From ff86bcf81581aa37788e4d95ad1df469c3ee36bc Mon Sep 17 00:00:00 2001 From: Fabian Guera Date: Thu, 7 May 2026 21:54:13 +0100 Subject: [PATCH] Fix _smooth_W_torch to run on CUDA when available The tensor was always created on CPU, so SVDSmoothing ran on CPU even when torch CUDA was available. Also add .cpu() before .numpy() to handle the CUDA tensor correctly. --- weightwatcher/RMT_Util.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/weightwatcher/RMT_Util.py b/weightwatcher/RMT_Util.py index 97744e8..80b8a98 100644 --- a/weightwatcher/RMT_Util.py +++ b/weightwatcher/RMT_Util.py @@ -149,22 +149,23 @@ def _smooth_W_torch(W, n_comp): """ # Convert W to a torch tensor - W_tensor = torch.tensor(W, dtype=torch.float32) + device = "cuda" if torch.cuda.is_available() else "cpu" + W_tensor = torch.tensor(W, dtype=torch.float32).to(device) # Perform SVD low-rank approximation # Note: torch.svd_lowrank returns U, S, Vh such that W ≈ U * diag(S) * Vh U, S, V = torch.svd_lowrank(W_tensor, q=n_comp) # Compute the smoothed W using the low-rank approximation smoothed_W_tensor = torch.mm(U, torch.mm(torch.diag(S), V.T)) - + # If the original W has more columns than rows, transpose the result if W.shape[0] < W.shape[1]: smoothed_W_tensor = smoothed_W_tensor.T - + # Convert the smoothed W back to a NumPy array if necessary # not using half...need to check this - smoothed_W = smoothed_W_tensor.float().numpy() - + smoothed_W = smoothed_W_tensor.float().cpu().numpy() + del W_tensor return smoothed_W @@ -235,8 +236,8 @@ def _svd_lowrank_fast(M, k): U, S, V = torch.svd_lowrank(M_cuda, q=k) del M_cuda return torch_T_to_np_32(U), torch_T_to_np_32(S), torch_T_to_np_32(V).T - - + + def _svd_values_fast(M, k): torch.cuda.empty_cache() with torch.no_grad(): @@ -244,7 +245,7 @@ def _svd_values_fast(M, k): _, S, _ = torch.svd_lowrank(M_cuda, q=k) del M_cuda return torch_T_to_np_32(S) - + def _smooth_W_fast(W, k): torch.cuda.empty_cache() with torch.no_grad():