From 265d317aa541e6713071ddb8b137dd3d8a6b9eda Mon Sep 17 00:00:00 2001 From: Gokul Date: Tue, 30 Jul 2024 11:45:10 -0700 Subject: [PATCH] Cherry pick docs changes from main branch (#1304) * fix docs, make sure docs build (#1302) * fix docs, make sure docs build * adding stateful dataloader docs * Add stateful dataloader tutorial docs (#1303) --- docs/Makefile | 36 +++- docs/source/dp_tutorial.rst | 22 +-- docs/source/index.rst | 2 + docs/source/stateful_dataloader_tutorial.rst | 177 ++++++++++++++++++ docs/source/torchdata.stateful_dataloader.rst | 13 ++ torchdata/dataloader2/reading_service.py | 2 +- .../stateful_dataloader.py | 19 +- 7 files changed, 241 insertions(+), 30 deletions(-) create mode 100644 docs/source/stateful_dataloader_tutorial.rst create mode 100644 docs/source/torchdata.stateful_dataloader.rst diff --git a/docs/Makefile b/docs/Makefile index 1e84df4e5..e9312f177 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -1,10 +1,14 @@ # Minimal makefile for Sphinx documentation # -# You can set these variables from the command line, and also -# from the environment for the first two. -SPHINXOPTS ?= -SPHINXBUILD ?= sphinx-build +ifneq ($(EXAMPLES_PATTERN),) + EXAMPLES_PATTERN_OPTS := -D sphinx_gallery_conf.filename_pattern="$(EXAMPLES_PATTERN)" +endif + +# You can set these variables from the command line. +SPHINXOPTS = -W -j auto $(EXAMPLES_PATTERN_OPTS) +SPHINXBUILD = sphinx-build +SPHINXPROJ = torchdata SOURCEDIR = source BUILDDIR = build @@ -12,12 +16,26 @@ BUILDDIR = build help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -doctest: html - $(SPHINXBUILD) -b doctest $(SPHINXOPTS) "$(SOURCEDIR)" "$(BUILDDIR)"/doctest - @echo "Testing of doctests in the sources finished, look at the " \ - "results in $(BUILDDIR)/doctest/output.txt." +docset: html + doc2dash --name $(SPHINXPROJ) --icon $(SOURCEDIR)/_static/img/pytorch-logo-flame.png --enable-js --online-redirect-url http://pytorch.org/data/ --force $(BUILDDIR)/html/ + + # Manually fix because Zeal doesn't deal well with `icon.png`-only at 2x resolution. + cp $(SPHINXPROJ).docset/icon.png $(SPHINXPROJ).docset/icon@2x.png + convert $(SPHINXPROJ).docset/icon@2x.png -resize 16x16 $(SPHINXPROJ).docset/icon.png + +html-noplot: # Avoids running the gallery examples, which may take time + $(SPHINXBUILD) -D plot_gallery=0 -b html "${SOURCEDIR}" "$(BUILDDIR)"/html + @echo + @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." + +clean: + rm -rf $(BUILDDIR)/* + rm -rf $(SOURCEDIR)/generated_examples/ # sphinx-gallery + rm -rf $(SOURCEDIR)/gen_modules/ # sphinx-gallery + rm -rf $(SOURCEDIR)/sg_execution_times.rst # sphinx-gallery + rm -rf $(SOURCEDIR)/generated/ # autosummary -.PHONY: help doctest Makefile +.PHONY: help Makefile docset # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). diff --git a/docs/source/dp_tutorial.rst b/docs/source/dp_tutorial.rst index 900880f12..b1d0965df 100644 --- a/docs/source/dp_tutorial.rst +++ b/docs/source/dp_tutorial.rst @@ -321,7 +321,7 @@ Accessing AWS S3 with ``fsspec`` DataPipes ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ This requires the installation of the libraries ``fsspec`` -(`documentation `_) and ``s3fs`` +(`documentation `__) and ``s3fs`` (`s3fs GitHub repo `_). You can list out the files within a S3 bucket directory by passing a path that starts @@ -363,7 +363,7 @@ is also available for writing data to cloud. Accessing Google Cloud Storage (GCS) with ``fsspec`` DataPipes ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ This requires the installation of the libraries ``fsspec`` -(`documentation `_) and ``gcsfs`` +(`documentation `__) and ``gcsfs`` (`gcsfs GitHub repo `_). You can list out the files within a GCS bucket directory by specifying a path that starts @@ -400,11 +400,11 @@ Accessing Azure Blob storage with ``fsspec`` DataPipes ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ This requires the installation of the libraries ``fsspec`` -(`documentation `_) and ``adlfs`` +(`documentation `__) and ``adlfs`` (`adlfs GitHub repo `_). -You can access data in Azure Data Lake Storage Gen2 by providing URIs staring with ``abfs://``. +You can access data in Azure Data Lake Storage Gen2 by providing URIs staring with ``abfs://``. For example, -`FSSpecFileLister `_ (``.list_files_by_fsspec(...)``) +`FSSpecFileLister `_ (``.list_files_by_fsspec(...)``) can be used to list files in a directory in a container: .. code:: python @@ -430,11 +430,11 @@ directory ``curated/covid-19/ecdc_cases/latest``, belonging to account ``pandemi .open_files_by_fsspec(account_name='pandemicdatalake') \ .parse_csv() print(list(dp)[:3]) - # [['date_rep', 'day', ..., 'iso_country', 'daterep'], + # [['date_rep', 'day', ..., 'iso_country', 'daterep'], # ['2020-12-14', '14', ..., 'AF', '2020-12-14'], # ['2020-12-13', '13', ..., 'AF', '2020-12-13']] -If necessary, you can also access data in Azure Data Lake Storage Gen1 by using URIs staring with +If necessary, you can also access data in Azure Data Lake Storage Gen1 by using URIs staring with ``adl://`` and ``abfs://``, as described in `README of adlfs repo `_ Accessing Azure ML Datastores with ``fsspec`` DataPipes @@ -446,11 +446,11 @@ An Azure ML datastore is a *reference* to an existing storage account on Azure. - Authentication is automatically handled - both *credential-based* access (service principal/SAS/key) and *identity-based* access (Azure Active Directory/managed identity) are supported. When using credential-based authentication, you do not need to expose secrets in your code. This requires the installation of the library ``azureml-fsspec`` -(`documentation `_). +(`documentation `__). -You can access data in an Azure ML datastore by providing URIs staring with ``azureml://``. +You can access data in an Azure ML datastore by providing URIs staring with ``azureml://``. For example, -`FSSpecFileLister `_ (``.list_files_by_fsspec(...)``) +`FSSpecFileLister `_ (``.list_files_by_fsspec(...)``) can be used to list files in a directory in a container: .. code:: python @@ -470,7 +470,7 @@ can be used to list files in a directory in a container: dp = IterableWrapper([uri]).list_files_by_fsspec() print(list(dp)) - # ['azureml:////resourcegroups//workspaces//datastores//paths//file1.txt', + # ['azureml:////resourcegroups//workspaces//datastores//paths//file1.txt', # 'azureml:////resourcegroups//workspaces//datastores//paths//file2.txt', ...] You can also open files using `FSSpecFileOpener `_ diff --git a/docs/source/index.rst b/docs/source/index.rst index 5aa5895af..cec30b4ad 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -36,6 +36,7 @@ Features described in this documentation are classified by release status: :maxdepth: 2 :caption: API Reference: + torchdata.stateful_dataloader.rst torchdata.datapipes.iter.rst torchdata.datapipes.map.rst torchdata.datapipes.utils.rst @@ -47,6 +48,7 @@ Features described in this documentation are classified by release status: :maxdepth: 2 :caption: Tutorial and Examples: + stateful_dataloader_tutorial.rst dp_tutorial.rst dlv2_tutorial.rst examples.rst diff --git a/docs/source/stateful_dataloader_tutorial.rst b/docs/source/stateful_dataloader_tutorial.rst new file mode 100644 index 000000000..7bfc7f930 --- /dev/null +++ b/docs/source/stateful_dataloader_tutorial.rst @@ -0,0 +1,177 @@ +Stateful DataLoader Tutorial +============================ + +Saving and loading state +------------------------ + +Stateful DataLoader adds the ``load_state_dict``, ``state_dict`` methods to the ``torch.utils.data.DataLoader``. State fetch and set can be done as follows: + +.. code:: python + + from torchdata.stateful_dataloader import StatefulDataLoader + + dataloader = StatefulDataLoader(dataset, num_workers=2) + for i, batch in enumerate(dataloader): + ... + if i == 10: + state_dict = dataloader.state_dict() + break + + # Training run resumes with the previous checkpoint + dataloader = StatefulDataLoader(dataset, num_workers=2) + # Resume state with DataLoader + dataloader.load_state_dict(state_dict) + for i, batch in enumerate(dataloader): + ... + +Saving Custom State with Map-Style Datasets +------------------------------------------- + +For efficient resuming of `Map-style datasets `_, you can resume iteration by defining ``state_dict`` / ``load_state_dict`` methods in your sampler. If your dataset has worker-specific state (eg RNG transform state) you can add ``state_dict`` / ``load_state_dict`` methods to your dataset. + +.. code:: python + + from typing import * + import torch + import torch.utils.data + from torchdata.stateful_dataloader import StatefulDataLoader + + # If you are using the default RandomSampler and BatchSampler in torch.utils.data, they are patched when you import torchdata.stateful_dataloader so that defining, a custom sampler here is unnecessary + class MySampler(torch.utils.data.Sampler[int]): + def __init__(self, high: int, seed: int, limit: int): + self.seed, self.high, self.limit = seed, high, limit + self.g = torch.Generator() + self.g.manual_seed(self.seed) + self.i = 0 + + def __iter__(self): + while self.i < self.limit: + val = int(torch.randint(high=self.high, size=(1,), generator=self.g)) + self.i += 1 + yield val + + def load_state_dict(self, state_dict: Dict[str, Any]): + self.i = state_dict["i"] + self.g.set_state(state_dict["rng"]) + + def state_dict(self) -> Dict[str, Any]: + return {"i": self.i, "rng": self.g.get_state()} + + # Optional: save dataset random transform state + class NoisyRange(torch.utils.data.Dataset): + def __init__(self, high: int, mean: float, std: float): + self.high, self.mean, self.std = high, torch.tensor([float(mean)]), float(std) + + def __len__(self): + return self.high + + def __getitem__(self, idx: int) -> float: + if not (0 <= idx < self.high): + raise IndexError() + x = torch.normal(self.mean, self.std) + noise = x.item() + return idx + noise + + def load_state_dict(self, state_dict): + torch.set_rng_state(state_dict["rng"]) + + def state_dict(self): + return {"rng": torch.get_rng_state()} + + # Test both single/multiprocess dataloading + for num_workers in [0, 2]: + print(f"{num_workers=}") + dl = StatefulDataLoader(NoisyRange(5, 1, 1), sampler=MySampler(5, 1, 10), + batch_size=2, drop_last=False, num_workers=num_workers) + + batches = [] + for i, batch in enumerate(dl): + batches.append(batch) + if i == 2: + sd = dl.state_dict() + + dl.load_state_dict(sd) + batches2 = list(dl) + + print(batches[3:]) + print(batches2) + + """ + Output: + num_workers=0 + [tensor([-0.4526, 3.7948], dtype=torch.float64), tensor([6.5494, 3.0470], dtype=torch.float64)] + [tensor([-0.4526, 3.7948], dtype=torch.float64), tensor([6.5494, 3.0470], dtype=torch.float64)] + num_workers=2 + [tensor([3.7412, 1.2438], dtype=torch.float64), tensor([4.4807, 4.0036], dtype=torch.float64)] + [tensor([3.7412, 1.2438], dtype=torch.float64), tensor([4.4807, 4.0036], dtype=torch.float64)] + """ + +Saving Custom State with Iterable-Style Datasets +------------------------------------------------ + +Tracking iteration order with `Iterable-style datasets `_ requires state from each worker-level instance of the dataset to be captured. You can define ``state_dict`` / ``load_state_dict`` methods on your dataset which capture worker-level state. :class:`StatefulDataLoader` will handle aggregation across workers and distribution back to the workers. Calling ``load_state_dict`` requires :class:`StatefulDataLoader`` to have same ``num_workers`` as those of the provided ``state_dict``. + +.. code:: python + + from typing import * + import torch + import torch.utils.data + from torchdata.stateful_dataloader import StatefulDataLoader + + + class MyIterableDataset(torch.utils.data.IterableDataset): + def __init__(self, high: int, seed: int): + self.high, self.seed = high, seed + self.g = torch.Generator() + self.i = 0 + + def __iter__(self): + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None: + worker_id = worker_info.id + num_workers = worker_info.num_workers + else: + worker_id = 0 + num_workers = 1 + self.g.manual_seed(self.seed) + arr = torch.randperm(self.high, generator=self.g) + arr = arr[worker_id:self.high:num_workers] + for idx in range(self.i, len(arr)): + self.i += 1 + yield arr[idx] + self.i = 0 + + def state_dict(self): + return {"i": self.i} + + def load_state_dict(self, state_dict): + self.i = state_dict["i"] + + # Test both single/multiprocess dataloading + for num_workers in [0, 2]: + print(f"{num_workers=}") + dl = StatefulDataLoader( + MyIterableDataset(12, 0), batch_size=2, drop_last=False, + num_workers=num_workers) + + batches = [] + for i, batch in enumerate(dl): + batches.append(batch) + if i == 2: + sd = dl.state_dict() + + dl.load_state_dict(sd) + batches2 = list(dl) + + print(batches[3:]) + print(batches2) + + """ + Output: + num_workers=0 + [tensor([ 2, 10]), tensor([3, 1]), tensor([11, 6])] + [tensor([ 2, 10]), tensor([3, 1]), tensor([11, 6])] + num_workers=2 + [tensor([ 4, 10]), tensor([ 3, 11]), tensor([1, 6])] + [tensor([ 4, 10]), tensor([ 3, 11]), tensor([1, 6])] + """ diff --git a/docs/source/torchdata.stateful_dataloader.rst b/docs/source/torchdata.stateful_dataloader.rst new file mode 100644 index 000000000..a7d161b34 --- /dev/null +++ b/docs/source/torchdata.stateful_dataloader.rst @@ -0,0 +1,13 @@ +:tocdepth: 3 + +Stateful DataLoader +=================== + +.. automodule:: torchdata.stateful_dataloader + +StatefulDataLoader is a drop-in replacement for `torch.utils.data.DataLoader `_ which offers ``state_dict`` / ``load_state_dict`` methods for handling mid-epoch checkpointing which operate on the previous/next iterator requested from the dataloader (resp.). + +By default, the state includes the number of batches yielded and uses this to naively fast-forward the sampler (map-style) or the dataset (iterable-style). However if the sampler and/or dataset include ``state_dict`` / ``load_state_dict`` methods, then it will call them during its own ``state_dict`` / ``load_state_dict`` calls. Under the hood, :class:`StatefulDataLoader` handles aggregation and distribution of state across multiprocess workers (but not across ranks). + +.. autoclass:: StatefulDataLoader + :members: diff --git a/torchdata/dataloader2/reading_service.py b/torchdata/dataloader2/reading_service.py index 1af26f875..776c6f7ef 100644 --- a/torchdata/dataloader2/reading_service.py +++ b/torchdata/dataloader2/reading_service.py @@ -149,7 +149,7 @@ def __new__(cls, *args, **kwargs): class InProcessReadingService(ReadingServiceInterface): r""" - Default ReadingService to serve the ``DataPipe` graph in the main process, + Default ReadingService to serve the ``DataPipe`` graph in the main process, and apply graph settings like determinism control to the graph. Args: diff --git a/torchdata/stateful_dataloader/stateful_dataloader.py b/torchdata/stateful_dataloader/stateful_dataloader.py index 92553cf69..9b162b4f8 100644 --- a/torchdata/stateful_dataloader/stateful_dataloader.py +++ b/torchdata/stateful_dataloader/stateful_dataloader.py @@ -92,13 +92,12 @@ class StatefulDataLoader(DataLoader[_T_co]): r""" - This is a drop in replacement for :class:`~torch.utils.data.DataLoader` + This is a drop in replacement for ``torch.utils.data.DataLoader`` that implements state_dict and load_state_dict methods, enabling mid-epoch checkpointing. - All arguments are identical to :class:`~torch.utils.data.DataLoader`, with - a new kwarg: `snapshot_every_n_steps: Optional[int] = `. - See :py:mod:`torch.utils.data` documentation page for more details. + All arguments are identical to ``torch.utils.data.DataLoader``, with + a new kwarg: ``snapshot_every_n_steps``. Args: dataset (Dataset): dataset from which to load the data. @@ -148,11 +147,13 @@ class StatefulDataLoader(DataLoader[_T_co]): maintain the workers `Dataset` instances alive. (default: ``False``) pin_memory_device (str, optional): the device to :attr:`pin_memory` to if ``pin_memory`` is ``True``. + snapshot_every_n_steps (int, optional): Defines how often the state is + transferred from the dataloader workers to the dataloader. By default, it is set to ``1``, i.e., state is transferred every step. If the state is large, this value can be increased (and ideally set to the frequency of training checkpointing) to reduce the overhead of transferring state every step. .. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an unpicklable object, e.g., a lambda function. See - :ref:`multiprocessing-best-practices` on more details related + `multiprocessing-best-practices `_ on more details related to multiprocessing in PyTorch. .. warning:: ``len(dataloader)`` heuristic is based on the length of the sampler used. @@ -169,12 +170,12 @@ class StatefulDataLoader(DataLoader[_T_co]): dropped when :attr:`drop_last` is set. Unfortunately, PyTorch can not detect such cases in general. - See `Dataset Types`_ for more details on these two types of datasets and how + See `Dataset Types `_ for more details on these two types of datasets and how :class:`~torch.utils.data.IterableDataset` interacts with - `Multi-process data loading`_. + `Multi-process data loading `_. - .. warning:: See :ref:`reproducibility`, and :ref:`dataloader-workers-random-seed`, and - :ref:`data-loading-randomness` notes for random seed related questions. + .. warning:: See `Reproducibility `_, and `Dataloader-workers-random-seed `_, and + `Data-loading-randomness `_ notes for random seed related questions. .. _multiprocessing context: https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods