-
Notifications
You must be signed in to change notification settings - Fork 155
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Cherry pick docs changes from main branch (#1304)
* fix docs, make sure docs build (#1302) * fix docs, make sure docs build * adding stateful dataloader docs * Add stateful dataloader tutorial docs (#1303)
- Loading branch information
1 parent
ba35881
commit 265d317
Showing
7 changed files
with
241 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
Stateful DataLoader Tutorial | ||
============================ | ||
|
||
Saving and loading state | ||
------------------------ | ||
|
||
Stateful DataLoader adds the ``load_state_dict``, ``state_dict`` methods to the ``torch.utils.data.DataLoader``. State fetch and set can be done as follows: | ||
|
||
.. code:: python | ||
from torchdata.stateful_dataloader import StatefulDataLoader | ||
dataloader = StatefulDataLoader(dataset, num_workers=2) | ||
for i, batch in enumerate(dataloader): | ||
... | ||
if i == 10: | ||
state_dict = dataloader.state_dict() | ||
break | ||
# Training run resumes with the previous checkpoint | ||
dataloader = StatefulDataLoader(dataset, num_workers=2) | ||
# Resume state with DataLoader | ||
dataloader.load_state_dict(state_dict) | ||
for i, batch in enumerate(dataloader): | ||
... | ||
Saving Custom State with Map-Style Datasets | ||
------------------------------------------- | ||
|
||
For efficient resuming of `Map-style datasets <https://pytorch.org/docs/stable/data.html#map-style-datasets>`_, you can resume iteration by defining ``state_dict`` / ``load_state_dict`` methods in your sampler. If your dataset has worker-specific state (eg RNG transform state) you can add ``state_dict`` / ``load_state_dict`` methods to your dataset. | ||
|
||
.. code:: python | ||
from typing import * | ||
import torch | ||
import torch.utils.data | ||
from torchdata.stateful_dataloader import StatefulDataLoader | ||
# If you are using the default RandomSampler and BatchSampler in torch.utils.data, they are patched when you import torchdata.stateful_dataloader so that defining, a custom sampler here is unnecessary | ||
class MySampler(torch.utils.data.Sampler[int]): | ||
def __init__(self, high: int, seed: int, limit: int): | ||
self.seed, self.high, self.limit = seed, high, limit | ||
self.g = torch.Generator() | ||
self.g.manual_seed(self.seed) | ||
self.i = 0 | ||
def __iter__(self): | ||
while self.i < self.limit: | ||
val = int(torch.randint(high=self.high, size=(1,), generator=self.g)) | ||
self.i += 1 | ||
yield val | ||
def load_state_dict(self, state_dict: Dict[str, Any]): | ||
self.i = state_dict["i"] | ||
self.g.set_state(state_dict["rng"]) | ||
def state_dict(self) -> Dict[str, Any]: | ||
return {"i": self.i, "rng": self.g.get_state()} | ||
# Optional: save dataset random transform state | ||
class NoisyRange(torch.utils.data.Dataset): | ||
def __init__(self, high: int, mean: float, std: float): | ||
self.high, self.mean, self.std = high, torch.tensor([float(mean)]), float(std) | ||
def __len__(self): | ||
return self.high | ||
def __getitem__(self, idx: int) -> float: | ||
if not (0 <= idx < self.high): | ||
raise IndexError() | ||
x = torch.normal(self.mean, self.std) | ||
noise = x.item() | ||
return idx + noise | ||
def load_state_dict(self, state_dict): | ||
torch.set_rng_state(state_dict["rng"]) | ||
def state_dict(self): | ||
return {"rng": torch.get_rng_state()} | ||
# Test both single/multiprocess dataloading | ||
for num_workers in [0, 2]: | ||
print(f"{num_workers=}") | ||
dl = StatefulDataLoader(NoisyRange(5, 1, 1), sampler=MySampler(5, 1, 10), | ||
batch_size=2, drop_last=False, num_workers=num_workers) | ||
batches = [] | ||
for i, batch in enumerate(dl): | ||
batches.append(batch) | ||
if i == 2: | ||
sd = dl.state_dict() | ||
dl.load_state_dict(sd) | ||
batches2 = list(dl) | ||
print(batches[3:]) | ||
print(batches2) | ||
""" | ||
Output: | ||
num_workers=0 | ||
[tensor([-0.4526, 3.7948], dtype=torch.float64), tensor([6.5494, 3.0470], dtype=torch.float64)] | ||
[tensor([-0.4526, 3.7948], dtype=torch.float64), tensor([6.5494, 3.0470], dtype=torch.float64)] | ||
num_workers=2 | ||
[tensor([3.7412, 1.2438], dtype=torch.float64), tensor([4.4807, 4.0036], dtype=torch.float64)] | ||
[tensor([3.7412, 1.2438], dtype=torch.float64), tensor([4.4807, 4.0036], dtype=torch.float64)] | ||
""" | ||
Saving Custom State with Iterable-Style Datasets | ||
------------------------------------------------ | ||
|
||
Tracking iteration order with `Iterable-style datasets <https://pytorch.org/docs/stable/data.html#iterable-style-datasets>`_ requires state from each worker-level instance of the dataset to be captured. You can define ``state_dict`` / ``load_state_dict`` methods on your dataset which capture worker-level state. :class:`StatefulDataLoader` will handle aggregation across workers and distribution back to the workers. Calling ``load_state_dict`` requires :class:`StatefulDataLoader`` to have same ``num_workers`` as those of the provided ``state_dict``. | ||
|
||
.. code:: python | ||
from typing import * | ||
import torch | ||
import torch.utils.data | ||
from torchdata.stateful_dataloader import StatefulDataLoader | ||
class MyIterableDataset(torch.utils.data.IterableDataset): | ||
def __init__(self, high: int, seed: int): | ||
self.high, self.seed = high, seed | ||
self.g = torch.Generator() | ||
self.i = 0 | ||
def __iter__(self): | ||
worker_info = torch.utils.data.get_worker_info() | ||
if worker_info is not None: | ||
worker_id = worker_info.id | ||
num_workers = worker_info.num_workers | ||
else: | ||
worker_id = 0 | ||
num_workers = 1 | ||
self.g.manual_seed(self.seed) | ||
arr = torch.randperm(self.high, generator=self.g) | ||
arr = arr[worker_id:self.high:num_workers] | ||
for idx in range(self.i, len(arr)): | ||
self.i += 1 | ||
yield arr[idx] | ||
self.i = 0 | ||
def state_dict(self): | ||
return {"i": self.i} | ||
def load_state_dict(self, state_dict): | ||
self.i = state_dict["i"] | ||
# Test both single/multiprocess dataloading | ||
for num_workers in [0, 2]: | ||
print(f"{num_workers=}") | ||
dl = StatefulDataLoader( | ||
MyIterableDataset(12, 0), batch_size=2, drop_last=False, | ||
num_workers=num_workers) | ||
batches = [] | ||
for i, batch in enumerate(dl): | ||
batches.append(batch) | ||
if i == 2: | ||
sd = dl.state_dict() | ||
dl.load_state_dict(sd) | ||
batches2 = list(dl) | ||
print(batches[3:]) | ||
print(batches2) | ||
""" | ||
Output: | ||
num_workers=0 | ||
[tensor([ 2, 10]), tensor([3, 1]), tensor([11, 6])] | ||
[tensor([ 2, 10]), tensor([3, 1]), tensor([11, 6])] | ||
num_workers=2 | ||
[tensor([ 4, 10]), tensor([ 3, 11]), tensor([1, 6])] | ||
[tensor([ 4, 10]), tensor([ 3, 11]), tensor([1, 6])] | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
:tocdepth: 3 | ||
|
||
Stateful DataLoader | ||
=================== | ||
|
||
.. automodule:: torchdata.stateful_dataloader | ||
|
||
StatefulDataLoader is a drop-in replacement for `torch.utils.data.DataLoader <https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader>`_ which offers ``state_dict`` / ``load_state_dict`` methods for handling mid-epoch checkpointing which operate on the previous/next iterator requested from the dataloader (resp.). | ||
|
||
By default, the state includes the number of batches yielded and uses this to naively fast-forward the sampler (map-style) or the dataset (iterable-style). However if the sampler and/or dataset include ``state_dict`` / ``load_state_dict`` methods, then it will call them during its own ``state_dict`` / ``load_state_dict`` calls. Under the hood, :class:`StatefulDataLoader` handles aggregation and distribution of state across multiprocess workers (but not across ranks). | ||
|
||
.. autoclass:: StatefulDataLoader | ||
:members: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters