diff --git a/streaming/base/dataloader.py b/streaming/base/dataloader.py index 2325280e5..cfd8ec9fc 100644 --- a/streaming/base/dataloader.py +++ b/streaming/base/dataloader.py @@ -72,7 +72,8 @@ def state_dict(self) -> Optional[Dict[str, Any]]: """ if isinstance(self.dataset, StreamingDataset): world = World() - return self.dataset.state_dict(self.num_samples_yielded * world.num_ranks) + num_samples = self.num_samples_yielded * world.num_ranks + return self.dataset.state_dict(num_samples, False) return None def load_state_dict(self, obj: Dict[str, Any]) -> None: @@ -84,7 +85,7 @@ def load_state_dict(self, obj: Dict[str, Any]) -> None: obj (Dict[str, Any]): The state. """ if isinstance(self.dataset, StreamingDataset): - return self.dataset.load_state_dict(obj) + self.dataset.load_state_dict(obj) def __del__(self) -> None: """Terminate the workers during cleanup.""" diff --git a/streaming/base/dataset.py b/streaming/base/dataset.py index 5783dcf5a..72622ea88 100644 --- a/streaming/base/dataset.py +++ b/streaming/base/dataset.py @@ -666,13 +666,18 @@ def __iter__(self) -> Iterator[Dict[str, Any]]: for sample_id in self._each_sample(sample_ids): yield self[sample_id] - def state_dict(self, sample_in_epoch: int) -> Dict[str, Any]: + def state_dict(self, num_samples: int, from_beginning: bool) -> Dict[str, Any]: """Get a dict containing training state (called from non-worker process). This is called on rank zero. + Our stock StreamingDataLoader counts samples from start of training (from_beginning=false). + However, if you are always counting from the start of the epoch, set from_beginning=true. + Args: - sample_in_epoch (int): The number of samples processed so far in the current epoch. + num_samples (int): The number of samples processed so far in the current epoch. + from_beginning (int): Whether we are counting samples from the start of this epoch, or + the start of just this potentially resumed training run this epoch. Returns: Dict[str, Any]: The state. @@ -680,9 +685,13 @@ def state_dict(self, sample_in_epoch: int) -> Dict[str, Any]: world = World() epoch = self.next_epoch - 1 epoch, offset = self._resume(world, epoch) + if from_beginning: + sample_in_epoch = num_samples + else: + sample_in_epoch = offset + num_samples return { 'epoch': epoch, - 'sample_in_epoch': offset + sample_in_epoch, + 'sample_in_epoch': sample_in_epoch, 'num_canonical_nodes': self.num_canonical_nodes, 'shuffle_seed': self.shuffle_seed }