From e583d85ff932e4f3a9c59c141a3a812746772147 Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Wed, 6 May 2026 12:51:51 -0700 Subject: [PATCH 1/3] fix dtypes for torch --- ScaFFold/utils/data_types.py | 5 ++++- ScaFFold/utils/trainer.py | 10 ++++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/ScaFFold/utils/data_types.py b/ScaFFold/utils/data_types.py index b555811..ef1515d 100644 --- a/ScaFFold/utils/data_types.py +++ b/ScaFFold/utils/data_types.py @@ -19,7 +19,10 @@ # Masks are values 0 <= x <= n_categories MASK_DTYPE = np.uint16 # Volumes/img are 0 <= x <= 1 -VOLUME_DTYPE = np.float32 +VOLUME_DTYPE_NAME = "float32" +VOLUME_NP_DTYPE = getattr(np, VOLUME_DTYPE_NAME) +VOLUME_TORCH_DTYPE = getattr(torch, VOLUME_DTYPE_NAME) +VOLUME_DTYPE = VOLUME_NP_DTYPE # Shared AMP dtype selection for torch.autocast. AMP_DTYPE = torch.bfloat16 diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index 1a1d2e0..3add912 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -30,7 +30,7 @@ from ScaFFold.utils.checkpointing import CheckpointManager from ScaFFold.utils.data_loading import FractalDataset, SpatialShardSpec -from ScaFFold.utils.data_types import AMP_DTYPE, VOLUME_DTYPE +from ScaFFold.utils.data_types import AMP_DTYPE, VOLUME_TORCH_DTYPE from ScaFFold.utils.dice_score import compute_sharded_dice from ScaFFold.utils.distributed import get_local_rank, get_world_rank, get_world_size @@ -436,7 +436,7 @@ def warmup(self): images = images.to( device=self.device, - dtype=VOLUME_DTYPE, + dtype=VOLUME_TORCH_DTYPE, memory_format=torch.channels_last_3d, non_blocking=True, ) @@ -611,7 +611,7 @@ def train(self): begin_code_region("image_to_device") images = images.to( device=self.device, - dtype=VOLUME_DTYPE, + dtype=VOLUME_TORCH_DTYPE, memory_format=torch.channels_last_3d, # NDHWC (channels last) vs NCDHW (channels first) non_blocking=True, ) @@ -749,7 +749,9 @@ def train(self): self.config.n_categories, self.config._parallel_strategy, ) - dice_info = torch.tensor([dice_sum, numsamples], dtype=VOLUME_DTYPE) + dice_info = torch.tensor( + [dice_sum, numsamples], dtype=VOLUME_TORCH_DTYPE + ) if self.config.dist: dice_info = dice_info.to(device=self.device) torch.distributed.all_reduce( From 3dfbd138977624bcc33acc7bdcfca7ebd6e98b6e Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Thu, 7 May 2026 10:39:04 -0700 Subject: [PATCH 2/3] Add per minibatch timer --- ScaFFold/utils/trainer.py | 50 +++++++++++++++++++++++++++++++++------ 1 file changed, 43 insertions(+), 7 deletions(-) diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index 1a1d2e0..1aa399a 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -267,6 +267,20 @@ def _current_learning_rate(self): return self.config.starting_learning_rate return self.optimizer.param_groups[0]["lr"] + def _timing_ddp_rank(self): + if self.ps is None: + return self.world_rank + return self.ps.ddp_ind + + def _timing_shard_label(self): + if self.ps is None: + return "replicated" + return "x".join(str(shard_index) for shard_index in self.ps.shard_ind) + + def _sync_device_for_timing(self): + if self.device.type == "cuda": + torch.cuda.synchronize(self.device) + class PyTorchTrainer(BaseTrainer): """ @@ -349,18 +363,18 @@ def cleanup_or_resume(self): with open(self.outfile_path, "a", newline="") as outfile: outfile.write(",".join(headers) + "\n") - def _truncate_stats_file(self, start_epoch): + def _truncate_stats_file(self, start_epoch, path=None): """ Scans the stats file and truncates it at the first occurrence of an epoch >= start_epoch. This is O(1) memory and safe for large logs. """ - self.log.info( - f"Truncating {self.outfile_path} to remove epochs >= {start_epoch}" - ) + if path is None: + path = self.outfile_path + self.log.info(f"Truncating {path} to remove epochs >= {start_epoch}") try: # Open in read+update mode ('r+') to allow seeking and truncating - with open(self.outfile_path, "r+") as f: + with open(path, "r+") as f: header = f.readline() if not header: return @@ -401,7 +415,7 @@ def _truncate_stats_file(self, start_epoch): pass except Exception as e: - self.log.warning(f"Failed to truncate stats file: {e}") + self.log.warning(f"Failed to truncate stats file {path}: {e}") def _get_memsize(self, tensor, tensor_label: str, verbosity: int = 0): """Log size of tensor in memory""" @@ -604,7 +618,12 @@ def train(self): disable=True if self.world_rank != 0 else False, ) as pbar: begin_code_region("batch_loop") - for batch in self.train_loader: + for batch_idx, batch in enumerate(self.train_loader): + time_minibatch = batch_idx == 0 and self.world_rank == 0 + if time_minibatch: + self._sync_device_for_timing() + minibatch_start_time = time.perf_counter() + # Load initial samples and labels images, true_masks = batch["image"], batch["mask"] @@ -724,6 +743,23 @@ def train(self): self.global_step += 1 # Stay on GPU epoch_loss += loss.detach() + if time_minibatch: + self._sync_device_for_timing() + minibatch_time_s = ( + time.perf_counter() - minibatch_start_time + ) + print( + "MINIBATCH_TIMER " + f"epoch={epoch} " + f"batch_idx={batch_idx} " + f"global_step={self.global_step} " + f"ddp_rank={self._timing_ddp_rank()} " + f"world_rank={self.world_rank} " + f"local_rank={self.local_rank} " + f"shard_index={self._timing_shard_label()} " + f"batch_size={images_dc.shape[0]} " + f"minibatch_time_s={minibatch_time_s:.6f}", + ) end_code_region("update_loss") end_code_region("batch_loop") From c9ef075c1f786510692e05b36867b5956b5ccef6 Mon Sep 17 00:00:00 2001 From: Michael McKinsey Date: Thu, 7 May 2026 13:11:13 -0700 Subject: [PATCH 3/3] cleanup --- ScaFFold/utils/trainer.py | 33 ++++----------------------------- 1 file changed, 4 insertions(+), 29 deletions(-) diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index 77db2bf..e981e8a 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -267,20 +267,6 @@ def _current_learning_rate(self): return self.config.starting_learning_rate return self.optimizer.param_groups[0]["lr"] - def _timing_ddp_rank(self): - if self.ps is None: - return self.world_rank - return self.ps.ddp_ind - - def _timing_shard_label(self): - if self.ps is None: - return "replicated" - return "x".join(str(shard_index) for shard_index in self.ps.shard_ind) - - def _sync_device_for_timing(self): - if self.device.type == "cuda": - torch.cuda.synchronize(self.device) - class PyTorchTrainer(BaseTrainer): """ @@ -621,7 +607,6 @@ def train(self): for batch_idx, batch in enumerate(self.train_loader): time_minibatch = batch_idx == 0 and self.world_rank == 0 if time_minibatch: - self._sync_device_for_timing() minibatch_start_time = time.perf_counter() # Load initial samples and labels @@ -744,22 +729,12 @@ def train(self): # Stay on GPU epoch_loss += loss.detach() if time_minibatch: - self._sync_device_for_timing() + # This sync has some potential performance impact + # TODO: Would be better to measure this with Caliper, which uses CUDA events. + torch.cuda.synchronize(self.device) minibatch_time_s = ( time.perf_counter() - minibatch_start_time ) - print( - "MINIBATCH_TIMER " - f"epoch={epoch} " - f"batch_idx={batch_idx} " - f"global_step={self.global_step} " - f"ddp_rank={self._timing_ddp_rank()} " - f"world_rank={self.world_rank} " - f"local_rank={self.local_rank} " - f"shard_index={self._timing_shard_label()} " - f"batch_size={images_dc.shape[0]} " - f"minibatch_time_s={minibatch_time_s:.6f}", - ) end_code_region("update_loss") end_code_region("batch_loop") @@ -827,7 +802,7 @@ def train(self): ) outfile.flush() print( - f"Epoch {epoch} completed in {epoch_duration} seconds. Total train time so far: {time.time() - start}" + f"Epoch {epoch} completed in {epoch_duration} seconds. Total train time so far: {time.time() - start}. Rank 0 first batch minibatch_time_s={minibatch_time_s:.6f}." ) #