@@ -75,9 +75,9 @@ def augment_gmm(self, R: Tensor, gmm: GaussianMixture) -> Tensor:
7575 :return: Full set of coordinates with noise."""
7676
7777 z = torch .empty ((R .shape [0 ], self .num_particles , R .shape [2 ]), device = R .device )
78- z_gmm = torch .tensor (gmm . sample ( R . shape [ 0 ])[ 0 ], dtype = R . dtype , device = R . device ). view (
79- R .shape [0 ], - 1 , 3
80- )
78+ z_gmm = torch .tensor (
79+ gmm . sample ( R .shape [0 ])[ 0 ], dtype = R . dtype , device = R . device
80+ ). view ( R . shape [ 0 ], - 1 , 3 )
8181
8282 start_idx = 0
8383 for i , (cg_idx , noise_idx ) in enumerate (self .latent_groupings ):
@@ -118,7 +118,9 @@ def log_prob(self, value: Tensor) -> Tensor:
118118 R_cg = value [:, cg_idx , :]
119119 R_noise = value [:, noise_idx , :]
120120 exponential_term = (
121- - 0.5 * sum_except_batch ((R_noise - R_cg [:, None , :]) ** 2 ) / self .scale ** 2
121+ - 0.5
122+ * sum_except_batch ((R_noise - R_cg [:, None , :]) ** 2 )
123+ / self .scale ** 2
122124 )
123125 normalization_term = - torch .log (Z ) * R_noise .shape [1 ]
124126 log_prob += exponential_term + normalization_term
@@ -188,7 +190,9 @@ def _init_weights(self):
188190
189191 # Coordinate MLP - Xavier uniform for all but last layer
190192 coord_mlp_layers = list (layer .coors_mlp .modules ())
191- linear_layers = [m for m in coord_mlp_layers if isinstance (m , nn .Linear )]
193+ linear_layers = [
194+ m for m in coord_mlp_layers if isinstance (m , nn .Linear )
195+ ]
192196
193197 for i , m in enumerate (linear_layers ):
194198 if i == len (linear_layers ) - 1 :
@@ -312,7 +316,9 @@ def velocity(self, xt: Tensor, t: Tensor) -> Tensor:
312316
313317 return self .velo_net (xt , t )
314318
315- def compute_metrics (self , batch : tuple [Tensor , ...], batch_idx : int ) -> dict [str , Tensor ]:
319+ def compute_metrics (
320+ self , batch : tuple [Tensor , ...], batch_idx : int
321+ ) -> dict [str , Tensor ]:
316322 """Compute training/validation metrics.
317323
318324 :param batch: Batch data tuple, expecting (r,) where r is a Tensor.
@@ -337,8 +343,10 @@ def compute_metrics(self, batch: tuple[Tensor, ...], batch_idx: int) -> dict[str
337343
338344 if not self .training :
339345 x1 = self .compute_flow (x0 , return_intermediate = False , verbose = False )
340- traj = md .Trajectory (x1 .cpu ().numpy (), self .top_aa )
341- metrics ["ged" ] = torch .mean (torch .tensor (graph_edit_distance (traj = traj , verbose = False )))
346+ traj = md .Trajectory (x1 .detach ().cpu ().numpy (), self .top_aa )
347+ metrics ["ged" ] = torch .mean (
348+ torch .tensor (graph_edit_distance (traj = traj , verbose = False ))
349+ )
342350
343351 return metrics
344352
@@ -395,10 +403,12 @@ def fit_latent_gmm(
395403
396404 with torch .no_grad ():
397405 x1 = r .to (self .device )
398- x0 = self .compute_flow (x1 , reverse = True , chunk_size = chunk_size , verbose = verbose ).cpu ()
399- eps_sn = self .noise .to_standard_normal (x0 )[:, self .indices_split [1 ].cpu (), :].view (
400- x0 .shape [0 ], - 1
401- )
406+ x0 = self .compute_flow (
407+ x1 , reverse = True , chunk_size = chunk_size , verbose = verbose
408+ ).cpu ()
409+ eps_sn = self .noise .to_standard_normal (x0 )[
410+ :, self .indices_split [1 ].cpu (), :
411+ ].view (x0 .shape [0 ], - 1 )
402412
403413 gmm = GaussianMixture (n_components = n_components , * args , ** kwargs )
404414 gmm .fit (eps_sn .numpy ())
0 commit comments