diff --git a/weightwatcher/RMT_Util.py b/weightwatcher/RMT_Util.py index 97744e8..80b8a98 100644 --- a/weightwatcher/RMT_Util.py +++ b/weightwatcher/RMT_Util.py @@ -149,22 +149,23 @@ def _smooth_W_torch(W, n_comp): """ # Convert W to a torch tensor - W_tensor = torch.tensor(W, dtype=torch.float32) + device = "cuda" if torch.cuda.is_available() else "cpu" + W_tensor = torch.tensor(W, dtype=torch.float32).to(device) # Perform SVD low-rank approximation # Note: torch.svd_lowrank returns U, S, Vh such that W ≈ U * diag(S) * Vh U, S, V = torch.svd_lowrank(W_tensor, q=n_comp) # Compute the smoothed W using the low-rank approximation smoothed_W_tensor = torch.mm(U, torch.mm(torch.diag(S), V.T)) - + # If the original W has more columns than rows, transpose the result if W.shape[0] < W.shape[1]: smoothed_W_tensor = smoothed_W_tensor.T - + # Convert the smoothed W back to a NumPy array if necessary # not using half...need to check this - smoothed_W = smoothed_W_tensor.float().numpy() - + smoothed_W = smoothed_W_tensor.float().cpu().numpy() + del W_tensor return smoothed_W @@ -235,8 +236,8 @@ def _svd_lowrank_fast(M, k): U, S, V = torch.svd_lowrank(M_cuda, q=k) del M_cuda return torch_T_to_np_32(U), torch_T_to_np_32(S), torch_T_to_np_32(V).T - - + + def _svd_values_fast(M, k): torch.cuda.empty_cache() with torch.no_grad(): @@ -244,7 +245,7 @@ def _svd_values_fast(M, k): _, S, _ = torch.svd_lowrank(M_cuda, q=k) del M_cuda return torch_T_to_np_32(S) - + def _smooth_W_fast(W, k): torch.cuda.empty_cache() with torch.no_grad():