Skip to content

Commit

Permalink
Cherry pick docs changes from main branch (#1304)
Browse files Browse the repository at this point in the history
* fix docs, make sure docs build (#1302)

* fix docs, make sure docs build

* adding stateful dataloader docs

* Add stateful dataloader tutorial docs (#1303)
  • Loading branch information
gokulavasan authored Jul 30, 2024
1 parent ba35881 commit 265d317
Show file tree
Hide file tree
Showing 7 changed files with 241 additions and 30 deletions.
36 changes: 27 additions & 9 deletions docs/Makefile
Original file line number Diff line number Diff line change
@@ -1,23 +1,41 @@
# Minimal makefile for Sphinx documentation
#

# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
ifneq ($(EXAMPLES_PATTERN),)
EXAMPLES_PATTERN_OPTS := -D sphinx_gallery_conf.filename_pattern="$(EXAMPLES_PATTERN)"
endif

# You can set these variables from the command line.
SPHINXOPTS = -W -j auto $(EXAMPLES_PATTERN_OPTS)
SPHINXBUILD = sphinx-build
SPHINXPROJ = torchdata
SOURCEDIR = source
BUILDDIR = build

# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

doctest: html
$(SPHINXBUILD) -b doctest $(SPHINXOPTS) "$(SOURCEDIR)" "$(BUILDDIR)"/doctest
@echo "Testing of doctests in the sources finished, look at the " \
"results in $(BUILDDIR)/doctest/output.txt."
docset: html
doc2dash --name $(SPHINXPROJ) --icon $(SOURCEDIR)/_static/img/pytorch-logo-flame.png --enable-js --online-redirect-url http://pytorch.org/data/ --force $(BUILDDIR)/html/

# Manually fix because Zeal doesn't deal well with `icon.png`-only at 2x resolution.
cp $(SPHINXPROJ).docset/icon.png $(SPHINXPROJ).docset/icon@2x.png
convert $(SPHINXPROJ).docset/icon@2x.png -resize 16x16 $(SPHINXPROJ).docset/icon.png

html-noplot: # Avoids running the gallery examples, which may take time
$(SPHINXBUILD) -D plot_gallery=0 -b html "${SOURCEDIR}" "$(BUILDDIR)"/html
@echo
@echo "Build finished. The HTML pages are in $(BUILDDIR)/html."

clean:
rm -rf $(BUILDDIR)/*
rm -rf $(SOURCEDIR)/generated_examples/ # sphinx-gallery
rm -rf $(SOURCEDIR)/gen_modules/ # sphinx-gallery
rm -rf $(SOURCEDIR)/sg_execution_times.rst # sphinx-gallery
rm -rf $(SOURCEDIR)/generated/ # autosummary

.PHONY: help doctest Makefile
.PHONY: help Makefile docset

# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
Expand Down
22 changes: 11 additions & 11 deletions docs/source/dp_tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ Accessing AWS S3 with ``fsspec`` DataPipes
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

This requires the installation of the libraries ``fsspec``
(`documentation <https://filesystem-spec.readthedocs.io/en/latest/>`_) and ``s3fs``
(`documentation <https://filesystem-spec.readthedocs.io/en/latest/>`__) and ``s3fs``
(`s3fs GitHub repo <https://github.com/fsspec/s3fs>`_).

You can list out the files within a S3 bucket directory by passing a path that starts
Expand Down Expand Up @@ -363,7 +363,7 @@ is also available for writing data to cloud.
Accessing Google Cloud Storage (GCS) with ``fsspec`` DataPipes
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
This requires the installation of the libraries ``fsspec``
(`documentation <https://filesystem-spec.readthedocs.io/en/latest/>`_) and ``gcsfs``
(`documentation <https://filesystem-spec.readthedocs.io/en/latest/>`__) and ``gcsfs``
(`gcsfs GitHub repo <https://github.com/fsspec/gcsfs>`_).

You can list out the files within a GCS bucket directory by specifying a path that starts
Expand Down Expand Up @@ -400,11 +400,11 @@ Accessing Azure Blob storage with ``fsspec`` DataPipes
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

This requires the installation of the libraries ``fsspec``
(`documentation <https://filesystem-spec.readthedocs.io/en/latest/>`_) and ``adlfs``
(`documentation <https://filesystem-spec.readthedocs.io/en/latest/>`__) and ``adlfs``
(`adlfs GitHub repo <https://github.com/fsspec/adlfs>`_).
You can access data in Azure Data Lake Storage Gen2 by providing URIs staring with ``abfs://``.
You can access data in Azure Data Lake Storage Gen2 by providing URIs staring with ``abfs://``.
For example,
`FSSpecFileLister <generated/torchdata.datapipes.iter.FSSpecFileLister.html>`_ (``.list_files_by_fsspec(...)``)
`FSSpecFileLister <generated/torchdata.datapipes.iter.FSSpecFileLister.html>`_ (``.list_files_by_fsspec(...)``)
can be used to list files in a directory in a container:

.. code:: python
Expand All @@ -430,11 +430,11 @@ directory ``curated/covid-19/ecdc_cases/latest``, belonging to account ``pandemi
.open_files_by_fsspec(account_name='pandemicdatalake') \
.parse_csv()
print(list(dp)[:3])
# [['date_rep', 'day', ..., 'iso_country', 'daterep'],
# [['date_rep', 'day', ..., 'iso_country', 'daterep'],
# ['2020-12-14', '14', ..., 'AF', '2020-12-14'],
# ['2020-12-13', '13', ..., 'AF', '2020-12-13']]
If necessary, you can also access data in Azure Data Lake Storage Gen1 by using URIs staring with
If necessary, you can also access data in Azure Data Lake Storage Gen1 by using URIs staring with
``adl://`` and ``abfs://``, as described in `README of adlfs repo <https://github.com/fsspec/adlfs/blob/main/README.md>`_

Accessing Azure ML Datastores with ``fsspec`` DataPipes
Expand All @@ -446,11 +446,11 @@ An Azure ML datastore is a *reference* to an existing storage account on Azure.
- Authentication is automatically handled - both *credential-based* access (service principal/SAS/key) and *identity-based* access (Azure Active Directory/managed identity) are supported. When using credential-based authentication, you do not need to expose secrets in your code.

This requires the installation of the library ``azureml-fsspec``
(`documentation <https://learn.microsoft.com/python/api/azureml-fsspec/?view=azure-ml-py>`_).
(`documentation <https://learn.microsoft.com/python/api/azureml-fsspec/?view=azure-ml-py>`__).

You can access data in an Azure ML datastore by providing URIs staring with ``azureml://``.
You can access data in an Azure ML datastore by providing URIs staring with ``azureml://``.
For example,
`FSSpecFileLister <generated/torchdata.datapipes.iter.FSSpecFileLister.html>`_ (``.list_files_by_fsspec(...)``)
`FSSpecFileLister <generated/torchdata.datapipes.iter.FSSpecFileLister.html>`_ (``.list_files_by_fsspec(...)``)
can be used to list files in a directory in a container:

.. code:: python
Expand All @@ -470,7 +470,7 @@ can be used to list files in a directory in a container:
dp = IterableWrapper([uri]).list_files_by_fsspec()
print(list(dp))
# ['azureml:///<sub_id>/resourcegroups/<rg_name>/workspaces/<ws_name>/datastores/<datastore>/paths/<folder>/file1.txt',
# ['azureml:///<sub_id>/resourcegroups/<rg_name>/workspaces/<ws_name>/datastores/<datastore>/paths/<folder>/file1.txt',
# 'azureml:///<sub_id>/resourcegroups/<rg_name>/workspaces/<ws_name>/datastores/<datastore>/paths/<folder>/file2.txt', ...]
You can also open files using `FSSpecFileOpener <generated/torchdata.datapipes.iter.FSSpecFileOpener.html>`_
Expand Down
2 changes: 2 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Features described in this documentation are classified by release status:
:maxdepth: 2
:caption: API Reference:

torchdata.stateful_dataloader.rst
torchdata.datapipes.iter.rst
torchdata.datapipes.map.rst
torchdata.datapipes.utils.rst
Expand All @@ -47,6 +48,7 @@ Features described in this documentation are classified by release status:
:maxdepth: 2
:caption: Tutorial and Examples:

stateful_dataloader_tutorial.rst
dp_tutorial.rst
dlv2_tutorial.rst
examples.rst
Expand Down
177 changes: 177 additions & 0 deletions docs/source/stateful_dataloader_tutorial.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
Stateful DataLoader Tutorial
============================

Saving and loading state
------------------------

Stateful DataLoader adds the ``load_state_dict``, ``state_dict`` methods to the ``torch.utils.data.DataLoader``. State fetch and set can be done as follows:

.. code:: python
from torchdata.stateful_dataloader import StatefulDataLoader
dataloader = StatefulDataLoader(dataset, num_workers=2)
for i, batch in enumerate(dataloader):
...
if i == 10:
state_dict = dataloader.state_dict()
break
# Training run resumes with the previous checkpoint
dataloader = StatefulDataLoader(dataset, num_workers=2)
# Resume state with DataLoader
dataloader.load_state_dict(state_dict)
for i, batch in enumerate(dataloader):
...
Saving Custom State with Map-Style Datasets
-------------------------------------------

For efficient resuming of `Map-style datasets <https://pytorch.org/docs/stable/data.html#map-style-datasets>`_, you can resume iteration by defining ``state_dict`` / ``load_state_dict`` methods in your sampler. If your dataset has worker-specific state (eg RNG transform state) you can add ``state_dict`` / ``load_state_dict`` methods to your dataset.

.. code:: python
from typing import *
import torch
import torch.utils.data
from torchdata.stateful_dataloader import StatefulDataLoader
# If you are using the default RandomSampler and BatchSampler in torch.utils.data, they are patched when you import torchdata.stateful_dataloader so that defining, a custom sampler here is unnecessary
class MySampler(torch.utils.data.Sampler[int]):
def __init__(self, high: int, seed: int, limit: int):
self.seed, self.high, self.limit = seed, high, limit
self.g = torch.Generator()
self.g.manual_seed(self.seed)
self.i = 0
def __iter__(self):
while self.i < self.limit:
val = int(torch.randint(high=self.high, size=(1,), generator=self.g))
self.i += 1
yield val
def load_state_dict(self, state_dict: Dict[str, Any]):
self.i = state_dict["i"]
self.g.set_state(state_dict["rng"])
def state_dict(self) -> Dict[str, Any]:
return {"i": self.i, "rng": self.g.get_state()}
# Optional: save dataset random transform state
class NoisyRange(torch.utils.data.Dataset):
def __init__(self, high: int, mean: float, std: float):
self.high, self.mean, self.std = high, torch.tensor([float(mean)]), float(std)
def __len__(self):
return self.high
def __getitem__(self, idx: int) -> float:
if not (0 <= idx < self.high):
raise IndexError()
x = torch.normal(self.mean, self.std)
noise = x.item()
return idx + noise
def load_state_dict(self, state_dict):
torch.set_rng_state(state_dict["rng"])
def state_dict(self):
return {"rng": torch.get_rng_state()}
# Test both single/multiprocess dataloading
for num_workers in [0, 2]:
print(f"{num_workers=}")
dl = StatefulDataLoader(NoisyRange(5, 1, 1), sampler=MySampler(5, 1, 10),
batch_size=2, drop_last=False, num_workers=num_workers)
batches = []
for i, batch in enumerate(dl):
batches.append(batch)
if i == 2:
sd = dl.state_dict()
dl.load_state_dict(sd)
batches2 = list(dl)
print(batches[3:])
print(batches2)
"""
Output:
num_workers=0
[tensor([-0.4526, 3.7948], dtype=torch.float64), tensor([6.5494, 3.0470], dtype=torch.float64)]
[tensor([-0.4526, 3.7948], dtype=torch.float64), tensor([6.5494, 3.0470], dtype=torch.float64)]
num_workers=2
[tensor([3.7412, 1.2438], dtype=torch.float64), tensor([4.4807, 4.0036], dtype=torch.float64)]
[tensor([3.7412, 1.2438], dtype=torch.float64), tensor([4.4807, 4.0036], dtype=torch.float64)]
"""
Saving Custom State with Iterable-Style Datasets
------------------------------------------------

Tracking iteration order with `Iterable-style datasets <https://pytorch.org/docs/stable/data.html#iterable-style-datasets>`_ requires state from each worker-level instance of the dataset to be captured. You can define ``state_dict`` / ``load_state_dict`` methods on your dataset which capture worker-level state. :class:`StatefulDataLoader` will handle aggregation across workers and distribution back to the workers. Calling ``load_state_dict`` requires :class:`StatefulDataLoader`` to have same ``num_workers`` as those of the provided ``state_dict``.

.. code:: python
from typing import *
import torch
import torch.utils.data
from torchdata.stateful_dataloader import StatefulDataLoader
class MyIterableDataset(torch.utils.data.IterableDataset):
def __init__(self, high: int, seed: int):
self.high, self.seed = high, seed
self.g = torch.Generator()
self.i = 0
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
worker_id = worker_info.id
num_workers = worker_info.num_workers
else:
worker_id = 0
num_workers = 1
self.g.manual_seed(self.seed)
arr = torch.randperm(self.high, generator=self.g)
arr = arr[worker_id:self.high:num_workers]
for idx in range(self.i, len(arr)):
self.i += 1
yield arr[idx]
self.i = 0
def state_dict(self):
return {"i": self.i}
def load_state_dict(self, state_dict):
self.i = state_dict["i"]
# Test both single/multiprocess dataloading
for num_workers in [0, 2]:
print(f"{num_workers=}")
dl = StatefulDataLoader(
MyIterableDataset(12, 0), batch_size=2, drop_last=False,
num_workers=num_workers)
batches = []
for i, batch in enumerate(dl):
batches.append(batch)
if i == 2:
sd = dl.state_dict()
dl.load_state_dict(sd)
batches2 = list(dl)
print(batches[3:])
print(batches2)
"""
Output:
num_workers=0
[tensor([ 2, 10]), tensor([3, 1]), tensor([11, 6])]
[tensor([ 2, 10]), tensor([3, 1]), tensor([11, 6])]
num_workers=2
[tensor([ 4, 10]), tensor([ 3, 11]), tensor([1, 6])]
[tensor([ 4, 10]), tensor([ 3, 11]), tensor([1, 6])]
"""
13 changes: 13 additions & 0 deletions docs/source/torchdata.stateful_dataloader.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
:tocdepth: 3

Stateful DataLoader
===================

.. automodule:: torchdata.stateful_dataloader

StatefulDataLoader is a drop-in replacement for `torch.utils.data.DataLoader <https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader>`_ which offers ``state_dict`` / ``load_state_dict`` methods for handling mid-epoch checkpointing which operate on the previous/next iterator requested from the dataloader (resp.).

By default, the state includes the number of batches yielded and uses this to naively fast-forward the sampler (map-style) or the dataset (iterable-style). However if the sampler and/or dataset include ``state_dict`` / ``load_state_dict`` methods, then it will call them during its own ``state_dict`` / ``load_state_dict`` calls. Under the hood, :class:`StatefulDataLoader` handles aggregation and distribution of state across multiprocess workers (but not across ranks).

.. autoclass:: StatefulDataLoader
:members:
2 changes: 1 addition & 1 deletion torchdata/dataloader2/reading_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def __new__(cls, *args, **kwargs):

class InProcessReadingService(ReadingServiceInterface):
r"""
Default ReadingService to serve the ``DataPipe` graph in the main process,
Default ReadingService to serve the ``DataPipe`` graph in the main process,
and apply graph settings like determinism control to the graph.
Args:
Expand Down
19 changes: 10 additions & 9 deletions torchdata/stateful_dataloader/stateful_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,12 @@

class StatefulDataLoader(DataLoader[_T_co]):
r"""
This is a drop in replacement for :class:`~torch.utils.data.DataLoader`
This is a drop in replacement for ``torch.utils.data.DataLoader``
that implements state_dict and load_state_dict methods, enabling mid-epoch
checkpointing.
All arguments are identical to :class:`~torch.utils.data.DataLoader`, with
a new kwarg: `snapshot_every_n_steps: Optional[int] = `.
See :py:mod:`torch.utils.data` documentation page for more details.
All arguments are identical to ``torch.utils.data.DataLoader``, with
a new kwarg: ``snapshot_every_n_steps``.
Args:
dataset (Dataset): dataset from which to load the data.
Expand Down Expand Up @@ -148,11 +147,13 @@ class StatefulDataLoader(DataLoader[_T_co]):
maintain the workers `Dataset` instances alive. (default: ``False``)
pin_memory_device (str, optional): the device to :attr:`pin_memory` to if ``pin_memory`` is
``True``.
snapshot_every_n_steps (int, optional): Defines how often the state is
transferred from the dataloader workers to the dataloader. By default, it is set to ``1``, i.e., state is transferred every step. If the state is large, this value can be increased (and ideally set to the frequency of training checkpointing) to reduce the overhead of transferring state every step.
.. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn`
cannot be an unpicklable object, e.g., a lambda function. See
:ref:`multiprocessing-best-practices` on more details related
`multiprocessing-best-practices <https://pytorch.org/docs/stable/notes/multiprocessing.html#multiprocessing-best-practices>`_ on more details related
to multiprocessing in PyTorch.
.. warning:: ``len(dataloader)`` heuristic is based on the length of the sampler used.
Expand All @@ -169,12 +170,12 @@ class StatefulDataLoader(DataLoader[_T_co]):
dropped when :attr:`drop_last` is set. Unfortunately, PyTorch can not detect such
cases in general.
See `Dataset Types`_ for more details on these two types of datasets and how
See `Dataset Types <https://pytorch.org/docs/stable/data.html>`_ for more details on these two types of datasets and how
:class:`~torch.utils.data.IterableDataset` interacts with
`Multi-process data loading`_.
`Multi-process data loading <https://pytorch.org/docs/stable/data.html#multi-process-data-loading>`_.
.. warning:: See :ref:`reproducibility`, and :ref:`dataloader-workers-random-seed`, and
:ref:`data-loading-randomness` notes for random seed related questions.
.. warning:: See `Reproducibility <https://pytorch.org/docs/stable/notes/randomness.html#reproducibility>`_, and `Dataloader-workers-random-seed <https://pytorch.org/docs/stable/notes/faq.html#dataloader-workers-random-seed>`_, and
`Data-loading-randomness <https://pytorch.org/docs/stable/data.html#data-loading-randomness>`_ notes for random seed related questions.
.. _multiprocessing context:
https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods
Expand Down

0 comments on commit 265d317

Please sign in to comment.