diff --git a/ScaFFold/utils/trainer.py b/ScaFFold/utils/trainer.py index 3add912..e981e8a 100644 --- a/ScaFFold/utils/trainer.py +++ b/ScaFFold/utils/trainer.py @@ -349,18 +349,18 @@ def cleanup_or_resume(self): with open(self.outfile_path, "a", newline="") as outfile: outfile.write(",".join(headers) + "\n") - def _truncate_stats_file(self, start_epoch): + def _truncate_stats_file(self, start_epoch, path=None): """ Scans the stats file and truncates it at the first occurrence of an epoch >= start_epoch. This is O(1) memory and safe for large logs. """ - self.log.info( - f"Truncating {self.outfile_path} to remove epochs >= {start_epoch}" - ) + if path is None: + path = self.outfile_path + self.log.info(f"Truncating {path} to remove epochs >= {start_epoch}") try: # Open in read+update mode ('r+') to allow seeking and truncating - with open(self.outfile_path, "r+") as f: + with open(path, "r+") as f: header = f.readline() if not header: return @@ -401,7 +401,7 @@ def _truncate_stats_file(self, start_epoch): pass except Exception as e: - self.log.warning(f"Failed to truncate stats file: {e}") + self.log.warning(f"Failed to truncate stats file {path}: {e}") def _get_memsize(self, tensor, tensor_label: str, verbosity: int = 0): """Log size of tensor in memory""" @@ -604,7 +604,11 @@ def train(self): disable=True if self.world_rank != 0 else False, ) as pbar: begin_code_region("batch_loop") - for batch in self.train_loader: + for batch_idx, batch in enumerate(self.train_loader): + time_minibatch = batch_idx == 0 and self.world_rank == 0 + if time_minibatch: + minibatch_start_time = time.perf_counter() + # Load initial samples and labels images, true_masks = batch["image"], batch["mask"] @@ -724,6 +728,13 @@ def train(self): self.global_step += 1 # Stay on GPU epoch_loss += loss.detach() + if time_minibatch: + # This sync has some potential performance impact + # TODO: Would be better to measure this with Caliper, which uses CUDA events. + torch.cuda.synchronize(self.device) + minibatch_time_s = ( + time.perf_counter() - minibatch_start_time + ) end_code_region("update_loss") end_code_region("batch_loop") @@ -791,7 +802,7 @@ def train(self): ) outfile.flush() print( - f"Epoch {epoch} completed in {epoch_duration} seconds. Total train time so far: {time.time() - start}" + f"Epoch {epoch} completed in {epoch_duration} seconds. Total train time so far: {time.time() - start}. Rank 0 first batch minibatch_time_s={minibatch_time_s:.6f}." ) #