Skip to content

Commit 2bb264d

Browse files
Fix missing detach in compute_metrics
1 parent 264a13d commit 2bb264d

4 files changed

Lines changed: 29 additions & 15 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
Split-flows provide a probabilistic bridge between molecular resolutions, enabling conditional backmapping and direct measurement of the configuration-dependent (local) information loss.
1818

1919
<div align="center">
20-
<img src="figures/flow_trajectory.gif" alt="Flow Trajectory" style="max-width: 300px; width: 50%;">
20+
<img src="figures/flow_trajectory.gif" alt="Flow Trajectory" style="max-width: 300px; width: 50%; animation-iteration-count: 1;">
2121
</div>
2222

2323
## Installation

split_flows/models/split_flow.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,9 @@ def augment_gmm(self, R: Tensor, gmm: GaussianMixture) -> Tensor:
7575
:return: Full set of coordinates with noise."""
7676

7777
z = torch.empty((R.shape[0], self.num_particles, R.shape[2]), device=R.device)
78-
z_gmm = torch.tensor(gmm.sample(R.shape[0])[0], dtype=R.dtype, device=R.device).view(
79-
R.shape[0], -1, 3
80-
)
78+
z_gmm = torch.tensor(
79+
gmm.sample(R.shape[0])[0], dtype=R.dtype, device=R.device
80+
).view(R.shape[0], -1, 3)
8181

8282
start_idx = 0
8383
for i, (cg_idx, noise_idx) in enumerate(self.latent_groupings):
@@ -118,7 +118,9 @@ def log_prob(self, value: Tensor) -> Tensor:
118118
R_cg = value[:, cg_idx, :]
119119
R_noise = value[:, noise_idx, :]
120120
exponential_term = (
121-
-0.5 * sum_except_batch((R_noise - R_cg[:, None, :]) ** 2) / self.scale**2
121+
-0.5
122+
* sum_except_batch((R_noise - R_cg[:, None, :]) ** 2)
123+
/ self.scale**2
122124
)
123125
normalization_term = -torch.log(Z) * R_noise.shape[1]
124126
log_prob += exponential_term + normalization_term
@@ -188,7 +190,9 @@ def _init_weights(self):
188190

189191
# Coordinate MLP - Xavier uniform for all but last layer
190192
coord_mlp_layers = list(layer.coors_mlp.modules())
191-
linear_layers = [m for m in coord_mlp_layers if isinstance(m, nn.Linear)]
193+
linear_layers = [
194+
m for m in coord_mlp_layers if isinstance(m, nn.Linear)
195+
]
192196

193197
for i, m in enumerate(linear_layers):
194198
if i == len(linear_layers) - 1:
@@ -312,7 +316,9 @@ def velocity(self, xt: Tensor, t: Tensor) -> Tensor:
312316

313317
return self.velo_net(xt, t)
314318

315-
def compute_metrics(self, batch: tuple[Tensor, ...], batch_idx: int) -> dict[str, Tensor]:
319+
def compute_metrics(
320+
self, batch: tuple[Tensor, ...], batch_idx: int
321+
) -> dict[str, Tensor]:
316322
"""Compute training/validation metrics.
317323
318324
:param batch: Batch data tuple, expecting (r,) where r is a Tensor.
@@ -337,8 +343,10 @@ def compute_metrics(self, batch: tuple[Tensor, ...], batch_idx: int) -> dict[str
337343

338344
if not self.training:
339345
x1 = self.compute_flow(x0, return_intermediate=False, verbose=False)
340-
traj = md.Trajectory(x1.cpu().numpy(), self.top_aa)
341-
metrics["ged"] = torch.mean(torch.tensor(graph_edit_distance(traj=traj, verbose=False)))
346+
traj = md.Trajectory(x1.detach().cpu().numpy(), self.top_aa)
347+
metrics["ged"] = torch.mean(
348+
torch.tensor(graph_edit_distance(traj=traj, verbose=False))
349+
)
342350

343351
return metrics
344352

@@ -395,10 +403,12 @@ def fit_latent_gmm(
395403

396404
with torch.no_grad():
397405
x1 = r.to(self.device)
398-
x0 = self.compute_flow(x1, reverse=True, chunk_size=chunk_size, verbose=verbose).cpu()
399-
eps_sn = self.noise.to_standard_normal(x0)[:, self.indices_split[1].cpu(), :].view(
400-
x0.shape[0], -1
401-
)
406+
x0 = self.compute_flow(
407+
x1, reverse=True, chunk_size=chunk_size, verbose=verbose
408+
).cpu()
409+
eps_sn = self.noise.to_standard_normal(x0)[
410+
:, self.indices_split[1].cpu(), :
411+
].view(x0.shape[0], -1)
402412

403413
gmm = GaussianMixture(n_components=n_components, *args, **kwargs)
404414
gmm.fit(eps_sn.numpy())

split_flows/utils/metrics.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,9 @@ def compute_bond_cutoff_mdtraj(topology, scale=1.3):
147147
"""Compute bond cutoffs for MDTraj topology"""
148148
atomic_nums = [atom.element.atomic_number for atom in topology.atoms]
149149
# COVCUTOFFTABLE values are in Angstroms, convert to nanometers for MDTraj
150-
vdw_array = torch.Tensor([COVCUTOFFTABLE[int(el)] / 10.0 for el in atomic_nums]) # Å to nm
150+
vdw_array = torch.Tensor(
151+
[COVCUTOFFTABLE[int(el)] / 10.0 for el in atomic_nums]
152+
) # Å to nm
151153

152154
cutoff_array = (vdw_array[None, :] + vdw_array[:, None]) * scale
153155

split_flows/utils/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,5 +80,7 @@ def gradient(
8080

8181
if grad_outputs is None:
8282
grad_outputs = torch.ones_like(output).detach()
83-
grad = torch.autograd.grad(output, x, grad_outputs=grad_outputs, create_graph=create_graph)[0]
83+
grad = torch.autograd.grad(
84+
output, x, grad_outputs=grad_outputs, create_graph=create_graph
85+
)[0]
8486
return grad

0 commit comments

Comments
 (0)