diff --git a/.github/dependabot.yml b/.github/dependabot.yml
index d49f7d476..85e1b4163 100644
--- a/.github/dependabot.yml
+++ b/.github/dependabot.yml
@@ -11,3 +11,21 @@ updates:
interval: "weekly"
# Allow up to 5 open pull requests for pip dependencies
open-pull-requests-limit: 5
+ ignore:
+ - dependency-name: "GitPython"
+ - dependency-name: "docutils"
+ - dependency-name: "furo"
+ - dependency-name: "myst-parser"
+ - dependency-name: "nbsphinx"
+ - dependency-name: "pandoc"
+ - dependency-name: "pypandoc"
+ - dependency-name: "sphinx-argparse"
+ - dependency-name: "sphinx-copybutton"
+ - dependency-name: "sphinx"
+ - dependency-name: "sphinx-tabs"
+ - dependency-name: "sphinxcontrib.katex"
+ - dependency-name: "sphinxcontrib-applehelp"
+ - dependency-name: "sphinxcontrib-devhelp"
+ - dependency-name: "sphinxcontrib-htmlhelp"
+ - dependency-name: "sphinxcontrib-qthelp"
+ - dependency-name: "sphinxcontrib-serializinghtml"
diff --git a/docs/source/_static/images/batching_methods.png b/docs/source/_static/images/batching_methods.png
new file mode 100644
index 000000000..0359cbd79
Binary files /dev/null and b/docs/source/_static/images/batching_methods.png differ
diff --git a/docs/source/_static/images/intra_canonical_node_shuffling.png b/docs/source/_static/images/intra_canonical_node_shuffling.png
new file mode 100644
index 000000000..c11b19922
Binary files /dev/null and b/docs/source/_static/images/intra_canonical_node_shuffling.png differ
diff --git a/docs/source/_static/images/mds_writing.png b/docs/source/_static/images/mds_writing.png
new file mode 100644
index 000000000..252d21af6
Binary files /dev/null and b/docs/source/_static/images/mds_writing.png differ
diff --git a/docs/source/_static/images/py1b_py1br.png b/docs/source/_static/images/py1b_py1br.png
new file mode 100644
index 000000000..558b30f12
Binary files /dev/null and b/docs/source/_static/images/py1b_py1br.png differ
diff --git a/docs/source/_static/images/py1e.png b/docs/source/_static/images/py1e.png
new file mode 100644
index 000000000..fdb5634ae
Binary files /dev/null and b/docs/source/_static/images/py1e.png differ
diff --git a/docs/source/_static/images/remote_streams.png b/docs/source/_static/images/remote_streams.png
new file mode 100644
index 000000000..f7f6bcf50
Binary files /dev/null and b/docs/source/_static/images/remote_streams.png differ
diff --git a/docs/source/_static/images/sample_partition.png b/docs/source/_static/images/sample_partition.png
new file mode 100644
index 000000000..287ddb079
Binary files /dev/null and b/docs/source/_static/images/sample_partition.png differ
diff --git a/docs/source/_static/images/sample_retrieval.png b/docs/source/_static/images/sample_retrieval.png
new file mode 100644
index 000000000..64c3a70e5
Binary files /dev/null and b/docs/source/_static/images/sample_retrieval.png differ
diff --git a/docs/source/_static/images/shards_canonical_nodes.png b/docs/source/_static/images/shards_canonical_nodes.png
new file mode 100644
index 000000000..c3f3c664a
Binary files /dev/null and b/docs/source/_static/images/shards_canonical_nodes.png differ
diff --git a/docs/source/_static/images/shards_sequential.png b/docs/source/_static/images/shards_sequential.png
new file mode 100644
index 000000000..72e05d6bc
Binary files /dev/null and b/docs/source/_static/images/shards_sequential.png differ
diff --git a/docs/source/_static/images/shards_shuffled.png b/docs/source/_static/images/shards_shuffled.png
new file mode 100644
index 000000000..1bcd60841
Binary files /dev/null and b/docs/source/_static/images/shards_shuffled.png differ
diff --git a/docs/source/_static/images/shuffling_example.png b/docs/source/_static/images/shuffling_example.png
new file mode 100644
index 000000000..c18f5b959
Binary files /dev/null and b/docs/source/_static/images/shuffling_example.png differ
diff --git a/docs/source/_static/images/streaming_partitioning.png b/docs/source/_static/images/streaming_partitioning.png
new file mode 100644
index 000000000..27402344a
Binary files /dev/null and b/docs/source/_static/images/streaming_partitioning.png differ
diff --git a/docs/source/conf.py b/docs/source/conf.py
index dbd0f1b83..0a4e98872 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -77,14 +77,15 @@
'sphinx.ext.extlinks',
'sphinx.ext.coverage',
'sphinx.ext.napoleon',
+ 'sphinxcontrib.katex',
'sphinx.ext.viewcode',
'sphinx.ext.intersphinx',
'sphinx_copybutton',
'myst_parser',
'sphinxarg.ext',
'sphinx.ext.doctest',
- 'nbsphinx',
'sphinx_tabs.tabs',
+ 'nbsphinx',
]
diff --git a/docs/source/dataset_configuration/mixing_data_sources.md b/docs/source/dataset_configuration/mixing_data_sources.md
new file mode 100644
index 000000000..7a9be245d
--- /dev/null
+++ b/docs/source/dataset_configuration/mixing_data_sources.md
@@ -0,0 +1,89 @@
+# Mixing Datasets
+
+Training a model often requires combining data from multiple different sources. Streaming makes combining these data sources, or streams, easy and configurable. See the [main concepts page](../getting_started/main_concepts.md#distributed-model-training) for a high-level view of distributed training with multiple streams.
+
+## Using Streams
+
+A stream is a data source, as a collection of shard files (or set of subdirectories containing shard files). Shard files can optionally be compressed. Streams are represented by the {class}`streaming.Stream` object. Similar to {class}`streaming.StreamingDataset` itself, a `Stream` object can take in `remote` and `local` paths -- see [here](../getting_started/main_concepts.md#remote-data-streams) for an example.
+
+It is possible, though not recommended, for streams to have different schemas.
+
+## Configuring the data mix
+The `proportion`, `repeat`, or `choose` arguments to `Stream` are used to configure different dataset mixing schemes. Only one of them may be set at a time, and all streams must use the same mixing scheme (e.g., Stream A with `proportion` and Stream B with `choose` are incompatible).
+- **`proportion`**: Specifies how to sample this Stream relative to other Streams.
+- **`repeat`**: Specifies the degree to which a Stream is upsampled or downsampled.
+- **`choose`**: Specifies the number of samples to choose from a Stream.
+
+Let's look at some examples of dataset mixing in action.
+
+### Using `proportion` for relative weighting
+
+As an example, let's say we have Stream A with 100 samples and Stream B with 200 samples. The `epoch_size`, if not set, will default to the total number of unique samples -- in this case, 300. To configure our training dataset to be 25% from Stream A and 75% from Stream B, we simply set `proportion` to those values:
+
+```python
+stream_A = Stream(
+ remote = 's3://stream_A_remote',
+ local = '/tmp/stream_A',
+ proportion = 0.25,
+)
+stream_B = Stream(
+ remote = 's3://stream_B_remote',
+ local = '/tmp/stream_B',
+ proportion = 0.75,
+)
+dataset = StreamingDataset(
+ streams = [stream_A, stream_B],
+)
+```
+
+Since `epoch_size` has not been specified, the epoch will be 300 samples long, of which 75 samples will come from Stream A, and 225 from Stream B. Equivalently, we could have also set `proportion` to 2 for Stream A and 6 for Stream B for the same weighting -- StreamingDataset will normalize the proportion weights.
+
+If `epoch_size` is explicitly set, then proportions will apply to that value instead. For example, if `epoch_size` was passed as 400 to StreamingDataset, as below, and proportions stayed the same, then in each epoch, 100 samples would be from Stream A and 300 would be from Stream B.
+
+```python
+dataset = StreamingDataset(
+ epoch_size = 400,
+ streams = [stream_A, stream_B], # With proportions A: 0.25 and B: 0.75.
+)
+```
+
+For multi-epoch training, to control how samples are chosen between epochs, see the [inter-epoch sampling](replication_and_sampling.md#inter-epoch-sampling) section.
+
+### Using `repeat` for absolute weighting
+
+It can be useful to specify how many times to upsample or downsample a Stream -- the `repeat` argument fulfills this use case. For example, to see every sample from Stream A 3 times per epoch, simply set `repeat` to 3:
+
+```python
+stream_A = Stream(
+ remote = 's3://stream_A_remote',
+ local = '/tmp/stream_A',
+ repeat = 3,
+)
+```
+
+To downsample a stream, meaning that only a fraction of the total samples from that stream are seen every epoch, set `repeat` to less than 1. For example, to see only a quarter of the samples from Stream A per epoch, set `repeat` to 0.25.
+
+### Using `choose` for absolute weighting
+
+Specifying the absolute number of samples to choose from a Stream can also be useful when mixing datasets. Use the `choose` argument to indicate the number of samples to take from a stream per epoch. For example, to see exactly 250 samples from Stream A per epoch, set `choose` to 250:
+
+```python
+stream_A = Stream(
+ remote = 's3://stream_A_remote',
+ local = '/tmp/stream_A',
+ choose = 250,
+)
+```
+
+## Batching Methods
+
+Controlling how a global batch is constructed is a requirement for some training runs. StreamingDataset's `batching_method` argument takes in three different options to configure the composition of each global batch:
+- **`'random'`**: (default) Global batches respect dataset mixing *in expectation*. Stream proportions can vary somewhat between batches.
+- **`'stratified'`**: *Every* global batch respects dataset mixing exactly. Can help mitigate loss spikes and divergence by making sure stream proportions hold for every batch.
+- **`'per_stream'`**: Each global batch contains samples from only one stream at a time. Particularly useful when your streams contain data of different tensor shapes/sizes, so that each batch can contain samples of the same shape/size.
+
+As an example, suppose we have Stream A (green) and Stream B (blue), each making up half of our total dataset. Applying each of the batching methods would make global batches look like this:
+
+
+
+Each bar represents a single global batch. We see that `random` batching can have some variance in stream composition, while `stratified` batching keeps composition exact, and `per_stream` batching constructs each batch with a single stream.
diff --git a/docs/source/dataset_configuration/replication_and_sampling.md b/docs/source/dataset_configuration/replication_and_sampling.md
new file mode 100644
index 000000000..c458bc783
--- /dev/null
+++ b/docs/source/dataset_configuration/replication_and_sampling.md
@@ -0,0 +1,51 @@
+# Replication and Sampling
+
+You can control how samples are replicated, chosen between epochs, and chosen from shards. These are useful for a variety of cases:
+- **Replication**: Replicate training samples among subsets of devices. This is particularly useful for Tensor Parallelism (TP) or Sequence Parallelism (SP).
+- **Inter-epoch Sampling**: Control if the samples seen across epochs should vary or not.
+- **Sampling from shards**: Control how many samples to choose from each shard at a time.
+
+Let's see when and how to use these features.
+
+## Replication
+
+Training with Tensor Parallelism (TP) or Sequence Parallelism (SP) requires multiple devices to see the same sample of data. The `replication` parameter of {class}`streaming.StreamingDataset`, controls how many consecutive devices will see the same samples in each batch. For example, if `replication` is set to 4 for a training job with 16 GPUs, devices 0 through 3 will see the same samples, devices 4 through 7 will see the same samples, and so on.
+
+```python
+dataset = StreamingDataset(
+ ...
+ replication = 4, # Every 4 GPUs will see the same samples.
+ ...
+)
+```
+
+Be aware that samples are only replicated across consecutive GPUs, as denoted by their rank from [PyTorch's distributed module](https://pytorch.org/docs/stable/distributed.html).
+
+## Epoch size
+
+You can specify the size of each epoch of training with the `epoch_size` argument:
+
+```python
+dataset = StreamingDataset(
+ ...
+ epoch_size = 10000, # Each epoch will be 10k samples.
+ ...
+)
+```
+
+## Inter-epoch sampling
+
+You can choose how sampling from your dataset(s) occurs between epochs by specifying the `sampling_method` when instantiating `StreamingDataset`. This can be one of two values:
+
+- `'balanced'`: (default) Samples are chosen at random from dataset(s) during each epoch.
+- `'fixed'`: The same samples from the dataset(s) are chosen during every epoch.
+
+For example, with `balanced` sampling, if the size of an epoch is 1000 samples, but my dataset contains 2000 samples, then each epoch will consist of 1000 samples taken at random from the underlying 2000. But with `fixed` sampling, the same 1000 samples that are seen in epoch 0 will be seen in all subsequent epochs as well.
+
+## Sampling from shards
+
+If all samples from a shard don't have to be used in training, the number of samples to choose from each shard is set by the `sampling_granularity` parameter to StreamingDataset. The `sampling_granularity` arg defaults to 1, meaning that one sample is chosen from each shard at a time.
+
+This is particularly useful if just training on a small subset of your overall dataset. In this case, the way in which samples are chosen from shards becomes important, and directly impacts how many shards I have to download for the training job. For example, suppose the overall dataset has 10,000 samples, split up between 1000 shards of 100 samples each, but the epoch size is just 1000 samples. If `sampling_granularity` is set to 1, then the training dataset will consist of a single sample from each of the 1000 shards, meaning that all 1000 shards have to be downloaded over the course of the run. Instead, if `sampling_granularity` is set to 100, then the training dataset will consist of all 100 samples from just 10 shards, and only 10 shards will have to be downloaded for the run.
+
+If the run's epoch size is large enough such that all shards have to be downloaded anyways, setting `sampling_granularity` will not change shard download demand.
diff --git a/docs/source/dataset_configuration/shard_retrieval.md b/docs/source/dataset_configuration/shard_retrieval.md
new file mode 100644
index 000000000..02d551d97
--- /dev/null
+++ b/docs/source/dataset_configuration/shard_retrieval.md
@@ -0,0 +1,100 @@
+# Shard Retrieval
+
+Shards are downloaded on the fly during training and samples are retrieved from them. You can configure {class}`streaming.StreamingDataset`'s shard retrieval to meet your training job's needs. For more information about shard retrieval during distributed model training, refer to the [main concepts page](../getting_started/main_concepts.md#distributed-model-training).
+
+## Loading datasets
+
+### Pointing to your dataset
+
+To train on a dataset that lives in a remote location, simply pass the path to StreamingDataset's `remote` argument. The dataset's `index.json` file should live at this directory. StreamingDataset works with all major cloud providers. The `local` argument should be used to specify where the downloaded shards will be stored, on local disk.
+
+```python
+dataset = StreamingDataset(
+ remote = 's3://some-bucket/my-dataset', # dataset lives at this remote path
+ local = '/local/dataset', # shards downloaded and stored locally at this path
+)
+```
+
+If your dataset is already available on local disk for your GPUs to access, only specify the `local` argument.
+
+```python
+dataset = StreamingDataset(
+ local = '/local/dataset', # dataset shards are already locally available at this path
+)
+```
+
+The `split` argument can be used to specify a particular subdirectory to use -- for example, a training dataset split.
+
+```python
+dataset = StreamingDataset(
+ remote = 's3://some-bucket/my-dataset',
+ local = '/local/dataset',
+ split = 'train', # dataset will be loaded from 's3://some-bucket/my-dataset/train'
+)
+```
+
+### Multiple streams
+
+If using multiple data sources, specify the `remote` and/or `local` paths for each one in a separate {class}`streaming.Stream` object, and pass those to StreamingDataset's `streams` argument. An example can be found [here](../getting_started/main_concepts.md#Remote-data-streams).
+
+### Hash Validation
+
+If you wrote out your dataset shards with specific hash functions (see [here](../preparing_datasets/basic_dataset_conversion.md#Configuring-dataset-writing)) and want to validate them at training time, set the `validate_hash` argument to StreamingDataset. Depending on the hash function, this may slow down data loading.
+
+```python
+dataset = StreamingDataset(
+ ...
+ validate_hash = 'sha1', # validate shard using sha1 hash function
+ ...
+)
+```
+
+## Controlling shard downloads
+
+### Downloading ahead
+
+Setting the `predownload` argument ensures that StreamingDataset will download the shards needed for the upcoming `predownload` samples, per worker. For example, if `predownload` is set to 8, then each DataLoader worker will download the shards needed for up to 8 samples ahead of the current point in training. The default value of `predownload` in StreamingDataset performs well, so only set this argument if you want to prepare more samples ahead of the current training batch.
+
+```python
+dataset = StreamingDataset(
+ ...
+ predownload = 8, # each worker will download shards for up to 8 samples ahead
+ ...
+)
+```
+
+### Retries and Timeout
+
+Set the `download_retry` argument to the number of times a shard download should be retried. The `download_timeout` argument specifies, in seconds, how long to wait for a shard download before throwing an exception. For larger shards, a longer `download_timeout` can be necessary.
+
+```python
+dataset = StreamingDataset(
+ ...
+ download_retry = 3, # retry shard downloads up to 3 times
+ download_timeout = 120, # wait 2 minutes for a shard to download
+ ...
+)
+```
+
+## Configure shard storage
+
+### Cache limit
+
+If you have limited local disk space, specify the `cache_limit` argument. Once locally stored shards reach the `cache_limit`, Streaming will begin evicting shards to stay under the limit. This is particularly useful for very large datasets or small disks. Setting `cache_limit` too low will hinder performance, since shards may be continually evicted and redownloaded. This can be specified as integer bytes or as a human-readable string.
+
+```python
+cache_limit = 10*1024**2 # cache limit of 10mb
+cache_limit = '10mb' # also a cache limit of 10mb
+```
+
+### Keeping compressed shards
+
+If your dataset shards are compressed (see [here](../preparing_datasets/basic_dataset_conversion.md#Configuring-dataset-writing)), StreamingDataset will decompress them upon download for use in training. To control whether the compressed versions of shards are kept locally, use the `keep_zip` flag. This defaults to `False`, meaning that StreamingDataset will default to deleting compressed shards and only keeping the decompressed shards.
+
+```python
+dataset = StreamingDataset(
+ ...
+ keep_zip = True, # keep compressed versions of shards locally
+ ...
+)
+```
diff --git a/docs/source/dataset_configuration/shuffling.md b/docs/source/dataset_configuration/shuffling.md
new file mode 100644
index 000000000..aed01fd31
--- /dev/null
+++ b/docs/source/dataset_configuration/shuffling.md
@@ -0,0 +1,77 @@
+# Shuffling
+
+Shuffling is important for model convergence during training, but can be computationally expensive. Unshuffled data can lead to divergence, loss spikes, or suboptimal training. Streaming's shuffling is made to give you great shuffle quality without sacrificing training throughput. Shuffling depends on 5 different arguments, shown below. **StreamingDataset's defaults ensure high shuffle quality and throughput, so merely setting `shuffle=True` is performant in nearly all cases.**
+
+| Parameter | Type | Default value | Description |
+| :-------- | :--- | :------------ | :---------- |
+| `shuffle` | `bool` | `False` | turn shuffling on or off |
+| `shuffle_algo` | `str` | `'py1e'` | which shuffling algorithm to use |
+| `shuffle_seed` | `int` | `9176` | all randomness in StreamingDataset is derived from this seed |
+| `shuffle_block_size` | `int` | `max(4000000/num_canonical_nodes, 1<<18)` | Number of samples to shuffle at a time, only used by py1b, py1br, and py1e algorithms |
+| `num_canonical_nodes` | `int` | # of physical nodes | Number of sample buckets. Controls shuffling in py1s and py2s algorithms |
+
+## How Shuffling Works
+
+**Step 1**: StreamingDataset downloads shard metadata contained in `index.json` files and partitions sample IDs among nodes, devices, and workers. Shards and samples are not shuffled. See the [main concepts page](../getting_started/main_concepts.md#distributed-model-training) for more information. Below, we have two streams, or data sources, in green and blue. Each stream is a collection of shards, which are numbered.
+
+
+
+**Step 2**: The shard order is shuffled.
+
+
+
+**Step 3**: Shards are split over canonical nodes, which are simply buckets of samples. The `num_canonical_nodes` parameter controls how many of these buckets there are. Some shards can be split between canonical nodes. There are two canonical nodes shown here, in pink and purple.
+
+
+
+**Step 4**: Samples are shuffled within each canonical node, to maximize performance and reduce inter-node shard downloads. Canonical nodes are then assigned to physical training nodes. Now, the sample IDs are shuffled and ready for training. The `shuffle_algo` parameter controls how this shuffling happens.
+
+
+
+## Shuffling Algorithms
+
+The `shuffle_algo` can be set to one of five choices, each with different tradeoffs. We recommend the default shuffle algorithm, `py1e`, as it achieves great shuffle quality while balancing shard downloads.
+
+### Shuffle-block-based algorithms
+
+If your dataset has not been pre-shuffled, or you are using multiple streams, you should use a shuffle-block-based algorithm. The `py1e`, `py1br`, and `py1b` shuffles use the `shuffle_block_size` parameter, which determines how many samples within each canonical node are shuffled at once. You should set `shuffle_block_size` to be larger than the number of samples in a single shard (usually, at least 10x) for a high quality shuffle.
+
+#### `'py1e'` (default)
+
+Samples from each shard are spread out across a range of maximum size `shuffle_block_size`. The diagram below shows how the samples from each shard are spread out over a specified range, shuffling them.
+
+
+
+This algorithm provides great shuffle quality, just like `py1br` and `py1b`, while also reducing the maximum needed cache limit and better balancing shard downloads. StreamingDataset defaults to using this shuffling algorithm.
+
+#### `'py1br'`
+
+Samples within each canonical node are shuffled in blocks of size `shuffle_block_size`. The block sizes are slightly randomized. The diagram below shows the boundaries of each shuffle block as a dashed line.
+
+
+
+This algorithm is a more download-optimal version of `py1b`, which is being deprecated.
+
+#### `'py1b'`
+
+This algorithm is very similar to `py1br`, without randomizing shuffle block sizes, resulting in suboptimal download performance. It will soon be deprecated -- please use `py1e` or `py1br` instead.
+
+### Intra-shard shuffle algorithms
+
+If your dataset has been pre-shuffled, your cache limit per node is very small, or if just having a lightweight shuffle is okay for your training run, use an intra-shard shuffle algorithm. These shuffling algorithms only require storing one shard per canonical node at a time during training, resulting in very low disk usage.
+
+#### `'py1s'`
+
+Samples within each shard or shard part are shuffled. The `py1s` algorithm only performs shuffling once, *after* shards are split among canonical nodes, in contrast to `py2s`.
+
+#### `'py2s'`
+
+Samples within each shard are shuffled both before and after shards are split among canonical nodes. This corresponds to an additional shuffle before Step 3 above. The `py2s` shuffle results in a higher quality shuffle for shards that are split between two canonical nodes, but requires two shuffles, in contrast to `py1s`. As a result, it is more computationally expensive.
+
+### Naive shuffling
+
+#### `'naive'`
+
+Globally shuffles all samples. This is useful for single-node training on small data, where you want the most random shuffle possible, but is the least download-efficient of all shuffle algorithms. Training throughput is often much lower when using the `naive` shuffling algorithm.
+
+If you are having trouble with throughput, network downloads, or shuffle quality, please refer to the [perfomance tuning page](../distributed_training/performance_tuning.md).
diff --git a/docs/source/distributed_training/elastic_determinism.md b/docs/source/distributed_training/elastic_determinism.md
new file mode 100644
index 000000000..7d960a73c
--- /dev/null
+++ b/docs/source/distributed_training/elastic_determinism.md
@@ -0,0 +1,37 @@
+# Elastic Determinism
+
+Deterministic and reproducible training across varying numbers of GPUs is essential for resizing workloads, debugging distributed training jobs, and more. Streaming is built to provide **elastically deterministic training and resumption**. For example, a training run on 24 GPUs can be stopped, resumed on 16 GPUs, and later, finished on 48 GPUs, all with the same loss curve and global batch size. Here's an example of completely deterministic loss curves as the number of GPUs increases from 8 to 64:
+
+
+
+When combining elastic determinism with elastically sharded checkpoints, as our Composer training library does, distributed training becomes easier and much more flexible. See [here](https://docs.mosaicml.com/projects/composer/en/stable/trainer/checkpointing.html) for more information on Composer's checkpointing.
+
+## Requirements
+
+For elastic determinism, Streaming merely requires that your global batch size stays constant over the course of the training job, and is also divisible by all the numbers of GPUs you wish to run on. For example, with a global batch size of 18, you can train deterministically on 1, 2, 3, 6, 9, or 18 GPUs, but not on 7, since 18 samples cannot be evenly split among GPUs.
+
+Streaming uses the `num_canonical_nodes` parameter, which controls the number of buckets into which samples are partitioned, to ensure that the global sample order remains elastically deterministic. To retain determinism between runs, set `num_canonical_nodes` to the same value. The `num_canonical_nodes` defaults to the number of physical nodes of the first run.
+
+For example, if Run 1 was trained on 32 GPUs, where each physical node had 8 GPUs, then the total number of physical nodes is 4, and `num_canonical_nodes` defaults to 4. If Run 2 is required to have the same loss curve as Run 1, explicitly set `num_canonical_nodes` to 4, and remember to set `batch_size` accordingly:
+
+```python
+# Dataset for Run 1 does not specify `num_canonical_nodes`. Assuming that each physical node has 8 GPUs,
+# and Run 1 is launched on 32 GPUs, `num_canonical_nodes` is set to the number of physical nodes, 4.
+run_1_32_gpu_dataset = StreamingDataset(
+ remote = 'oci://some_remote_path/dataset',
+ local = 'tmp/local/cache',
+ batch_size = 4, # This is the per-device batch size. Global batch size is 32 gpus * 4 samples/gpu = 128 samples
+)
+
+# To make Run 2 have the same loss curve as Run 1, explicitly set `num_canonical_nodes` to 4.
+# Assuming Run 2 is launched on 8 GPUs, the `batch_size` (per-device) must increase by a factor of 4
+# so that the global batch size stays the same (128 samples).
+run_2_8_gpu_dataset = StreamingDataset(
+ remote = 'oci://some_remote_path/dataset',
+ local = 'tmp/local/cache',
+ num_canonical_nodes = 4, # Explicitly set to the same as Run 1 for deterministic training
+ batch_size = 16, # This is the per-device batch size. Global batch size is 8 gpus * 16 samples/gpu = 128 samples
+)
+```
+
+See [this section](../dataset_configuration/shuffling.md#how-shuffling-works) for more information on how `num_canonical_nodes` is used.
diff --git a/docs/source/distributed_training/fast_resumption.md b/docs/source/distributed_training/fast_resumption.md
new file mode 100644
index 000000000..52e940167
--- /dev/null
+++ b/docs/source/distributed_training/fast_resumption.md
@@ -0,0 +1,45 @@
+# Fast Resumption
+
+Being resistant to timeouts, hardware failures, or other errors is crucial to efficient distributed training. While other datasets require iterating through previously seen samples before resuming, Streaming allows for **immediate and deterministic resumption** in the middle of an epoch by being stateful.
+
+## Saving and loading state
+
+To get fast, deterministic mid-epoch resumption, **make sure to use the {class}`streaming.StreamingDataLoader` object**. StreamingDataLoader works in conjunction with StreamingDataset to save and load dataset state. It works exactly like a normal PyTorch DataLoader.
+
+When checkpointing, simply call the `state_dict` method of `StreamingDataLoader` and save it along with your checkpoint. Then, when resuming, call `load_state_dict` with the saved state, and you'll be running in no time. Here's an example:
+
+```python
+from streaming import StreamingDataset
+from streaming import StreamingDataLoader
+
+dataset = StreamingDataset(local='/tmp/cache', remote='s3://remote/dataset', batch_size=1)
+dataloader = StreamingDataLoader(dataset, batch_size=1)
+
+# Here, we assume each sample in our dataset has fields 'x' and 'y'.
+# We save the dataloader state after 4 batches, and stop after 6 batches.
+state_dict = None
+for i, batch in enumerate(dataloader):
+ print(i, batch['x'], batch['y'])
+ if i == 4:
+ state_dict = dataloader.state_dict()
+ if i == 6:
+ break
+```
+
+Now, we've completed 4 batches and seen 6, when training has "stopped". This is akin to a training job failing some time after a checkpointing interval. Now, we resume from where we left off:
+
+```python
+# Create a new dataset
+dataset_2 = StreamingDataset(local='cache', remote='path', batch_size=1)
+dataloader_2 = StreamingDataLoader(dataset_2, batch_size=1)
+# Load in the state dict that was previously saved
+dataloader_2.load_state_dict(state_dict)
+
+# Iterate over the dataset, which will start from batch 5 now.
+for i, batch in enumerate(dataloader_2):
+ print(i, batch['x'], batch['y'])
+```
+
+## Resumption with Composer
+
+When training with [Composer](https://docs.mosaicml.com/projects/composer/en/stable/), our open-soure deep learning training library built on top of PyTorch, fast resumption is handled automatically. Composer and Streaming work seamlessly together to provide efficient, scalable neural network training.
diff --git a/docs/source/distributed_training/performance_tuning.md b/docs/source/distributed_training/performance_tuning.md
new file mode 100644
index 000000000..f637ec457
--- /dev/null
+++ b/docs/source/distributed_training/performance_tuning.md
@@ -0,0 +1,77 @@
+# Performance Tuning
+
+Getting the best performance from your training jobs is of utmost importance -- GPUs are expensive! Streaming's default parameters give great performance out-of-the-box for most model configurations. Performance with Streaming mainly deals with downloading and storing shards optimally.
+
+Great performance with Streaming means that dataloading is never a bottleneck during model training. Streaming also provides the [Streaming Simulator tool](#streaming-simulator) to help with performance optimization. Let's dive in!
+
+## Downloading shards
+
+Streaming downloads dataset shards on the fly to make sure that the samples they contain are ready to be trained with. Refer to this [section](../dataset_configuration/shard_retrieval.md#controlling-shard-downloads) for information on how to control shard downloads. Some potential issues we have seen:
+- The dataset `predownload` and dataloader `workers` are set too low. Either increase the `predownload` parameter to StreamingDataset, or increase the number of workers for your dataloader, to allow more samples to be prepared ahead of time.
+- Shard downloads are not balanced, often in conjunction with low download bandwidth. To determine if shuffling is increasing download demand too much, try running with `shuffle` set to `False`. Then, make sure `shuffle_algo` is set to `'py1e'` (more info [here](../dataset_configuration/shuffling.md#py1e-default)) to help balance out inter- and intra-node shard downloads while maintaining shuffle quality. If this is still slowing down training, try the [`'py1s'`](../dataset_configuration/shuffling.md#py1s) or [`'py2s'`](../dataset_configuration/shuffling.md#py2s) shuffling algorithms.
+
+## Storing shards
+
+Once shards are downloaded, they are stored on each node's disk, and are available to that node's GPUs. Refer to this [section](../dataset_configuration/shard_retrieval.md#controlling-shard-downloads) for information on controlling how shards are stored. The main issue that can crop up here is when the node's available disk space is less than the cache size required to store dataset shards. If `cache_limit` is not set, each node's cache size for shards is given by:
+
+$$L = \frac{S \cdot N}{P}$$
+
+
+Where $L$ is the required cache limit per node, in MB, $S$ is the average shard size, in MB, $N$ is the total number of shard files, and $P$ is the number of physical nodes. However, for optimal performance, the *minimum* required `cache_limit` can be much lower, since each node only needs to store shards that have samples that are actively being used for training. If `shuffle` is `False`, or if using the [`'py1s'`](../dataset_configuration/shuffling.md#py1s) or [`'py2s'`](../dataset_configuration/shuffling.md#py2s) shuffling algorithms, the required cache limit will be approximately:
+
+$$L = 2 \cdot S \cdot \lceil\frac{C}{P}\rceil $$
+
+Where $L$ is the required minimum cache limit per node, in MB, $S$ is the average shard size, in MB, $C$ is the number of canonical nodes (see [here](../dataset_configuration/shuffling.md#how-shuffling-works) and [here](../distributed_training/elastic_determinism.md#requirements)), and $P$ is the number of physical nodes. This is because only a single shard, plus a potentially predownloaded subsequent shard, needs to be resident per canonical node to make progress during training.
+
+If using a shuffle-block-based algorithm such as [`'py1e'`](../dataset_configuration/shuffling.md#py1e-default), [`'py1br'`](../dataset_configuration/shuffling.md#py1br), or [`'py1b'`](../dataset_configuration/shuffling.md#py1b), the required minumum cache limit per node will be approximately:
+
+$$L = k \cdot S \lceil \frac{B}{Q} \rceil \cdot \lceil\frac{C}{P}\rceil $$
+
+Where $L$ is the required minimum cache limit per node, in MB, $k$ is a constant that depends on the shuffle algorithm used, $S$ is the average shard size, in MB, $B$ is the shuffle block size (see [here](../dataset_configuration/shuffling.md#shuffle-block-based-algorithms)) as a number of samples, $Q$ is the average number of samples per shard, $C$ is the number of canonical nodes (sample buckets), and $P$ is the number of physical nodes. This is because each shuffle block consists of $\lceil \frac{B}{Q}\rceil$ shards, and the subsequent shuffle block's shards may have to be predownloaded. The constant $k$ is $1$ for the [`'py1e'`](../dataset_configuration/shuffling.md#py1e-default) algorithm, whereas it is $2$ for both [`'py1br'`](../dataset_configuration/shuffling.md#py1br) and [`'py1b'`](../dataset_configuration/shuffling.md#py1b), meaning that `'py1e'` gives better cache limit performance, while retaining shuffle quality.
+
+## Streaming Simulator
+
+A simulator for throughput, network use, and shuffle quality with MosaicML Streaming. The simulator allows you to:
+- Plan runs and anticipate issues beforehand
+- Find optimal run configurations
+- Debug issues with underperforming runs
+- Better understand the impact of different configurations
+
+### Getting Started
+Run the following to install simulator-specific dependencies, if they don't already exist:
+```
+pip install --upgrade "mosaicml-streaming[simulator]"
+```
+Then, simply run `simulator` in your command line to open the Web UI and get simulating!
+### Key Features
+
+#### Throughput
+Throughput is estimated for the duration of the run and is displayed as the simulation progresses. We estimate throughput by iterating over the samples of the dataset in order, and performing shard downloads based on an estimate of network bandwidth. The 10-step rolling average is displayed.
+
+
+
+#### Network Downloads
+Cumulative network downloads are also estimated for the run and displayed. It is calculated in conjunction with throughput. If shards are compressed, we assume they are downloaded in compressed form and immediately uncompressed.
+
+
+
+#### Simulation Stats
+We also provide various useful statistics from the simulation, such as:
+- Minimum cache limit (i.e., maximum space used by live shards)
+- Steps slowed down by shard downloads
+- Estimated time to first batch
+- Estimated warmup time (i.e., time until throughput maximized)
+
+
+
+#### Shuffle Quality
+You can choose to evaluate the quality of different shuffling algorithms for your run. We provide an estimate of shuffle quality based on the entropy calculated over the probability distribution of differences between neighboring sample indices and shard indices of the dataset. *These shuffle quality metrics are noisy and may not reflect the true strength of a shuffle.*
+
+
+
+
+
+#### Yaml Support
+Yaml files that follow MosaicML conventions can be uploaded and simulated as well. Simply click the toggle, enter any needed additional information, and see your results. Parameters can also be modified to quickly test out configurations.
+
+
diff --git a/docs/source/distributed_training/requirements.md b/docs/source/distributed_training/requirements.md
new file mode 100644
index 000000000..d48ca2b2c
--- /dev/null
+++ b/docs/source/distributed_training/requirements.md
@@ -0,0 +1,25 @@
+# Requirements for Distributed Training
+
+Streaming is purpose built for fast, large-scale distributed training. It relies on the environment variables below, that must be set on each device/GPU to correctly assign data.
+
+- **WORLD_SIZE**: Total number of processes to launch across all nodes.
+- **LOCAL_WORLD_SIZE**: Total number of processes to launch for each node.
+- **RANK**: Rank of the current process, which is the range between `0` to `WORLD_SIZE - 1`.
+- **MASTER_ADDR**: The hostname for the rank-zero process.
+- **MASTER_PORT**: The port for the rank-zero process.
+
+Some launchers will automatically take care of setting environment variables. For example, using [Composer](https://docs.mosaicml.com/projects/composer/en/stable/) in conjunction with [MosaicML Platform](https://docs.mosaicml.com/projects/mcli/en/latest/) will automatically enable distributed training.
+
+More info about using different distributed training launchers with Streaming can be found [here](using_launchers.md).
+
+## Parallelism Strategies
+
+Streaming supports a variety of distributed training parallelism strategies, including Distributed Data Parallelism ([DDP](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html)), Fully Sharded Data Parallelism ([FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), akin to [ZeRO](https://www.microsoft.com/en-us/research/blog/zero-deepspeed-new-system-optimizations-enable-training-models-with-over-100-billion-parameters/)), Hybrid Sharded Data Parallelism ([HSDP](https://pytorch.org/tutorials/recipes/distributed_device_mesh.html)), Tensor Parallelism ([TP](https://docs.aws.amazon.com/sagemaker/latest/dg/model-parallel-extended-features-pytorch-tensor-parallelism.html)), and Sequence Parallelism ([SP](https://arxiv.org/pdf/2105.13120.pdf)).
+
+### Data Parallel strategies
+
+Parallelism strategies like DDP, FSDP, and HSDP are all data-parallel strategies, where each device needs to see a unique part of the global training batch. StreamingDataset supports this out-of-the-box without any configuration changes.
+
+### Data Replication strategies
+
+Parallelism strategies like TP and SP require multiple devices to receive the same data samples, requiring replication. Simply set the `replication` argument to StreamingDataset to specify how many consecutive devices should receive the same data. An example can be found [here](../dataset_configuration/replication_and_sampling.md#replication).
diff --git a/docs/source/distributed_training/using_launchers.md b/docs/source/distributed_training/using_launchers.md
new file mode 100644
index 000000000..2e9839138
--- /dev/null
+++ b/docs/source/distributed_training/using_launchers.md
@@ -0,0 +1,3 @@
+# Using Launchers
+
+Streaming can be used with multiple different launchers, including [Composer](https://docs.mosaicml.com/projects/composer/en/stable/), [torchrun](https://pytorch.org/docs/stable/elastic/run.html), and others. This page is still under construction so please refer to the distributed training [requirements page](requirements.md) for now!
diff --git a/docs/source/examples b/docs/source/examples
deleted file mode 120000
index d15735c1d..000000000
--- a/docs/source/examples
+++ /dev/null
@@ -1 +0,0 @@
-../../examples
\ No newline at end of file
diff --git a/docs/source/fundamentals/batching.md b/docs/source/fundamentals/batching.md
deleted file mode 100644
index 5560c2a1a..000000000
--- a/docs/source/fundamentals/batching.md
+++ /dev/null
@@ -1,7 +0,0 @@
-# Batching
-
-You can choose how batches are constructed by specifying the `batching_method` argument when instantiating `StreamingDataset`. Currently, this can take on one of three values:
-
-- `'random'`: (default) Samples for each batch are chosen at random from input streams. While stream proportions hold in aggregate over the course of training, this batching method does not guarantee that stream proportions hold for each batch.
-- `'stratified'`: Every single batch is divided up between streams in the same proportions. Unlike in the default case, stream proportions hold for every batch, unlike in the default case, where they hold only in aggregate.
-- `'per_stream'`: Each batch has samples from just one stream. In aggregate over all batches, stream proportions still hold.
diff --git a/docs/source/fundamentals/compression.md b/docs/source/fundamentals/compression.md
deleted file mode 100644
index 4f1e8d0b5..000000000
--- a/docs/source/fundamentals/compression.md
+++ /dev/null
@@ -1,15 +0,0 @@
-# Compression
-
-Compression allows us to store and download a small dataset and use a large dataset. Compression is beneficial for text, often compressing shards to a third of the original size, whereas it is marginally helpful for other modalities like images. Compression operates based on shards. We provide several compression algorithms, but in practice, `Zstandard` is a safe bet across the entire time-size Pareto frontier. The higher the quality level, the higher the compression ratio. However, using higher compression levels will impact the compression speed.
-
-Table of supported compression algorithms:
-
-| Name | Code | Min Level | Default Level | Max Level |
-| --------------------------------------------- | ------ | --------- | ------------- | --------- |
-| [Brotli](https://github.com/google/brotli) | br | 0 | 11 | 11 |
-| [Bzip2](https://sourceware.org/bzip2/) | bz2 | 1 | 9 | 9 |
-| [Gzip](https://www.gzip.org/) | gz | 0 | 9 | 9 |
-| [Snappy](https://github.com/google/snappy) | snappy | – | – | – |
-| [Zstandard](https://github.com/facebook/zstd) | zstd | 1 | 3 | 22 |
-
-The compression algorithm to use, if any, is specified by passing `code` or `code:level` as a string to the [Writer](https://docs.mosaicml.com/projects/streaming/en/stable/api_reference/generated/streaming.MDSWriter.html). Decompression happens behind the scenes in the [Stream](https://docs.mosaicml.com/projects/streaming/en/stable/api_reference/generated/streaming.Stream.html) (inside [StreamingDataset](https://docs.mosaicml.com/projects/streaming/en/stable/api_reference/generated/streaming.StreamingDataset.html)) as shards are downloaded. Control whether to keep the compressed version of shards by setting the `keep_zip` flag in the specific Stream’s init or for all streams in StreamingDataset init.
diff --git a/docs/source/fundamentals/dataset_conversion_guide.md b/docs/source/fundamentals/dataset_conversion_guide.md
deleted file mode 100644
index ff4fb9d94..000000000
--- a/docs/source/fundamentals/dataset_conversion_guide.md
+++ /dev/null
@@ -1,66 +0,0 @@
-# Dataset Conversion Guide
-
-If you haven't read the [Dataset Format](dataset_format.md) guide, then we highly recommend doing so before you read this.
-
-## MDSWriter
-
-To convert the dataset into MDS format, one must use {class}`streaming.MDSWriter`. MDSWriter is like a native file writer; instead of writing the content line by line, MDSWriter writes the data sample by sample. It writes the data into a first shard file (for example, `shard.00000.mds`), and once the shard file reaches a size limit, it creates a new shard file with a number incremented (for example, `shard.00001.mds`), and so on. {class}`streaming.MDSWriter` support various parameters you can tweak based on your requirements. Let's understand each parameter one by one:
-
-1. An `out` parameter is an output dataset directory to save shard files. If the parameter is a local directory path, the shard files are stored locally. If the parameter is a remote directory, a local temporary directory is created to cache the shard files, and then the shard files are uploaded to a remote location. In the end, the temp directory is deleted once shards are uploaded. If the parameter is a tuple of `(local_dir, remote_dir)`, shard files are saved in the `local_dir` and uploaded to a remote location. As shard files are ready, it gets uploaded in the background to a remote location if provided. The user does not have to worry about uploading the shard files manually. `MDSWriter` also support a `keep_local` parameter where after uploading of an individual shard file is completed, you have the flexibility of deleting the shard file locally by providing `keep_local` to `False` (Default is `False`) to avoid running out of disk space.Checkout the [out](https://docs.mosaicml.com/projects/streaming/en/stable/api_reference/generated/streaming.MDSWriter.html) parameter for more detail. For example, one can provide the `out` parameter as shown below:
-
-```python
-out = '/tmp/data'
-out = 's3://bucket/data'
-out = {'/local/data', 'oci://bucket/data'}
-```
-
-2. A `column` parameter that maps a feature name or label name with a streaming supported encoding type. `MDSWriter` encodes your data from provided encoding type to bytes, and later it gets decoded back automatically to its original data type when calling `StreamingDataset`. The `index.json` file saves `column` information for decoding. Below is the list of supported encoding formats.
-
-| Name | Class | Name | Class | Name | Class |
-| ------ | ------ | ------- | ------- | ---- | ------ |
-| bytes | `Bytes` | int8 | `Int8` | pil | `PIL` |
-| str | `Str` | int16 | `Int16` | jpeg | `JPEG` |
-| int | `Int` | int32 | `Int32` | png | `PNG` |
-| uint8 | `UInt8` | int64 | `Int64` | pkl | `Pickle` |
-| uint16 | `UInt16` | float16 | `Float16` | json | `JSON` |
-| uint32 | `UInt32` | float32 | `Float32` | ndarray | `NDArray` |
-| uint64 | `UInt64` | float64 | `Float64` | | |
-
-Below is one example where the feature name `x` is an image, and the label `y` is a class value.
-
-```python
-column = {
- 'x': 'jpeg',
- 'y': 'int8'
-}
-```
-
-**Advanced use-case:** If the data type you are interested in is not listed in the above table, then you can write your own data type class with `encode` and `decode` method in it and patch it inside streaming. For example, let say, you would like to write the same for `int32` data type.
-
-
-```python
-import numpy as np
-from typing import Any
-
-from streaming.base.format.mds.encodings import Encoding, _encodings
-
-class Int32(Encoding):
- def encode(self, obj: Any) -> bytes:
- return obj.tobytes()
-
- def decode(self, data: bytes) -> Any:
- return np.frombuffer(data, np.int32)
-
-_encodings['int32'] = Int32
-```
-
-
-3. A `compression` algorithm name if you would like to compress the shard files. Check out the [compression](compression.md) document for more details.
-
-4. A `hashes` algorithm name to verify data integrity. Check out the [hashing](hashing.md) document for additional details.
-
-5. A shard `size_limit` in bytes for each shard file, after which point to start a new shard. Shard file size depends on the dataset size, but generally, too small of a shard size creates a ton of shard files and heavy network overheads, and too large of a shard size creates fewer shard files, but the training start time would increase since it has to wait for a shard file to get downloaded locally. Based on our intuition, the shard file size of 64Mb, and 128Mb play a balanced role. This parameter is a number of bytes, either directly as an `int` or a human-readable suffix (ex: `1024` or `"1kb"`)
-
-6. A `keep_local` parameter if you would like to keep the shard files locally after it has been uploaded to a remote cloud location by MDSWriter.
-
-This gives you a good understanding of {class}`streaming.MDSWriter` parameters. If you would like to convert your raw data into an MDS format, check out the [Dataset Conversion to MDS Format](../how_to_guides/dataset_conversion_to_mds_format.md) guide.
diff --git a/docs/source/fundamentals/dataset_format.md b/docs/source/fundamentals/dataset_format.md
deleted file mode 100644
index 24a0c8d89..000000000
--- a/docs/source/fundamentals/dataset_format.md
+++ /dev/null
@@ -1,23 +0,0 @@
-# Dataset Format
-
-## Introduction
-
-To use StreamingDataset, one must convert raw data into one of our supported serialized dataset formats. With massive datasets, our serialization format choices are critical to the ultimate observed performance of the system. If we care about performance, we must own the format ourselves. In the Deep Learning model training, we need extremely low latency cold random access on individual samples' granularity to ensure the dataset is not a bottleneck.
-
-StreamingDataset is compatible with any data type, including **images**, **text**, **video**, and **multimodal** data. StreamingDataset supports MDS (Mosaic Data Shard), CSV/TSV, and JSONL format, which can encode and decode most python objects.
-
-## High-level design
-During dataset conversion, StreamingDataset generates two types of files.
-1. **index.json:** The index.json file contains metadata about the shard files, such as how many samples are in each shard file, the compression algorithm used, the shard filename, etc.
-2. **One or more shard files:** Contains the data encoded into bytes. It is conceptualized as a table, each row being a sample. Columns have data types, which deserialize into any Python object such as int, str, PIL Image, etc. The shard file name starts with `shard.00000.` such as `shard.00000.mds`, `shard.00000.csv`, `shard.00000.jsonl`, etc., and number increments as it generates more shards.
-
-
-## Formats
-### 1. MDS
-Mosaic Data Shard (MDS) is like a row-oriented parquet that reads a sample by reading its start/stop bytes from the header and then seeks to sample. The sample can be a singular data entity or a dictionary of key/value pairs where the key is a data field, and the value is data. Check out the [Dataset Conversion Guide](dataset_conversion_guide.md) section to understand more about MDSWriter. Most of the existing [dataset conversion](../how_to_guides/dataset_conversion_to_mds_format.md) script uses MDS format, which is highly flexible with any data type.
-
-### 2. CSV/TSV
-CSV (Comma-Separated Values) and TSV (Tab-Separated Values) are tabular data stored in plain-text form. CSV separates the data using delimiter `,` and TSV separates the data using delimiter `\t`.
-
-### 3. JSONL
-JSON Lines text format, also called newline-delimited JSON. JSON Lines consist of several lines where each line is a valid JSON object, separated by newline character `\n`.
diff --git a/docs/source/fundamentals/environments.md b/docs/source/fundamentals/environments.md
deleted file mode 100644
index ace4dc383..000000000
--- a/docs/source/fundamentals/environments.md
+++ /dev/null
@@ -1,9 +0,0 @@
-# Environments
-
-StreamingDataset relies on certain environment variables that need to be set to work for the distributed workload. If the launcher that you are using does not set the below environment variables, you need to set it manually either in your script or export globally.
-
-- **WORLD_SIZE**: Total number of processes to launch across all nodes.
-- **LOCAL_WORLD_SIZE**: Total number of processes to launch for each node.
-- **RANK**: Rank of the current process, which is the range between `0` to `WORLD_SIZE - 1`.
-- **MASTER_ADDR**: The hostname for the rank-zero process.
-- **MASTER_PORT**: The port for the rank-zero process.
diff --git a/docs/source/fundamentals/hashing.md b/docs/source/fundamentals/hashing.md
deleted file mode 100644
index 9677a51a3..000000000
--- a/docs/source/fundamentals/hashing.md
+++ /dev/null
@@ -1,34 +0,0 @@
-# Hashing
-
-Streaming supports a variety of hash and checksum algorithms to verify data integrity.
-
-We optionally hash shards while serializing a streaming dataset, saving the resulting hashes in the index, which is written last. After the dataset is finished being written, we may hash the index file itself, the results of which must be stored elsewhere. Hashing during writing is controlled by the Writer argument `hashes: Optional[List[str]] = None`. We generally weakly recommend writing streaming datasets with one cryptographic hash algorithm and one fast hash algorithm for offline dataset validation in the future.
-
-Then, we optionally validate shard hashes upon download while reading a streaming dataset. Hashing during reading is controlled separately by the StreamingDataset argument `validate_hash: Optional[str] = None`. We recommend reading streaming datasets for training purposes without validating hashes because of the extra cost in time and computation.
-
-Available cryptographic hash functions:
-
-| Hash | Digest Bytes |
-| -------- | ------------ |
-| blake2b | 64 |
-| blake2s | 32 |
-| md5 | 16 |
-| sha1 | 20 |
-| sha224 | 28 |
-| sha256 | 32 |
-| sha384 | 48 |
-| sha512 | 64 |
-| sha3_224 | 28 |
-| sha3_256 | 32 |
-| sha3_384 | 48 |
-| sha3_512 | 64 |
-
-Available non-cryptographic hash functions:
-
-| Hash | Digest Bytes |
-| -------- | ------------ |
-| xxh32 | 4 |
-| xxh64 | 8 |
-| xxh128 | 16 |
-| xxh3_64 | 8 |
-| xxh3_128 | 16 |
diff --git a/docs/source/fundamentals/parallelism.md b/docs/source/fundamentals/parallelism.md
deleted file mode 100644
index ef5cd4075..000000000
--- a/docs/source/fundamentals/parallelism.md
+++ /dev/null
@@ -1,11 +0,0 @@
-# Parallelism
-
-Streaming supports data parallelism as well as sequence/tensor parallelism.
-
-- **Data Parallelism**: Streaming supports this by default. Each device will get a unique portion of
-each global batch. Samples are not replicated across devices. DDP, FSDP and HSDP all fall under
-this category.
-- **Sequence/Tensor Parallelism**: These parallelism strategies require groups of devices to share a
-portion of each global batch. Specifying the `replication` argument of `StreamingDataset` to `x`
-ensures that `x` consecutive devices will receive the same data. For example, `replication=4`
-sends one set of samples to devices 0 through 3, another to devices 4 through 7, and so on.
diff --git a/docs/source/fundamentals/sampling.md b/docs/source/fundamentals/sampling.md
deleted file mode 100644
index 36e1448f9..000000000
--- a/docs/source/fundamentals/sampling.md
+++ /dev/null
@@ -1,6 +0,0 @@
-# Sampling
-
-You can choose how sampling from your dataset(s) occurs between epochs by specifying the `sampling_method` when instantiating `StreamingDataset`. Currently, this can take on one of two values:
-
-- `'balanced'`: (default) Samples are chosen at random from dataset(s) during each epoch according to the proportions specified.
-- `'fixed'`: The same samples from the dataset(s) are chosen during every epoch, still according to the proportions specified.
diff --git a/docs/source/fundamentals/shuffling.md b/docs/source/fundamentals/shuffling.md
deleted file mode 100644
index 248897c56..000000000
--- a/docs/source/fundamentals/shuffling.md
+++ /dev/null
@@ -1,107 +0,0 @@
-# Shuffling
-
-Shuffling is not simple because very large numbers of samples cannot be shuffled on the fly in their entirety with acceptable performance, and you would not want to if you could for distributed download reasons. Instead, we rely on a combination of factors to cleverly achieve the effect of a global shuffle while being shard-efficient.
-
-## StreamingDataset arguments
-
-StreamingDataset takes four arguments to directly control shuffling.
-
-| Parameter | Type | Description |
-| :-------- | :--- | :---------- |
-| `shuffle` | `bool = False` | turn shuffling on or off |
-| `shuffle_algo` | `str = 'py1s'` | which shuffling algorithm to use |
-| `shuffle_seed` | `int = 9176` | all randomness in StreamingDataset is derived from this argument |
-| `shuffle_block_size` | `int = 1 << 18` | shuffling unit used by py1b and py1br algorithms |
-
-StreamingDataset also takes two other arguments that shuffling interacts with:
-
-| Parameter | Type | Description |
-| :-------- | :--- | :---------- |
-| `predownload` | `Optional[int] = None` | tune together with shuffle block size to keep workers from ever starving of shard pre-downloads while iterating (`None` means derive its value using batch size and number of canonical nodes `max(batch_size, 256 * batch_size // num_canonical_nodes)`) |
-| `num_canonical_nodes` | `Optional[int] = None` | number of divisions of the sample space, which are iterated from beginning to end concurrently (defaults to 64 times the number of initial physical nodes) |
-
-## Algorithms
-
-For `shuffle_algo`, you have four possible options, which have different tradeoffs. Once written, they cannot be changed, although it is easy to add new algorithms:
-
-### naive
-
-Globally shuffle the samples.
-
-Useful for single-node training on small data, where you want the most random shuffle possible.
-
-Statistically, this algorithm will result in all nodes downloading all shards, with those downloads all happening at the start of the epoch and needing to stay resident to make progress, bringing training to a crawl if the dataset is too large.
-
-### py1b
-
-Globally shuffle shards, divide that sample space over canonical nodes, then shuffle samples in fixed-size blocks (given by `shuffle_block_size`). So named because it shuffles samples in python, once, intra-block. A canonical node, for the purposes of shuffling, is simply a collection of shards. In order to have determinism with a different number of physical nodes, the shuffle ordering is done over the canonical nodes and these are then assigned to physical nodes.
-
-Shuffle block size should be set larger or much larger than a single shard. If so, this algorithm is useful for spacing out the contents of shards to mitigate a bad or non-existent pre-shuffle (i.e. if samples from the same shard are related in some way).
-
-In order to improve shuffle quality, this algorithm requires more shards to be downloaded and stay resident to make progress than py1s or py2s, noticed as longer start/resume latency, as a multiple of shuffle block size divided by samples per shard. If you see step-like burstiness in throughput, your workers may not be downloading far enough ahead – try raising predownload (it should be scaled with block size). Step size scales with shuffle block size.
-
-### py1br
-
-Globally shuffle shards, divide that sample space over canonical nodes, then shuffle samples in variable-size blocks (uniformly selected within the range `[0.75*shuffle_block_size, 1.25*shuffle_block_size)`). Shuffle blocks are also staggered -- along with variable shuffle block size, this works to prevent many simultaneous shard downloads. So named because it shuffles samples in python, once, intra-block, and blocks are randomized.
-
-Shuffle block size should be set larger or much larger than a single shard. If so, this algorithm is useful for spacing out the contents of shards to mitigate a bad or non-existent pre-shuffle (i.e. if samples from the same shard are related in some way).
-
-In order to improve shuffle quality, this algorithm requires more shards to be downloaded and stay resident to make progress than py1s or py2s, noticed as longer start/resume latency, as a multiple of shuffle block size divided by samples per shard. However, shard downloads with py1br are more balanced than with py1b, and this effect is more apparent when training with a higher number of nodes, resulting in less network bottlenecks. The shuffle quality of py1br and py1b are similar.
-
-### py1e
-
-Globally shuffle shards, divide that sample space over canonical nodes, shuffle the samples in each shard, then randomly distribute the samples from each shard over an expanded range (determined using `shuffle_block_size`). So named because it shuffles samples by extending the range of each shard, in python.
-
-Shuffle block size should be set larger or much larger than the number of samples in a single shard. This algorithm provides bounds on the range that samples from a shard can appear, allowing for a lower cache limit without decreasing throughput compared to py1b.
-
-This algorithm requires more shards to be downloaded and stay resident to make progress than py1s or py2s, noticed as longer start/resume latency, as a multiple of shuffle block size divided by samples per shard, similar to py1b. However, these shards will be downloaded in a more balanced fashion, reducing network bandwidth bottlenecks.
-
-### py1s
-
-Globally shuffle shards, divide that sample space over canonical nodes, then shuffle the samples within each shard or shard part. So named because it shuffles samples in python, once, intra-shard.
-
-This algorithm only requires one shard to be resident at a time per canonical node, so is smoother and more responsive than py1b, however your pre-shuffle should be good. Also note that different modalities have vastly different numbers of samples per shard, with downstream effects on shuffle quality (rule of thumb: 500 samples/shard for vision, 50K samples/shard for text).
-
-Shuffles twice as fast as py2s by being deterministic (“biased”) about assigning samples to canonical node divisions at boundary shards. In practice, we have not observed any negative downstream impacts from cutting a corner in this way. What effect does exist would be the strongest when your number of shards is very low relative to your number of canonical nodes.
-
-### py2s
-
-Globally shuffle shards, then shuffle the samples within each shard, then divide that sample space over canonical nodes, then shuffle the samples within each shard or shard part. So named because it shuffles samples in python, twice, intra-shard.
-
-This algorithm only requires one shard to be resident at a time per canonical node, so is smoother and more responsive at downloading than py1b, however your pre-shuffle should be good. Also note that different modalities have vastly different numbers of samples per shard, with downstream effects on shuffle quality (rule of thumb: 500 samples/shard for vision, 50K samples/shard for text).
-
-Shuffles roughly twice as slowly as py1s by being random (“correct”) about assigning samples to canonical node divisions at boundary shards. This becomes a pain point at around a billion samples.
-
-## Factors to consider
-
-Philosophers have long debated what it means to be a good shuffle. StreamingDataset relies on five approaches to shuffle quality which work synergistically:
-
-### Pre-shuffle
-
-The foundation of shuffle quality and therefore model learning is the preprocessing that was applied to the dataset, including deduping and pre-shuffling. Pre-shuffling refers to an offline preprocessing step that bakes in a global shuffle of the samples. You pre-shuffle once and benefit or suffer from the results forever.
-
-For performance reasons, samples which are collocated in the same shard are much more likely to be seen in time proximity to one another. The choice of shard size also matters: generally, shards are shuffled globally but samples only intra-shard or intra-block. While there are mitigations below, it is important for balance that we get a good diversity of samples on each shard and minimize repetition.
-
-### Shuffle algorithm
-
-How the shuffle works intimately impacts the distribution of samples and weaknesses thereof. See the preceding section for shuffling algorithms.
-
-### Shuffle block size
-
-You can strengthen the shuffle by increasing the size of the shuffling unit, within reason. For py1s or py2s this is the shard, but py1b provides a sliding scale via `shuffle_block_size` all the way from one sample at a time to all the samples at once (which would be like naive but with canonical node divisions).
-
-Large shuffle block sizes can save you from a bad or missing pre-shuffle. They are also a proportionally cheap and essential measure to take when training for many epochs on small datasets with very few devices. Conversely, large shuffle block sizes are a superfluous waste of time if training with many canonical nodes and many devices on many shards. There is a balance.
-
-### Number of canonical nodes
-
-When iterating, the sample space is divided evenly according to the number of canonical nodes. These divisions are read concurrently from beginning to end striping over dataloader workers in a precise pattern that preserves elastic determinism.
-
-The higher that `num_canonical_nodes` is set, the more independent non-overlapping paths the StreamingDataset replicas take through the shards per model replica (increasing data source diversity), and the more shards need to be downloaded concurrently. Data source diversity becomes increasingly important as you raise the number of different streams comprising the dataset. `num_canonical_nodes` can be raised arbitrarily high so long as the number of physical nodes evenly divides into it, which is ultimately limited by download throughput.
-
-### Splitting shard repeats
-
-The final shuffle quality technique is so important, it is always turned on. When upsampling, each repeat of a shard, including the last fractional repeat if it exists, is treated as a different shard for the purposes of shuffling. This results in the copies getting scattered across the epoch’s sample space, at the cost of more downloads.
-
-Without this, StreamingDataset would have to up/down-sample by stretching shards larger or smaller. Heavily upsampling shards would cause the model to see the same samples many times in rapid succession (at scale), which we have found interacts disastrously with small shuffle units, modulo data augmentation. A potential landmine during training.
-
-Our general advice on shuffling is that there are different tradeoffs at play, and the best answer often depends. We endeavor to provide reasonable defaults. Shuffling choices 2-4 can and should be tested empirically on your own models and your own data.
diff --git a/docs/source/fundamentals/simulator.md b/docs/source/fundamentals/simulator.md
deleted file mode 100644
index 59437d29f..000000000
--- a/docs/source/fundamentals/simulator.md
+++ /dev/null
@@ -1,45 +0,0 @@
-# Streaming Simulator
-A simulator for throughput, network use, and shuffle quality with MosaicML Streaming. The simulator allows you to:
-- Plan runs and anticipate issues beforehand
-- Find optimal run configurations
-- Debug issues with underperforming runs
-- Better understand the impact of different configurations
-
-## Getting Started
-Run the following to install simulator-specific dependencies, if they don't already exist:
-```
-pip install --upgrade "mosaicml-streaming[simulator]"
-```
-Then, simply run `simulator` in your command line to open the Web UI and get simulating!
-## Key Features
-
-### Throughput
-Throughput is estimated for the duration of the run and is displayed as the simulation progresses. We estimate throughput by iterating over the samples of the dataset in order, and performing shard downloads based on an estimate of network bandwidth. The 10-step rolling average is displayed.
-
-
-
-### Network Downloads
-Cumulative network downloads are also estimated for the run and displayed. It is calculated in conjunction with throughput. If shards are compressed, we assume they are downloaded in compressed form and immediately uncompressed.
-
-
-
-### Simulation Stats
-We also provide various useful statistics from the simulation, such as:
-- Minimum cache limit (i.e., maximum space used by live shards)
-- Steps slowed down by shard downloads
-- Estimated time to first batch
-- Estimated warmup time (i.e., time until throughput maximized)
-
-
-
-### Shuffle Quality
-You can choose to evaluate the quality of different shuffling algorithms for your run. We provide an estimate of shuffle quality based on the entropy calculated over the probability distribution of differences between neighboring sample indices and shard indices of the dataset. *These shuffle quality metrics are noisy and may not reflect the true strength of a shuffle.*
-
-
-
-
-
-### Yaml Support
-Yaml files that follow MosaicML conventions can be uploaded and simulated as well. Simply click the toggle, enter any needed additional information, and see your results. Parameters can also be modified to quickly test out configurations.
-
-
diff --git a/docs/source/getting_started/faqs_and_tips.md b/docs/source/getting_started/faqs_and_tips.md
new file mode 100644
index 000000000..57f41aca7
--- /dev/null
+++ b/docs/source/getting_started/faqs_and_tips.md
@@ -0,0 +1,102 @@
+# 🤔 FAQs and Tips
+
+## ❓ FAQs
+
+### Can I write datasets in parallel? How does this work?
+Yes, you can! Please see the [parallel dataset conversion](../preparing_datasets/parallel_dataset_conversion.ipynb) page for instructions. If you're using Spark, follow the [Spark dataframe to MDS](../preparing_datasets/spark_dataframe_to_mds.ipynb) example.
+
+### Is StreamingDataset's `batch_size` the global or device batch size?
+The `batch_size` argument to StreamingDataset is the *device* batch size. It should be set the same as the DataLoader `batch_size` argument. For optimal performance and deterministic resumption, you must pass `batch_size` to StreamingDataset.
+
+### How can I calculate ingress and egress costs?
+Ingress costs will depend on your GPU provider, but egress costs from cloud storage are equal to the egress costs for a single epoch of training. Streaming is smart about how samples are partitioned, and minimizes duplicate shard downloads between nodes. The egress cost is calculated as:
+
+$$\text{Egress cost} = (\text{Egress cost per MB}) \times (\text{Average shard size in MB}) \times (\text{Total number of shards})$$
+
+For multi-epoch training, if your nodes have persistent storage or if your training job does not experience hardware failures, the egress cost will be the same as a single epoch of training. Otherwise, with ephemeral storage and training failures, you will likely have to redownload shards.
+
+### How can I mix and weight different data sources?
+Mixing data sources is easy, flexible, and can even be controlled at the batch level. The [mixing data sources](../dataset_configuration/mixing_data_sources.md) page shows how you can do this.
+
+### Can I use only a subset of a data source when training for multiple epochs?
+Yes, you can! For example, if your dataset is 1000 samples, but you want to train only on 400 samples per epoch, simply set
+`epoch` size to 400. For more control over how these 400 samples are chosen in each epoch, see the [inter-epoch sampling](../dataset_configuration/replication_and_sampling.md#inter-epoch-sampling) section.
+
+### How can I apply a transformation to each sample?
+StreamingDataset is a subclass of PyTorch's IterableDataset, so applying transforms works the exact same way. See [here](https://pytorch.org/tutorials/beginner/data_loading_tutorial.html) for an example on how to use transforms with PyTorch. Our [CIFAR-10 guide](../how_to_guides/cifar10.ipynb) also has an example of using transforms with StreamingDataset.
+
+### If my dataset is larger than disk, how can I train?
+You can set the per-node cache limit using StreamingDataset's `cache_limit` argument, detailed [here](../dataset_configuration/shard_retrieval.md#cache-limit). When shard usage hits the `cache_limit` Streaming will begin evicting shards.
+
+### I'm seeing loss spikes and divergence on my training runs. How do I fix this?
+Training loss may suffer from loss spikes or divergence for a variety of reasons. Higher quality shuffling and dataset mixing can help mitigate loss variance, divergence, and spikes. First, make sure that `shuffle` is set to `True` in your dataset. If you're already shuffling, you should make your shuffle strength higher. If using a shuffle-block-based shuffling algorithm like [`'py1e'`](../dataset_configuration/shuffling.md#py1e-default), [`'py1br'`](../dataset_configuration/shuffling.md#py1br), or [`'py1b'`](../dataset_configuration/shuffling.md#py1b), increase the `shuffle_block_size` parameter. If using an intra-shard shuffle such as [`'py1s'`](../dataset_configuration/shuffling.md#py1s) or [`'py2s'`](../dataset_configuration/shuffling.md#py2s), increase the `num_canonical_nodes` parameter. Read more about shuffling [here](../dataset_configuration/shuffling.md).
+
+Changing how datasets are mixed can also help with training stability. Specifically, setting `batching_method` to `stratified` when mixing datasets provides consistent dataset mixing in every batch. Read more about dataset mixing [here](../dataset_configuration/mixing_data_sources.md).
+
+### When training for multiple epochs, training takes a long time between epochs. How can I address this?
+Training is likely taking longer between epochs due to DataLoader workers not persisting. Make sure to set `persistent_workers=True` in your DataLoader, which will keep `StreamingDataset` instances alive between epochs. More information can be found [here](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader).
+
+If this still does not address the issue, refer to the [performance tuning page](../distributed_training/performance_tuning.md).
+
+### I'm not seeing deterministic resumption on my training runs. How can I enable this?
+To enable elastic determinism and resumption, you should be using the {class}`streaming.StreamingDataLoader` instead of the generic PyTorch DataLoader. You should also make sure you're passing in `batch_size` to StreamingDataset in addition to your DataLoader. Certain launchers, such as [Composer](https://github.com/mosaicml/composer), support deterministic resumption with StreamingDataset automatically. See the [resumption](../distributed_training/fast_resumption.md) page for more information.
+
+### Is it possible for each global batch to consist only of samples from one Stream?
+Yes -- use the `per_stream` batching method as detailed in the [batching methods](../dataset_configuration/mixing_data_sources.md#batching-methods) section.
+
+### I'm seeing a shared memory error. How can I fix this?
+Streaming uses shared memory to communicate between workers. These errors are indicative of stale shared memory, likely from a previous training run. To fix this, call `python` in your terminal and run the commands below:
+
+```
+>>> import streaming.base.util as util
+>>> util.clean_stale_shared_memory()
+```
+
+### What's the difference between StreamingDataset's `epoch_size`, `__len__()`, and `size()`?
+The `epoch_size` attribute of StreamingDataset is the number of samples per epoch of training. The `__len__()` method returns the `epoch_size` divided by the number of devices -- it is the number of samples seen per device, per epoch. The `size()` method returns the number of unique samples in the underlying dataset. Due to upsampling/downsampling, `size()` may not be the same as `epoch_size`.
+
+### What's the difference between `StreamingDataset` vs. datasets vs. streams?
+`StreamingDataset` is the dataset class. It can take in multiple streams, which are just data sources. It combines these streams into a single dataset. `StreamingDataset` does not *stream* data, as continuous bytes; instead, it downloads shard files to enable a continuous flow of samples into the training job. `StreamingDataset` is an `IterableDataset` as opposed to a map-style dataset -- samples are retrieved as needed.
+
+
+## 🤓 Helpful Tips
+
+### Using locally available datasets
+If your dataset is locally accessible from your GPUs, you only need to specify the `local` argument to StreamingDataset as the path to those shard files. You should leave the `remote` field as `None`.
+
+### Access specific shards and samples
+You can use the `get_item` method of StreamingDataset to access particular samples -- StreamingDataset supports NumPy-style indexing. To further access information at the shard and sample level, the StreamingDataset attributes below are useful:
+
+- `dataset.stream_per_shard`: contains the stream index for each shard.
+- `dataset.shards_per_stream`: contains the number of shards per stream
+- `dataset.samples_per_shard`: contains the number of samples per shard
+- `dataset.samples_per_stream`: contains the number of samples per stream
+- `dataset.spanner`: maps global sample index to the corresponding shard index and relative sample index
+- `dataset.shard_offset_per_stream`: contains the offset of the shard indices for a stream. Can be used to get the shard index in a certain stream from the global shard index.
+- `dataset.prepare_shard(shard_id)`: downloads and extracts samples from shard with `shard_id`
+- `dataset[sample_id]`: retrieves sample with `sample_id`, implicitly downloading the relevant shard.
+
+You can use these in a variety of ways to inspect your dataset. For example, to retrieve the stream index, relative shard index in that stream, and sample index in that shard, for every sample in your dataset, you could do:
+
+```python
+# Instantiate a StreamingDataset however you would like
+dataset = StreamingDataset(
+ ...
+)
+# Retrieves the number of unique samples -- no up or down sampling applied
+num_dataset_samples = dataset.size()
+# Will contain tuples of (stream id, shard id, sample id)
+stream_shard_sample_ids = []
+for global_sample_idx in range(num_dataset_samples):
+ # Go from global sample index -> global shard index and relative sample index (in the shard)
+ global_shard_idx, relative_sample_idx = dataset.spanner[global_sample_idx]
+ # Get the stream index of that shard
+ stream_idx = dataset.stream_per_shard[global_shard_idx]
+ # Get the relative shard index (in the stream) by subtracting the offset
+ relative_shard_idx = global_shard_idx - dataset.shard_offset_per_stream[stream_idx]
+
+ stream_shard_sample_ids.append((stream_idx, relative_shard_idx, relative_sample_idx))
+```
+
+### Don't make your shard file size too large or small
+You can control the maximum file size of your shards with the `size_limit` argument to the `Writer` objects -- for example, in {class}`streaming.MDSWriter`. The default shard size is 67MB, and we see that 50-100MB shards work well across modalities and workloads. If shards are too small, then you will get too many download requests, and if shards are too large, then shard downloads become more expensive and harder to balance.
diff --git a/docs/source/getting_started/installation.md b/docs/source/getting_started/installation.md
deleted file mode 100644
index 97091eacb..000000000
--- a/docs/source/getting_started/installation.md
+++ /dev/null
@@ -1,47 +0,0 @@
-# 💾 Installation
-StreamingDataset installs via pip. Installing StreamingDataset in a virtual environment is recommended to avoid any dependency conflicts. StreamingDataset has been tested on Python 3.9, 3.10, and 3.11.
-
-## Create a virtual environment
-1. Create and navigate to your project directory:
- ```
- mkdir custom-project
- cd custom-project
- ```
-
-2. Create a virtual environment inside your project directory:
- ```
- python -m venv
-
- # For example
- python -m venv venv_streaming
- ```
-
-3. Activate the virtual environment:
- ```
- source venv_streaming/bin/activate
- ```
-
-## Install using pip
-StreamingDataset can be installed using pip as follows:
-```
-pip install mosaicml-streaming
-```
-
-Run the following command to ensure the proper installation of the StreamingDataset. The following command prints the StreamingDataset version.
-```
-python -c "import streaming; print(streaming.__version__)"
-```
-
-## Install from source
-Building and installing StreamingDataset from the source allows you to change the code base.
-```
-git clone https://github.com/mosaicml/streaming.git
-cd streaming
-pip install -e .
-```
-Run the following command to ensure the proper installation of the StreamingDataset. The following command prints the StreamingDataset version.
-```
-python -c "import streaming; print(streaming.__version__)"
-```
-
-That's it! Check out our [Quick Start](quick_start.md) guide on using the StreamingDataset.
diff --git a/docs/source/getting_started/main_concepts.md b/docs/source/getting_started/main_concepts.md
new file mode 100644
index 000000000..d64901ab6
--- /dev/null
+++ b/docs/source/getting_started/main_concepts.md
@@ -0,0 +1,158 @@
+# 🧠 Main Concepts
+
+## Overview
+Training a model with Streaming happens in 2 main steps:
+1. Convert your data to a StreamingDataset-compatible format
+2. Distributed model training with StreamingDataset
+
+Let's cover the key concepts in this process.
+
+## Dataset conversion
+Raw data samples need to be processed into a **stream**, or set of **shard** files, that allow for fast random access during training. Streaming supports the following file formats:
+* MDS (most performant)
+* CSV/TSV
+* JSONL
+
+A **shard** is a file, compatible with Streaming, that contains samples that are ready for training.
+
+A **stream** is a collection of shard files.
+
+The diagram below shows how raw data samples are converted to MDS shards using {class}`streaming.MDSWriter` objects.
+
+
+
+`MDSWriter` objects take in original dataset samples and convert them binary MDS shards, which contain serialized samples. The mapping from original files to shard files is not strict and can be one-to-one, many-to-one, or one-to-many. Each shard has a header that allows for fast random access to every sample during model training.
+
+As shown above, an `index.json` file is also created for the set of shard files, or stream, containing information such as the number of shards, number of samples per shard, shard sizes, etc. An example `index.json` file, which has metadata for multiple MDS shards, and where samples contain only one column called "tokens" encoded as `Bytes`, is structured as below:
+
+```json
+{
+ "shards": [
+ { // Shard 0
+ "column_encodings": ["bytes"],
+ "column_names": ["tokens"],
+ "column_sizes": [null],
+ "compression": null,
+ "format": "mds",
+ "hashes": [],
+ "raw_data": {
+ "basename": "shard.00000.mds",
+ "bytes": 67092637,
+ "hashes": {}
+ },
+ "samples": 4093,
+ "size_limit": 67108864,
+ "version": 2,
+ "zip_data": null
+ },
+ { // Shard 1, very similar to Shard 0 metadata
+ ...
+ "raw_data": {
+ "basename": "shard.00001.mds",
+ "bytes": 67092637,
+ "hashes": {}
+ },
+ ...
+ },
+ // and so on
+ ]
+}
+```
+
+Below, we use `MDSWriter` to write out a stream to a remote location that contains integer columns `'x'` and `'y'` in each sample:
+
+```python
+columns = {'x': 'int', 'y': 'int'}
+output_dir = 's3://path/for/mds/dataset'
+with MDSWriter(out=output_dir, columns=columns) as out:
+ for sample in raw_dataset_iterator:
+ out.write(sample)
+```
+
+Read more about dataset formats [here](../preparing_datasets/dataset_format.md), and about dataset conversion [here](../preparing_datasets/basic_dataset_conversion.md).
+
+## Distributed model training
+StreamingDataset splits up samples between nodes, ranks, and dataloader workers. Shards are downloaded, and samples retrieved from them, during distributed model training.
+
+A **node** is a host CPU system with its own disk space, with multiple GPUs attached, typically 8.
+
+A **rank** in this context is a specialized component for accelerated model training -- typically a GPU.
+
+A **worker** is a process on CPU that handles fetching shards and samples. Assigned to a particular rank on a node.
+
+### Start of training
+StreamingDataset downloads the `index.json` files for input streams, and uses the information they contain to partition samples across nodes, ranks, and workers. The diagram below shows this process:
+
+
+
+Let's understand what's happening here.
+
+#### Remote data streams
+The `index.json` files and shards for multiple streams are stored in the cloud. A Stream typically corresponds to one data source. The overall dataset combines samples from all input streams. Two streams are shown here, but StreamingDataset supports combining any number. Streams can be mixed in various [ways](../dataset_configuration/mixing_data_sources.md), making final dataset composition flexible.
+
+
+
+Below, we pass in a list of {class}`streaming.Stream` objects to a {class}`streaming.StreamingDataset`, and also specify the proportion of the overall dataset we want to take from each stream.
+
+```python
+# Stream 1 uses its own set of shard files and will be 1/4 of the training dataset.
+stream_1 = Stream(
+ remote = 's3://stream_1/directory',
+ local = '/local/cache/stream_1',
+ batch_size = 4,
+ proportion = 0.25,
+)
+# Stream 2 is similar to above, but will be 3/4 of the training dataset.
+stream_2 = Stream(
+ remote = 's3://stream_2/directory',
+ local = '/local/cache/stream_2',
+ batch_size = 4,
+ proportion = 0.75,
+)
+
+# This dataset uses multiple streams.
+dataset = StreamingDataset(streams=[stream_1, stream_2], batch_size=4)
+```
+
+If using a single stream, we could just specify the `remote` and `local` locations directly instead:
+
+```python
+dataset = StreamingDataset(remote='s3://some/path', local='/local/path', batch_size=4)
+```
+
+#### Sample partitioning
+Dataset samples are partitioned among nodes, GPUs, and workers. This partitioning reduces the redundant downloads during training, making it much more performant.
+
+
+
+In the diagram above, we have 2 nodes, 8 GPUs per node, and 2 dataloader workers per GPU. These values can vary depending on the training configuration. The number of nodes and ranks is detected through PyTorch, and the number of workers is passed in through the DataLoader. Zooming into the sample partition for just one GPU, the samples are split up between dataloader workers (2 per GPU above) and grouped together by GPU batch size (4 above).
+
+#### Dataset shuffling
+You can shuffle the samples within each node using one of our specialized [shuffling algorithms](../dataset_configuration/shuffling.md#shuffling-algorithms). Having a shuffled dataset is highly important for ML model training. By partitioning samples among nodes and only shuffling samples intra-node, overall download demand is controlled, since duplicate shard downloads between nodes are minimized.
+
+
+
+Enabling shuffling is as simple as setting `shuffle` to `True` in `StreamingDataset`.
+
+```python
+dataset = StreamingDataset(
+ ...
+ shuffle = True,
+ ...
+)
+```
+
+### Sample retrieval during training
+StreamingDataset retrieves shards and reads samples from them on the fly during training, as it is iterated over. The diagram below shows how this happens:
+
+
+
+Shards are progressively downloaded from the specified `remote` location(s) as needed for training. Dataloader workers in each node’s CPU each access the StreamingDataset sample partition, which tells them the order of samples they need to retrieve, and which shards contain those samples. Workers make samples available for training using the steps below.
+
+1. **Worker sample retrieval:** Each Dataloader worker is responsible for just a part of the entire dataset’s samples. The samples per worker are specified in the StreamingDataset’s partition. The worker checks whether the samples that model training will soon require are present. For a particular sample, the worker checks if the shard containing that sample is present on the node’s disk. If it is not present, the worker proceeds to step 2. If it is present, the worker jumps to step 3.
+2. **Shard download:** Since the shard with the required sample is not present in disk, the worker downloads the shard from remote cloud storage.
+3. **Load to GPU:** The sample is present on disk. The worker loads the sample to GPU for training when required.
+
+For more information on how to control StreamingDataset's shard fetching and storage behavior, see the [shard retrieval](../dataset_configuration/shard_retrieval.md) page.
+
+And that's all! Please look at the API reference for more information on specific objects and parameters.
diff --git a/docs/source/getting_started/quick_start.md b/docs/source/getting_started/quick_start.md
index 28b90e742..24a2c19a0 100644
--- a/docs/source/getting_started/quick_start.md
+++ b/docs/source/getting_started/quick_start.md
@@ -1,9 +1,9 @@
# 🚀 Quick Start
-Start training your model with the Streaming dataset in a few steps!
-
-1. Convert your raw dataset into one of our supported streaming formats, for example, `mds` (Mosaic Data Shard) format.
+Start training your model with Streaming in just a few steps!
+1. Convert your raw dataset into one of our supported file formats. Here, we convert an image dataset to MDS (Mosaic Data Shard) format.
+
```python
import numpy as np
from PIL import Image
@@ -12,8 +12,6 @@ Start training your model with the Streaming dataset in a few steps!
from streaming import MDSWriter
# Local or remote directory path to store the output compressed files.
- # For remote directory, the output files are automatically upload to a remote cloud storage
- # location.
out_root = 'dirname'
# A dictionary of input fields to an Encoder/Decoder type
@@ -26,10 +24,7 @@ Start training your model with the Streaming dataset in a few steps!
# Compression algorithm name
compression = 'zstd'
- # Hash algorithm name
- hashes = 'sha1', 'xxh64'
-
- # Generates random images and classes for input sample
+ # Generate random images and classes
samples = [
{
'uuid': str(uuid4()),
@@ -39,32 +34,62 @@ Start training your model with the Streaming dataset in a few steps!
for _ in range(1000)
]
- # Call `MDSWriter` to iterate through the input data and write into a shard `mds` file
- with MDSWriter(out=out_root, columns=columns, compression=compression, hashes=hashes) as out:
+ # Use `MDSWriter` to iterate through the input data and write to a collection of `.mds` files.
+ with MDSWriter(out=out_root, columns=columns, compression=compression) as out:
for sample in samples:
out.write(sample)
-
- # Clean up
- rmtree(out_root)
```
-2. Replace the original {class}`torch.utils.data.IterableDataset` with your new {class}`streaming.StreamingDataset`.
+2. Replace the original {class}`torch.utils.data.IterableDataset` with your new {class}`streaming.StreamingDataset`. Point it to the dataset written out above, and specify the `batch_size` to StreamingDataset and the DataLoader.
```python
from torch.utils.data import DataLoader
from streaming import StreamingDataset
- # Remote directory (S3 or local filesystem) where dataset is stored
- remote_dir = 's3://datapath'
+ # Remote directory where dataset is stored, from above
+ remote_dir = 's3://path/to/dataset'
- # Local directory where dataset is cached during operation
- local_dir = 'local_dir'
- dataset = StreamingDataset(local=local_dir, remote=remote_dir, split=None, shuffle=True)
+ # Local directory where dataset is cached during training
+ local_dir = '/local/cache/path'
+ dataset = StreamingDataset(local=local_dir, remote=remote_dir, batch_size=1, split=None, shuffle=True)
# Create PyTorch DataLoader
- dataloader = DataLoader(dataset)
+ dataloader = DataLoader(dataset, batch_size=1)
```
-That's it! For additional details on using {mod}`streaming`, please check out our [User Guide](user_guide.md) and [Examples](../examples/cifar10.ipynb).
+That's it! For additional details on using Streaming, check out the [Main Concepts](main_concepts.md) page and [How-to Guides](../how_to_guides/cifar10.ipynb).
+
+We also have starter code for the following popular datasets, which can be found in the `streaming` [directory](https://github.com/mosaicml/streaming/tree/main/streaming):
+
+| Dataset | Task | Read | Write |
+| --- | --- | --- | --- |
+| LAION-400M | Text and image | [Read](https://github.com/mosaicml/diffusion-benchmark/blob/main/data.py) | [Write](https://github.com/mosaicml/streaming/tree/main/streaming/multimodal/convert/laion/laion400m) |
+| WebVid | Text and video | [Read](https://github.com/mosaicml/streaming/blob/main/streaming/multimodal/webvid.py) | [Write](https://github.com/mosaicml/streaming/blob/main/streaming/multimodal/convert/webvid.py) |
+| C4 | Text | [Read](https://github.com/mosaicml/streaming/blob/main/streaming/text/c4.py) | [Write](https://github.com/mosaicml/streaming/blob/main/streaming/text/convert/c4.py) |
+| EnWiki | Text | [Read](https://github.com/mosaicml/streaming/blob/main/streaming/text/enwiki.py) | [Write](https://github.com/mosaicml/streaming/tree/main/streaming/text/convert/enwiki) |
+| Pile | Text | [Read](https://github.com/mosaicml/streaming/blob/main/streaming/text/pile.py) | [Write](https://github.com/mosaicml/streaming/blob/main/streaming/text/convert/pile.py)
+| ADE20K | Image segmentation | [Read](https://github.com/mosaicml/streaming/blob/main/streaming/vision/ade20k.py) | [Write](https://github.com/mosaicml/streaming/blob/main/streaming/vision/convert/ade20k.py)
+| CIFAR10 | Image classification | [Read](https://github.com/mosaicml/streaming/blob/main/streaming/vision/cifar10.py) | [Write](https://github.com/mosaicml/streaming/blob/main/streaming/vision/convert/cifar10.py) |
+| COCO | Image classification | [Read](https://github.com/mosaicml/streaming/blob/main/streaming/vision/coco.py) | [Write](https://github.com/mosaicml/streaming/blob/main/streaming/vision/convert/coco.py) |
+| ImageNet | Image classification | [Read](https://github.com/mosaicml/streaming/blob/main/streaming/vision/imagenet.py) | [Write](https://github.com/mosaicml/streaming/blob/main/streaming/vision/convert/imagenet.py) |
+
+**To start training on these datasets:**
+
+1. Convert raw data into .mds format using the corresponding script from the `convert` directory.
+
+For example:
+
+
+```bash
+$ python -m streaming.multimodal.convert.webvid --in --out
+```
+
+2. Import dataset class to start training the model.
+
+
+```python
+from streaming.multimodal import StreamingInsideWebVid
+dataset = StreamingInsideWebVid(local=local, remote=remote, batch_size=1, shuffle=True)
+```
Happy training!
diff --git a/docs/source/getting_started/user_guide.md b/docs/source/getting_started/user_guide.md
deleted file mode 100644
index a6fe46b76..000000000
--- a/docs/source/getting_started/user_guide.md
+++ /dev/null
@@ -1,176 +0,0 @@
-# 🖼️ User Guide
-
-At a very high level, one needs to convert a raw dataset into streaming format files and then use the same streaming format files using {class}`streaming.StreamingDataset` class for model training.
-
-Streaming supports different dataset writers based on your need for conversion of raw datasets into a streaming format such as
-- {class}`streaming.MDSWriter`: Writes the dataset into `.mds` (Mosaic Data Shard) extension. It supports various encoding/decoding formats(`str`, `int`, `bytes`, `jpeg`, `png`, `pil`, `pkl`, and `json`) which convert the data from that format to bytes and vice-versa.
-- {class}`streaming.CSVWriter`: Writes the dataset into `.csv` (Comma Separated Values) extension. It supports various encoding/decoding formats(`str`, `int`, and `float`) which convert the data from that format to string and vice-versa.
-- {class}`streaming.JSONWriter`: Writes the dataset into `.json` (JavaScript Object Notation) extension. It supports various encoding/decoding formats(`str`, `int`, and `float`).
-- {class}`streaming.TSVWriter`: Writes the dataset into `.tsv` (Tab Separated Values) extension. It supports various encoding/decoding formats(`str`, `int`, and `float`) which convert the data from that format to string and vice-versa.
-- {class}`streaming.XSVWriter`: Writes the dataset into `.xsv` (user defined Separated Values) extension. It supports various encoding/decoding formats(`str`, `int`, and `float`) which convert the data from that format to string and vice-versa.
-
-For more information about writers and their parameters, look at the [API reference doc](../api_reference/streaming.rst).
-
-After the dataset has been converted to one of our streaming formats, one just needs to instantiate the {class}`streaming.StreamingDataset` class by providing the dataset path of the streaming formats and use that dataset object in PyTorch {class}`torch.utils.data.DataLoader` class. For more information about `streaming.StreamingDataset` and its parameters, look at the {class}`streaming.StreamingDataset` API reference doc.
-
-Streaming supports various dataset compression formats (Brotli, Bzip2, Gzip, Snappy, and Zstandard) that reduces downloading time and cloud egress fees. Additionally, Streaming also supports various hashing algorithms (SHA2, SHA3, MD5, xxHash, etc.) that ensures data integrity through cryptographic and non-cryptographic hashing algorithm.
-
-Let's jump right into an example on how to convert a raw dataset into a streaming format and load the same streaming format dataset for model training.
-
-## Writing a dataset to streaming format
-
-This guide shows you how to use your custom StreamingDataset with {class}`streaming.MDSWriter`, but the steps would remain the same for other writers.
-
-The {class}`streaming.MDSWriter` takes the raw dataset and converts it into a sharded `.mds` format for fast data access.
-
-For this tutorial, let's create a Synthetic Classification dataset drawn from a normal distribution that returns a tuple of features and a label.
-
-```python
-import numpy as np
-
-class RandomClassificationDataset:
- """Classification dataset drawn from a normal distribution.
-
- Args:
- shape: shape of features (default: (5, 1, 1))
- size: number of samples (default: 100)
- num_classes: number of classes (default: 2)
- """
-
- def __init__(self, shape=(1, 1, 1), size=100, num_classes=2):
- self.size = size
- self.x = np.random.randn(size, *shape)
- self.y = np.random.randint(0, num_classes, size=(size,))
-
- def __len__(self):
- return self.size
-
- def __getitem__(self, index: int):
- return self.x[index], self.y[index]
-```
-
-There are a few parameters that need to be initialized before {class}`streaming.MDSWriter` gets called. Some of the parameters are optional, and others are required parameters. Let's look at each of them where we start with two required parameters.
-
-1. Provide the local filesystem directory path or a remote cloud provider storage path to store the compressed dataset files. If it is a remote path, the output files are automatically upload to a remote path.
-
-```python
-output_dir = 'test_output_dir'
-```
-
-2. Provide the column field as `Dict[str, str]`, which maps a feature name or label name with a streaming supported encoding type.
-
-```python
-columns = {'x': 'pkl', 'y': 'pkl'}
-```
-
-The below parameters are optional to {class}`streaming.MDSWriter`. Let's look at each one of them
-
-1. Provide a name of a compression algorithm; the default is `None`. Streaming supports families of compression algorithms such as `br`, `gzip`, `snappy`, `zstd`, and `bz2` with the level of compression.
-
-```python
-compression = 'zstd:7'
-```
-
-2. Provide a name of a hashing algorithm; the default is `None`. Streaming supports families of hashing algorithm such as `sha`, `blake`, `md5`, `xxHash`, etc.
-
-```python
-hashes = ['sha1']
-```
-
-3. Provide a shard size limit, after which point to start a new shard.
-
-```python
-# Number act as a byte, e.g., 1024 bytes. A string abbreviation (ex: "1024b" or "1kb") is also acceptable
-limit = 1024
-```
-
-Once the parameters are initialized, the last thing we need is a generator that iterates over the data sample.
-
-```python
-def each(samples):
- """Generator over each dataset sample.
-
- Args:
- samples (list): List of samples of (feature, label).
-
- Yields:
- Sample dicts.
- """
- for x, y in samples:
- yield {
- 'x': x,
- 'y': y,
- }
-```
-
-It's time to call the {class}`streaming.MDSWriter` with the above initialized parameters and write the samples by iterating over a dataset.
-
-```python
-from streaming.base import MDSWriter
-
-dataset = RandomClassificationDataset()
-with MDSWriter(out=output_dir, columns=columns, compression=compression, hashes=hashes, size_limit=limit) as out:
- for sample in each(dataset):
- out.write(sample)
-```
-
-Clean up after ourselves.
-
-```
-from shutil import rmtree
-
-rmtree(output_dir)
-```
-
-Once the dataset has been written, the output directory contains two types of files. The first is an index.json file that contains the metadata of shards and second is the shard files. For example,
-
-```bash
-dirname
-├── index.json
-├── shard.00000.mds.zstd
-└── shard.00001.mds.zstd
-```
-
-## Loading a streaming dataset
-
-After writing a dataset in the streaming format in the previous step and uploading to a cloud object storage as s3, we are ready to start loading the data.
-
-To load the same dataset files that were created in the above steps, create a `CustomDataset` class by inheriting the {class}`streaming.StreamingDataset` class and override the `__getitem__(idx: int)` method to get the samples. The {class}`streaming.StreamingDataset` class requires two mandatory parameters which are `remote` which is a remote directory (S3 or local filesystem) where dataset is stored and `local` which is a local directory where dataset is cached during operation.
-
- ```python
-from streaming import StreamingDataset
-
-class CustomDataset(StreamingDataset):
- def __init__(self, local, remote, batch_size):
- super().__init__(local=local, remote=remote, batch_size=batch_size)
-
- def __getitem__(self, idx: int) -> Any:
- obj = super().__getitem__(idx)
- return obj['x'], obj['y']
- ```
-
-The next step is to Instantiate the `CustomDataset` class with local and remote paths.
-
-```python
-# Local filesystem directory where dataset is cached during operation
-local = '/tmp/cache'
-
-# Remote directory (S3 or local filesystem) where dataset is stored
-remote='s3://mybucket/myfolder'
-
-dataset = CustomDataset(local=local, remote=remote)
-```
-
-The final step is to pass the dataset to PyTorch {class}`torch.utils.data.DataLoader` and use this dataloader to train your model.
-
-```python
-from torch.utils.data import DataLoader
-
-dataloader = DataLoader(dataset=dataset)
-```
-
-You've now seen an in-depth look at how to prepare and use streaming datasets with PyTorch. To continue learning about Streaming, please continue to explore our [examples](../examples/cifar10.ipynb/)!
-
-## Other options
-
-Please look at the API reference page for the complete list of {class}`streaming.StreamingDataset` supporting parameters.
diff --git a/docs/source/how_to_guides/cifar10.ipynb b/docs/source/how_to_guides/cifar10.ipynb
new file mode 100644
index 000000000..a5f6048f3
--- /dev/null
+++ b/docs/source/how_to_guides/cifar10.ipynb
@@ -0,0 +1,590 @@
+{
+ "cells": [
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Image Data: CIFAR10\n",
+ "\n",
+ "In this tutorial, we will demonstrate how to use the streaming CIFAR-10 dataset to train a classification model.\n",
+ "\n",
+ "### Recommended Background\n",
+ "\n",
+ "This tutorial assumes that you're reasonably familiar with the workings of datasets and dataloaders for training deep learning models. In addition, since we'll be building from a computer vision example, familiarity in that area will likely be useful as well.\n",
+ "\n",
+ "If you're already familiar with streaming's dataset classes ([StreamingDataset][streaming_dataset] and [MDSWriter][streaming_dataset_mds_writer]), that's great. If not, you may want to pause while working through the tutorial and look at the docs referenced along the way.\n",
+ "\n",
+ "### Tutorial Goals and Concepts Covered\n",
+ "\n",
+ "The goal of this tutorial is to showcase how to prepare the dataset and use Streaming data loading to train the model. It will consist of a few steps:\n",
+ "\n",
+ "1. Obtaining the dataset\n",
+ "2. Preparing the dataset for streaming\n",
+ "3. Streaming the dataset to the local machine\n",
+ "4. Training a model using these datasets\n",
+ "\n",
+ "Let's get started!\n",
+ "\n",
+ "[streaming_dataset]: https://streaming.docs.mosaicml.com/en/stable/api_reference/generated/streaming.StreamingDataset.html#streaming.StreamingDataset\n",
+ "[streaming_dataset_mds_writer]: https://streaming.docs.mosaicml.com/en/latest/api_reference/generated/streaming.MDSWriter.html"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Setup\n",
+ "\n",
+ "Let's start by making sure the right packages are installed and imported. We need to install the `mosaicml-streaming` package which installs the sufficient dependencies to run this tutorial."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%pip install mosaicml-streaming\n",
+ "# To install from source instead of the last release, comment the command above and uncomment the following one.\n",
+ "# %pip install git+https://github.com/mosaicml/streaming.git\n",
+ "\n",
+ "# (Optional) To upload a streaming dataset to an AWS S3 bucket\n",
+ "%pip install awscli"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import time\n",
+ "import os\n",
+ "import shutil\n",
+ "from typing import Callable, Any, Tuple\n",
+ "\n",
+ "import numpy as np\n",
+ "from tqdm import tqdm\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "import torch.nn.functional as F\n",
+ "from torch.utils.data import DataLoader, Dataset\n",
+ "from torchvision import transforms, models\n",
+ "from torchvision.datasets import CIFAR10"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We'll be using Streaming's `MDSWriter` which writes the dataset in Streaming format and `StreamingDataset` class to load the streaming dataset."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from streaming import MDSWriter, StreamingDataset"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Global settings\n",
+ "\n",
+ "For this tutorial, it makes the most sense to organize our global settings here rather than distribute them throughout the cells in which they're used."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# the location of our dataset\n",
+ "in_root = \"./dataset\"\n",
+ "\n",
+ "# the location of the \"remote\" streaming dataset (`sds`). \n",
+ "# Upload `out_root` to your cloud storage provider of choice.\n",
+ "out_root = \"./sds\"\n",
+ "out_train = \"./sds/train\"\n",
+ "out_test = \"./sds/test\"\n",
+ "\n",
+ "# the location to download the streaming dataset during training\n",
+ "local = './local'\n",
+ "local_train = './local/train'\n",
+ "local_test = './local/test'\n",
+ "\n",
+ "# toggle shuffling in dataloader\n",
+ "shuffle_train = True\n",
+ "shuffle_test = False\n",
+ "\n",
+ "# shard size limit, in bytes\n",
+ "size_limit = 1 << 25\n",
+ "\n",
+ "# training batch size\n",
+ "batch_size = 32 \n",
+ "\n",
+ "# training hardware parameters\n",
+ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
+ "\n",
+ "# number of training epochs\n",
+ "train_epochs = 2 # increase the number of epochs for greater accuracy\n",
+ "\n",
+ "# Hashing algorithm to use for dataset\n",
+ "hashes = ['sha1' ,'xxh64']"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# upload location for the dataset splits (change this if you want to upload to a different location, for example, AWS S3 bucket location)\n",
+ "upload_location = None\n",
+ "\n",
+ "if upload_location is None:\n",
+ " upload_train_location = None\n",
+ " upload_test_location = None\n",
+ "else:\n",
+ " upload_train_location = os.path.join(upload_location, 'train')\n",
+ " upload_test_location = os.path.join(upload_location, 'test')"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Download the CIFAR10 raw dataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Download the CIFAR10 raw dataset using torchvision\n",
+ "train_raw_dataset = CIFAR10(root=in_root, train=True, download=True)\n",
+ "test_raw_dataset = CIFAR10(root=in_root, train=False, download=True)"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Next, we'll make the directories for our binary streaming dataset files."
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Preparing and writing the dataset\n",
+ "\n",
+ "Below, we'll set up the logic for writing our starting dataset to files that can be read using a streaming dataloader.\n",
+ "\n",
+ "For more information on the `MDSWriter` check out the [API reference][api].\n",
+ "\n",
+ "[api]: https://streaming.docs.mosaicml.com/en/latest/api_reference/generated/streaming.MDSWriter.html"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def write_datasets(dataset: Dataset, split_dir: str) -> None:\n",
+ " fields = {\n",
+ " 'x': 'pil',\n",
+ " 'y': 'int',\n",
+ " }\n",
+ " indices = np.random.permutation(len(dataset))\n",
+ " indices = tqdm(indices)\n",
+ " with MDSWriter(out=split_dir, columns=fields, hashes=hashes, size_limit=size_limit) as out:\n",
+ " for i in indices:\n",
+ " x, y = dataset[i]\n",
+ " out.write({\n",
+ " 'x': x,\n",
+ " 'y': y,\n",
+ " })"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now that we've written the datasets to `out_root`, one can upload them to a cloud storage provider, and we are ready to stream them. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "remote_train = upload_train_location or out_train # replace this with your URL for cloud streaming\n",
+ "remote_test = upload_test_location or out_test"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Loading the Data\n",
+ "\n",
+ "We extend StreamingDataset to deserialize the data.\n",
+ "\n",
+ "For more information on the `StreamingDataset` class check out the [API reference](https://streaming.docs.mosaicml.com/en/stable/api_reference/generated/streaming.StreamingDataset.html#streaming.StreamingDataset)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class CIFAR10Dataset(StreamingDataset):\n",
+ " def __init__(self,\n",
+ " remote: str,\n",
+ " local: str,\n",
+ " shuffle: bool,\n",
+ " batch_size: int,\n",
+ " transforms: Callable\n",
+ " ) -> None:\n",
+ " super().__init__(local=local, remote=remote, shuffle=shuffle, batch_size=batch_size)\n",
+ " self.transforms = transforms\n",
+ "\n",
+ " def __getitem__(self, idx:int) -> Any:\n",
+ " obj = super().__getitem__(idx)\n",
+ " x = obj['x']\n",
+ " y = obj['y']\n",
+ " return self.transforms(x), y"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Initialize the data transformation"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "transformation = transforms.Compose([\n",
+ " transforms.ToTensor(),\n",
+ " transforms.Normalize( \n",
+ " (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) \n",
+ " )\n",
+ "])"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Putting It All Together"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We're now ready to actually write the streamable dataset. Let's do that if we haven't already."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "if not os.path.exists(out_train):\n",
+ " write_datasets(train_raw_dataset, out_train)\n",
+ " write_datasets(test_raw_dataset, out_test)"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "(Optional) Upload the Streaming dataset to an AWS S3 bucket of your choice. Uncomment the below line if you have provided the S3 bucket link to `upload_location`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "vscode": {
+ "languageId": "shellscript"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# !aws s3 cp $out_root $upload_location --recursive"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Once that's done, we can instantiate our streaming datasets and wrap them in standard dataloaders for training!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "train_dataset = CIFAR10Dataset(remote_train, local_train, shuffle_train, batch_size=batch_size, transforms=transformation)\n",
+ "test_dataset = CIFAR10Dataset(remote_test, local_test, shuffle_test, batch_size=batch_size, transforms=transformation)\n",
+ "\n",
+ "train_dataloader = DataLoader(train_dataset, batch_size=batch_size)\n",
+ "test_dataloader = DataLoader(test_dataset, batch_size=batch_size)"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Create a model\n",
+ "\n",
+ "We are going to create a Convolutional Neural Network (CNN) classification model in this tutorial. We will be using the CrossEntropyLoss to calculate the loss value and SGD Stochastic Gradient Descent method as the optimizer. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class Net(nn.Module):\n",
+ " def __init__(self):\n",
+ " super().__init__()\n",
+ " self.conv1 = nn.Conv2d(3, 6, 5)\n",
+ " self.pool = nn.MaxPool2d(2, 2)\n",
+ " self.conv2 = nn.Conv2d(6, 16, 5)\n",
+ " self.fc1 = nn.Linear(16 * 5 * 5, 120)\n",
+ " self.fc2 = nn.Linear(120, 84)\n",
+ " self.fc3 = nn.Linear(84, 10)\n",
+ "\n",
+ " def forward(self, x):\n",
+ " x = self.pool(F.relu(self.conv1(x)))\n",
+ " x = self.pool(F.relu(self.conv2(x)))\n",
+ " x = torch.flatten(x, 1) # flatten all dimensions except batch\n",
+ " x = F.relu(self.fc1(x))\n",
+ " x = F.relu(self.fc2(x))\n",
+ " x = self.fc3(x)\n",
+ " return x\n",
+ "\n",
+ "\n",
+ "model = Net()\n",
+ "model = model.to(device)\n",
+ "criterion = nn.CrossEntropyLoss()\n",
+ "optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=5e-4)"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Define a model train function"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def fit(model: nn.Module, train_dataloader: DataLoader) -> Tuple[float, float]:\n",
+ " model.train()\n",
+ " train_running_loss = 0.0\n",
+ " train_running_correct = 0\n",
+ " with tqdm(train_dataloader, unit=\"batch\") as tepoch:\n",
+ " for imgs, labels in tepoch:\n",
+ " imgs = imgs.to(device)\n",
+ " labels = labels.to(device)\n",
+ " optimizer.zero_grad()\n",
+ " labels_hat = model(imgs)\n",
+ " loss = criterion(labels_hat, labels)\n",
+ " train_running_loss += loss.item()\n",
+ " _, preds = torch.max(labels_hat.data, 1)\n",
+ " train_running_correct += (preds == labels).sum().item()\n",
+ " loss.backward()\n",
+ " optimizer.step()\n",
+ "\n",
+ " train_loss = train_running_loss/len(train_dataloader.dataset)\n",
+ " train_accuracy = 100. * train_running_correct/len(train_dataloader.dataset)\n",
+ " \n",
+ " return train_loss, train_accuracy"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Define a model evaluation function"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def eval(model: nn.Module, test_dataloader: DataLoader) -> Tuple[float, float]:\n",
+ " model.eval()\n",
+ " val_running_loss = 0.0\n",
+ " val_running_correct = 0\n",
+ " with tqdm(test_dataloader, unit=\"batch\") as tepoch:\n",
+ " for imgs, labels in tepoch:\n",
+ " imgs = imgs.to(device)\n",
+ " labels = labels.to(device)\n",
+ " output = model(imgs)\n",
+ " loss = criterion(output, labels)\n",
+ " val_running_loss += loss.item()\n",
+ " _, preds = torch.max(output.data, 1)\n",
+ " val_running_correct += (preds == labels).sum().item()\n",
+ " \n",
+ " val_loss = val_running_loss/len(test_dataloader.dataset)\n",
+ " val_accuracy = 100. * val_running_correct/len(test_dataloader.dataset)\n",
+ " \n",
+ " return val_loss, val_accuracy"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Train with the Streaming Dataloaders\n",
+ "\n",
+ "Now all that's left to do is train!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "for epoch in range(train_epochs):\n",
+ " train_epoch_loss, train_epoch_accuracy = fit(model, train_dataloader)\n",
+ " print(f'epoch: {epoch+1}/{train_epochs} Train Loss: {train_epoch_loss:.4f}, Train Acc: {train_epoch_accuracy:.2f}')\n",
+ " val_epoch_loss, val_epoch_accuracy = eval(model, test_dataloader)\n",
+ " print(f'epoch: {epoch+1}/{train_epochs} Val Loss: {val_epoch_loss:.4f}, Val Acc: {val_epoch_accuracy:.2f}')"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Cleanup\n",
+ "\n",
+ "That's it. No need to hang on to the files created by the tutorial..."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "shutil.rmtree(out_root, ignore_errors=True)\n",
+ "shutil.rmtree(in_root, ignore_errors=True)\n",
+ "shutil.rmtree(local, ignore_errors=True)"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "## What next?\n",
+ "\n",
+ "You've now seen an in-depth look at how to prepare and use streaming datasets with PyTorch.\n",
+ "\n",
+ "To continue learning about Streaming, please continue to explore our examples!"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Come get involved with MosaicML!\n",
+ "\n",
+ "We'd love for you to get involved with the MosaicML community in any of these ways:\n",
+ "\n",
+ "### [Star Streaming on GitHub](https://github.com/mosaicml/streaming)\n",
+ "\n",
+ "Help make others aware of our work by [starring Streaming on GitHub](https://github.com/mosaicml/streaming).\n",
+ "\n",
+ "### [Join the MosaicML Slack](https://mosaicml.me/slack)\n",
+ "\n",
+ "Head on over to the [MosaicML slack](https://mosaicml.me/slack) to join other ML efficiency enthusiasts. Come for the paper discussions, stay for the memes!\n",
+ "\n",
+ "### Contribute to Streaming\n",
+ "\n",
+ "Is there a bug you noticed or a feature you'd like? File an [issue](https://github.com/mosaicml/streaming/issues) or make a [pull request](https://github.com/mosaicml/streaming/pulls)!"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3.10.6 ('streaming_py3_10')",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.7"
+ },
+ "orig_nbformat": 4,
+ "vscode": {
+ "interpreter": {
+ "hash": "cb0371d9985d03b7be04a8e8a123b72f0ef8951070c9235d824cee9281d7d420"
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/docs/source/how_to_guides/configure_cloud_storage_credentials.md b/docs/source/how_to_guides/configure_cloud_storage_credentials.md
index b68879cb5..57a29a89d 100644
--- a/docs/source/how_to_guides/configure_cloud_storage_credentials.md
+++ b/docs/source/how_to_guides/configure_cloud_storage_credentials.md
@@ -14,7 +14,7 @@ For an S3 bucket with public access, no additional setup is required, simply spe
### MosaicML platform
-For [MosaicML platform](https://www.mosaicml.com/cloud) users, follow the steps mentioned in the [AWS S3](https://docs.mosaicml.com/projects/mcli/en/latest/resources/secrets/s3.html) MCLI documentation page on how to configure the cloud provider credentials.
+For [MosaicML platform](https://docs.mosaicml.com/projects/mcli/en/latest/) users, follow the steps mentioned in the [AWS S3](https://docs.mosaicml.com/projects/mcli/en/latest/resources/secrets/s3.html) MCLI documentation page on how to configure the cloud provider credentials.
### Others
@@ -139,7 +139,7 @@ Note that even with S3 compatible object stores, URLs should be of the form `s3:
### MosaicML platform
-For [MosaicML platform](https://www.mosaicml.com/cloud) users, follow the steps mentioned in the [Google Cloud Storage](https://docs.mosaicml.com/projects/mcli/en/latest/resources/secrets/gcp.html) MCLI documentation page on how to configure the cloud provider credentials.
+For [MosaicML platform](https://docs.mosaicml.com/projects/mcli/en/latest/) users, follow the steps mentioned in the [Google Cloud Storage](https://docs.mosaicml.com/projects/mcli/en/latest/resources/secrets/gcp.html) MCLI documentation page on how to configure the cloud provider credentials.
### GCP User Auth Credentials Mounted as Environment Variables
@@ -193,7 +193,7 @@ export GOOGLE_APPLICATION_CREDENTIALS='KEY_FILE'
### MosaicML platform
-For [MosaicML platform](https://www.mosaicml.com/cloud) users, follow the steps mentioned in the [Oracle Cloud Storage](https://docs.mosaicml.com/projects/mcli/en/latest/resources/secrets/oci.html) MCLI documentation page on how to configure the cloud provider credentials.
+For [MosaicML platform](https://docs.mosaicml.com/projects/mcli/en/latest/) users, follow the steps mentioned in the [Oracle Cloud Storage](https://docs.mosaicml.com/projects/mcli/en/latest/resources/secrets/oci.html) MCLI documentation page on how to configure the cloud provider credentials.
### Others
@@ -259,7 +259,7 @@ See the [Databricks documentation](https://docs.databricks.com/en/dev-tools/auth
### MosaicML platform
-For [MosaicML platform](https://www.mosaicml.com/cloud) users, follow the steps mentioned in the [Databricks](https://docs.mosaicml.com/projects/mcli/en/latest/resources/secrets/databricks.html) MCLI documentation page on how to configure the credentials.
+For [MosaicML platform](https://docs.mosaicml.com/projects/mcli/en/latest/) users, follow the steps mentioned in the [Databricks](https://docs.mosaicml.com/projects/mcli/en/latest/resources/secrets/databricks.html) MCLI documentation page on how to configure the credentials.
### Others
diff --git a/docs/source/how_to_guides/dataset_conversion_to_mds_format.md b/docs/source/how_to_guides/dataset_conversion_to_mds_format.md
deleted file mode 100644
index 10d045c7c..000000000
--- a/docs/source/how_to_guides/dataset_conversion_to_mds_format.md
+++ /dev/null
@@ -1,46 +0,0 @@
-# Dataset Conversion to MDS Format
-
-If you have not read the [Dataset Format](../fundamentals/dataset_format.md) guide and [Dataset Conversion](../fundamentals/dataset_conversion_guide.md) guide, then we highly recommend you do so before you start.
-
-To use Streaming Dataset we must first convert the dataset from its native format to MosaicML's Streaming Dataset format called Mosaic Dataset Shard (MDS). Once in MDS format, we can access the dataset from the local file system (disk network attached storage, etc.) or object store (GCS, OCS, S3, etc.). From object store, data can be streamed to train deep learning models and it all just works.
-
-## Convert a raw data into MDS format
-
-Let's look at the steps one needs to perform to convert their raw data into an MDS format.
-
-1. Get the raw dataset, either you can download all locally or create an iterator which downloads on the fly.
-2. For the raw dataset, you need some form of iterator which fetches one sample at a time.
-3. Convert the raw sample in the form of `column` field.
-4. Instantiate MDSWriter and call the `write` method to write a raw sample one at a time.
-
-Checkout the [user guide](../getting_started/user_guide.md) section which contains a simplistic example for the data conversion using single process. For multiprocess dataset conversion example, checkout [this](../examples/multiprocess_dataset_conversion.ipynb) tutorial.
-
-
-We've already created conversion scripts that can be used to convert popular public datasets to MDS format. Please see below for usage instructions.
-
-## Spark Dataframe Conversion Examples
-```{include} ../../../streaming/base/converters/README.md
-:start-line: 2
-```
-
-## NLP Dataset Conversion Examples
-
-```{include} ../../../streaming/text/convert/README.md
-:start-line: 8
-```
-
-## Vision Dataset Conversion Examples
-
-```{include} ../../../streaming/vision/convert/README.md
-:start-line: 8
-```
-
-## Multimodal Dataset Conversion Examples
-### [LAION-400M](https://laion.ai/blog/laion-400-open-dataset/)
-```{include} ../../../streaming/multimodal/convert/laion/laion400m/README.md
-:start-line: 8
-```
-### [WebVid](https://m-bain.github.io/webvid-dataset/)
-```{include} ../../../streaming/multimodal/convert/webvid/README.md
-:start-line: 12
-```
diff --git a/docs/source/how_to_guides/synthetic_nlp.ipynb b/docs/source/how_to_guides/synthetic_nlp.ipynb
new file mode 100644
index 000000000..934c591ee
--- /dev/null
+++ b/docs/source/how_to_guides/synthetic_nlp.ipynb
@@ -0,0 +1,545 @@
+{
+ "cells": [
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Text Data: Synthetic NLP\n",
+ "\n",
+ "In this tutorial, we will demonstrate how to create a Synthetic dataset, write a synthetic dataset into a streaming format and use the [StreamingDataset][streaming_dataset] class to load the dataset.\n",
+ "\n",
+ "### Recommended Background\n",
+ "\n",
+ "This tutorial assumes that you're reasonably familiar with the workings of datasets and dataloaders for training deep learning models.\n",
+ "\n",
+ "If you're already familiar with streaming's dataset classes ([Dataset][streaming_dataset] and [MDSWriter][streaming_dataset_mds_writer]), that's great. If not, you may want to pause while working through the tutorial and look at the docs referenced along the way.\n",
+ "\n",
+ "### Tutorial Goals and Concepts Covered\n",
+ "\n",
+ "The goal of this tutorial is to showcase how to prepare the dataset and use Streaming data loading to iterate and fetch the samples. It will consist of a few steps:\n",
+ "\n",
+ "1. Generate a synthetic dataset\n",
+ "2. Preparing the dataset for streaming\n",
+ "3. Streaming the dataset to the local machine\n",
+ "4. Iterate through the dataset and fetch the samples\n",
+ "\n",
+ "Let's get started!\n",
+ "\n",
+ "[streaming_dataset]: https://streaming.docs.mosaicml.com/en/stable/api_reference/generated/streaming.StreamingDataset.html#streaming.StreamingDataset\n",
+ "[streaming_dataset_mds_writer]: https://streaming.docs.mosaicml.com/en/latest/api_reference/generated/streaming.MDSWriter.html"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Setup\n",
+ "\n",
+ "Let's start by making sure the right packages are installed and imported. We need to install the `mosaicml-streaming` package which installs the sufficient dependencies to run this tutorial."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%pip install mosaicml-streaming\n",
+ "# To install from source instead of the last release, comment the command above and uncomment the following one.\n",
+ "# %pip install git+https://github.com/mosaicml/streaming.git\n",
+ "\n",
+ "# (Optional) To upload a streaming dataset to an AWS S3 bucket\n",
+ "%pip install awscli"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import shutil\n",
+ "from typing import Any, Dict, List, Tuple\n",
+ "\n",
+ "import numpy as np\n",
+ "import torch\n",
+ "from torch.utils.data import DataLoader\n",
+ "from tqdm import tqdm"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We'll be using Streaming's `MDSWriter` which writes the dataset in Streaming format and `StreamingDataset` to load the streaming dataset."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from streaming import MDSWriter, StreamingDataset"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Global settings\n",
+ "\n",
+ "For this tutorial, let's import some of the global setting at the start."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# the location of the \"remote\" streaming dataset (`sds`). \n",
+ "# Upload `out_root` to your cloud storage provider of choice. If `out_root` is a cloud provider\n",
+ "# path, shard files are automatically uploaded.\n",
+ "out_root = \"./sds\"\n",
+ "out_train = \"./sds/train\"\n",
+ "out_val = \"./sds/val\"\n",
+ "\n",
+ "# the location to download the streaming dataset during training\n",
+ "local = './local'\n",
+ "local_train = './local/train'\n",
+ "local_val = './local/val'\n",
+ "\n",
+ "# toggle shuffling in dataloader\n",
+ "shuffle_train = True\n",
+ "shuffle_val = False\n",
+ "\n",
+ "# training batch size\n",
+ "batch_size = 512"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# upload location for the dataset splits (change this if you want to upload to a different location, for example, AWS S3 bucket location)\n",
+ "upload_location = None\n",
+ "\n",
+ "if upload_location is None:\n",
+ " upload_train_location = None\n",
+ " upload_val_location = None\n",
+ "else:\n",
+ " upload_train_location = os.path.join(upload_location, 'train')\n",
+ " upload_val_location = os.path.join(upload_location, 'val')"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Create a Synthetic NLP dataset\n",
+ "\n",
+ "In this tutorial, we will be creating a synthetic number-saying dataset, i.e. converting a numbers from digits to words, for example, number `123` would spell as `one hundred twenty three`. The numbers are generated sequentially with a random positive/negative prefix sign.\n",
+ "\n",
+ "Let's import a utility functions to generate those synthetic number-saying dataset."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Word representation of a number\n",
+ "ones = ('zero one two three four five six seven eight nine ten eleven twelve thirteen fourteen ' +\n",
+ " 'fifteen sixteen seventeen eighteen nineteen').split()\n",
+ "\n",
+ "tens = 'twenty thirty forty fifty sixty seventy eighty ninety'.split()\n",
+ "\n",
+ "\n",
+ "def say(i: int) -> List[str]:\n",
+ " \"\"\"Get the word form of a number.\n",
+ "\n",
+ " Args:\n",
+ " i (int): The number.\n",
+ "\n",
+ " Returns:\n",
+ " List[str]: The number in word form.\n",
+ " \"\"\"\n",
+ " if i < 0:\n",
+ " return ['negative'] + say(-i)\n",
+ " elif i <= 19:\n",
+ " return [ones[i]]\n",
+ " elif i < 100:\n",
+ " return [tens[i // 10 - 2]] + ([ones[i % 10]] if i % 10 else [])\n",
+ " elif i < 1_000:\n",
+ " return [ones[i // 100], 'hundred'] + (say(i % 100) if i % 100 else [])\n",
+ " elif i < 1_000_000:\n",
+ " return say(i // 1_000) + ['thousand'] + (say(i % 1_000) if i % 1_000 else [])\n",
+ " elif i < 1_000_000_000:\n",
+ " return say(i // 1_000_000) + ['million'] + (say(i % 1_000_000) if i % 1_000_000 else [])\n",
+ " else:\n",
+ " assert False\n",
+ "\n",
+ "def get_numbers(num_train: int, num_val: int) -> Tuple[List[int], List[int]]:\n",
+ " \"\"\"Get two non-overlapping splits of a sequential random numbers.\n",
+ "\n",
+ " The train sample indices goes from [0, num_train] and val sample indices goes \n",
+ " from [num_train, num_val].\n",
+ "\n",
+ " Args:\n",
+ " num_train (int): Number of training samples.\n",
+ " num_val (int): Number of validation samples.\n",
+ "\n",
+ " Returns:\n",
+ " Tuple[List[int], List[int]]: The two generated splits.\n",
+ " \"\"\"\n",
+ " total = num_train + num_val\n",
+ " numbers = []\n",
+ " bar = tqdm(total=total, leave=False)\n",
+ " i = 0\n",
+ " while i < total:\n",
+ " was = len(numbers)\n",
+ " sign = (np.random.random() < 0.8) * 2 - 1\n",
+ " numbers.append(sign * i)\n",
+ " bar.update(len(numbers) - was)\n",
+ " i += 1\n",
+ " return numbers[:num_train], numbers[num_train:]\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Initialize a method to generate a train and validation samples where each sample is a dictionary with attributes `{'number': , 'words': }`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def generate_samples(numbers: List[int]) -> List[Dict[str, Any]]:\n",
+ " \"\"\"Generate samples from a list of numbers.\n",
+ "\n",
+ " Args:\n",
+ " numbers (List[int]): The numbers.\n",
+ "\n",
+ " Returns:\n",
+ " List[Dict[str, Any]]: The corresponding samples.\n",
+ " \"\"\"\n",
+ " samples = []\n",
+ " for num in numbers:\n",
+ " words = ' '.join(say(num))\n",
+ " sample = {'number': num, 'words': words}\n",
+ " samples.append(sample)\n",
+ " return samples\n",
+ "\n",
+ "\n",
+ "def get_dataset(num_train: int, num_val: int) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:\n",
+ " \"\"\"Generate a number-saying dataset of the given size.\n",
+ "\n",
+ " Args:\n",
+ " num_train (int): Number of training samples.\n",
+ " num_val (int): Number of validation samples.\n",
+ "\n",
+ " Returns:\n",
+ " Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: The two generated splits.\n",
+ " \"\"\"\n",
+ " train_nums, val_nums = get_numbers(num_train, num_val)\n",
+ " train_samples = generate_samples(train_nums)\n",
+ " val_samples = generate_samples(val_nums)\n",
+ " return train_samples, val_samples"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Create a non-overlapping `train` and `val` split dataset of unique random numbers."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Number of training and validation samples\n",
+ "num_train_samples = 10_000 # 10k samples\n",
+ "num_val_samples = 2000 # 2k samples\n",
+ "\n",
+ "# Create the samples.\n",
+ "print(f'Generating synthetic dataset ({num_train_samples} train, {num_val_samples} val)...')\n",
+ "train_samples, val_samples = get_dataset(num_train_samples, num_val_samples)\n",
+ "splits = [\n",
+ " ('train', train_samples),\n",
+ " ('val', val_samples)\n",
+ "]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Let's inspect the first train and test sample."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print(f'Train sample: {train_samples[0]}')\n",
+ "print(f'Val sample: {val_samples[0]}')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Convert the dataset to MosaicML Streaming\n",
+ "\n",
+ "We are going to use the `MDSWriter` to convert the raw synthetic NLP dataset into a `.mds` file format.\n",
+ "\n",
+ "For more information on the Streaming `MDSWriter` class check out the [API reference](https://streaming.docs.mosaicml.com/en/latest/api_reference/generated/streaming.MDSWriter.html)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Mapping of sample keyword with their data type\n",
+ "columns = {\n",
+ " 'number': 'int',\n",
+ " 'words': 'str',\n",
+ "}\n",
+ "\n",
+ "# Compression algorithm to use for dataset\n",
+ "compression = 'zstd:12'\n",
+ "\n",
+ "# Hashing algorithm to use for dataset\n",
+ "hashes = ['sha1', 'xxh3_64']\n",
+ "\n",
+ "# shard size limit, in bytes\n",
+ "size_limit = 1 << 16 # Override to a small number for more shards.\n",
+ "\n",
+ "print(f'Saving dataset (to {out_root})...')\n",
+ "for split, samples in splits:\n",
+ " print(f'* {split}')\n",
+ " dirname = os.path.join(out_root, split)\n",
+ " with MDSWriter(out=dirname, columns=columns, compression=compression, \n",
+ " hashes=hashes, size_limit=size_limit) as out:\n",
+ " for sample in tqdm(samples, leave=False):\n",
+ " out.write(sample)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now that we've written the datasets to `out_root`, one can upload them to a cloud storage provider, and we are ready to stream them. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "remote_train = upload_train_location or out_train # replace this with your URL for cloud streaming\n",
+ "remote_val = upload_val_location or out_val"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Loading the Data\n",
+ "\n",
+ "We extend Streaming's Dataset to deserialize the data. Let's verify that iterating over the `StreamingDataset` class gives us the exact raw samples in the same deterministic sample order.\n",
+ "\n",
+ "Note that `StreamingDataset` requires passing in a `batch_size` parameter for iteration. This `batch_size` is per-device, and should be the same as the `DataLoader` batch size.\n",
+ "\n",
+ "For more information on the `StreamingDataset` class check out the [API reference](https://streaming.docs.mosaicml.com/en/stable/api_reference/generated/streaming.StreamingDataset.html#streaming.StreamingDataset)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Load the samples back.\n",
+ "print('Walking the dataset:')\n",
+ "\n",
+ "print(f'verifying samples for train split')\n",
+ "train_dataset = StreamingDataset(remote=upload_location or out_root, local=local, batch_size=batch_size, split='train', shuffle=False)\n",
+ "for old, new in tqdm(zip(train_samples, train_dataset), total=len(train_samples), leave=False):\n",
+ " assert old == new\n",
+ "\n",
+ "print(f'verifying samples for val split')\n",
+ "val_dataset = StreamingDataset(remote=upload_location or out_root, local=local, batch_size=batch_size, split='val', shuffle=False)\n",
+ "for old, new in tqdm(zip(val_samples, val_dataset), total=len(val_samples), leave=False):\n",
+ " assert old == new\n"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We can also visualize the sample(s) by doing pythonic or NumPy indexing on a `StreamingDataset`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Fetch the 10th sample and print it on a console\n",
+ "print(f'Sample 10: {train_dataset[10]}')\n",
+ "\n",
+ "# Fetch multiple samples\n",
+ "indices = [-1, 30, [12, -14], slice(-1, -10, -2), np.array([10, -20])]\n",
+ "for indx in indices:\n",
+ " print(f'Sample {indx}: {train_dataset[indx]}')"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Below are some utility methods about the dataset which would be highly useful for debugging and model training. For more information on the `StreamingDataset` parameters, check out the [API reference](https://streaming.docs.mosaicml.com/en/stable/api_reference/generated/streaming.StreamingDataset.html#streaming.StreamingDataset)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Get the total number of samples\n",
+ "print(f'Total number of samples: {train_dataset.num_samples}')\n",
+ "\n",
+ "# Get the number of shard files\n",
+ "print(f'Total number of shards: {len(train_dataset.shards)}')\n",
+ "\n",
+ "# Get the number of samples inside each shard files.\n",
+ "# Number of samples in each shard can vary based on each sample size.\n",
+ "print(f'Number of samples inside each shards: {train_dataset.samples_per_shard}')"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We can now wrap our streaming datasets in a standard PyTorch dataloaders for training!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "train_dataloader = DataLoader(train_dataset, batch_size=batch_size)\n",
+ "val_dataloader = DataLoader(val_dataset, batch_size=batch_size)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Cleanup\n",
+ "\n",
+ "That's it. No need to hang on to the files created by the tutorial..."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "shutil.rmtree(out_root, ignore_errors=True)\n",
+ "shutil.rmtree(local, ignore_errors=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "## What next?\n",
+ "\n",
+ "You've now seen an in-depth look at how to prepare and use streaming datasets with PyTorch.\n",
+ "\n",
+ "To continue learning about Streaming, please continue to explore our examples!"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Come get involved with MosaicML!\n",
+ "\n",
+ "We'd love for you to get involved with the MosaicML community in any of these ways:\n",
+ "\n",
+ "### [Star Streaming on GitHub](https://github.com/mosaicml/streaming)\n",
+ "\n",
+ "Help make others aware of our work by [starring Streaming on GitHub](https://github.com/mosaicml/streaming).\n",
+ "\n",
+ "### [Join the MosaicML Slack](https://mosaicml.me/slack)\n",
+ "\n",
+ "Head on over to the [MosaicML slack](https://mosaicml.me/slack) to join other ML efficiency enthusiasts. Come for the paper discussions, stay for the memes!\n",
+ "\n",
+ "### Contribute to Streaming\n",
+ "\n",
+ "Is there a bug you noticed or a feature you'd like? File an [issue](https://github.com/mosaicml/streaming/issues) or make a [pull request](https://github.com/mosaicml/streaming/pulls)!"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3.10.6 ('streaming_py3_10')",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.4"
+ },
+ "orig_nbformat": 4,
+ "vscode": {
+ "interpreter": {
+ "hash": "cb0371d9985d03b7be04a8e8a123b72f0ef8951070c9235d824cee9281d7d420"
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/docs/source/index.md b/docs/source/index.md
index c381b8911..3b8de699e 100644
--- a/docs/source/index.md
+++ b/docs/source/index.md
@@ -2,85 +2,97 @@
# Streaming
-StreamingDataset helps to make training on large datasets from cloud storage as fast, cheap, and scalable as possible. It’s specially designed for multi-node, distributed training for large models—maximizing correctness guarantees, performance, and ease of use. Now, you can efficiently train anywhere, independent of your training data location. Just stream in the data you need, when you need it.
+StreamingDataset makes training on large datasets from cloud storage as fast, cheap, and scalable as possible. It’s specially designed for multi-node, distributed training of large models—maximizing correctness guarantees, performance, flexibility, and ease of use. Now, you can efficiently train anywhere, independent of where your dataset lives. Just train on the data you need, right when you need it.
-StreamingDataset is compatible with any data type, including **images, text, video, and multimodal data**. With support for major cloud storage providers ([AWS](https://aws.amazon.com/s3/), [OCI](https://www.oracle.com/cloud/storage/object-storage/), [GCS](https://cloud.google.com/storage), [Azure](https://azure.microsoft.com/en-us/products/storage/blobs), and any S3 compatible object store such as [Cloudflare R2](https://www.cloudflare.com/products/r2/), [Coreweave](https://docs.coreweave.com/storage/object-storage), [Backblaze b2](https://www.backblaze.com/b2/cloud-storage.html), etc. ) and designed as a drop-in replacement for your PyTorch [IterableDataset](https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset) class, StreamingDataset seamlessly integrates into your existing training workflows.
+StreamingDataset is compatible with any data type, including **images, text, video, and multimodal data**. With support for major cloud storage providers ([AWS](https://aws.amazon.com/s3/), [OCI](https://www.oracle.com/cloud/storage/object-storage/), [GCS](https://cloud.google.com/storage), [Azure](https://azure.microsoft.com/en-us/products/storage/blobs), [Databricks UC Volume](https://docs.databricks.com/en/sql/language-manual/sql-ref-volumes.html), and any S3 compatible object store such as [Cloudflare R2](https://www.cloudflare.com/products/r2/), [Coreweave](https://docs.coreweave.com/storage/object-storage), [Backblaze b2](https://www.backblaze.com/b2/cloud-storage.html), etc.) and designed as a drop-in replacement for your PyTorch [IterableDataset](https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset) class, StreamingDataset seamlessly integrates into your existing training workflows.
```python
from torch.utils.data import DataLoader
from streaming import StreamingDataset
-dataloader = DataLoader(dataset=StreamingDataset(remote='s3://...'))
+dataloader = DataLoader(dataset=StreamingDataset(remote='s3://...', batch_size=1))
```
-
-
+## **💾** Installation
+1. Set up your Python development environment.
+2. Install Streaming with `pip`:
+```
+pip install mosaicml-streaming
+```
+3. Verify the installation with:
+```
+python -c "import streaming; print(streaming.__version__)"
+```
+4. Jump to our [Quick Start](getting_started/quick_start.md) and [Main Concepts](getting_started/main_concepts.md) guides.
## **🔑** Key Features
-- **True Determinism**: Samples are in the same order regardless of the number of GPUs, nodes, or CPU workers. This makes it easier to reproduce and debug training runs and loss spikes and load a checkpoint trained on 64 GPUs and debug on 8 GPUs with reproducibility.
-- **Instant Mid-Epoch Resumption**: Resume training in seconds, not hours, in the middle of a long training run. Minimizing resumption latency can save thousands of dollars in egress fees and idle GPU compute time compared to existing solutions.
-- **High throughput**: Our MDS format cuts extraneous work to the bone, resulting in ultra-low sample latency and higher throughput compared to alternatives for workloads bottlenecked by the dataloader.
-- **Equal Convergence**: Model convergence from using StreamingDataset is just as good as using local disk, thanks to our shuffling algorithm. StreamingDataset shuffles across all samples assigned to a node, whereas alternative solutions only shuffle samples in a smaller pool (within a single process).
-- **Random access**: Access the data you need when you need it. Even if a sample isn’t downloaded yet, you can access `dataset[i]` to get sample `i`.
-- **Numpy style indexing**: Fetch data on the fly by providing a NumPy style indexing to `StreamingDataset`.
-- **Seamless data mixing**: During streaming, the different datasets are streamed, shuffled, and mixed seamlessly just-in-time.
-- **Disk usage limits**: Dynamically delete least recently used shards in order to keep disk usage under a specified limit.
-
-To get started, please checkout our [Quick Start](getting_started/quick_start.md) and [User Guide](getting_started/user_guide.md).
+- **Elastic Determinism**: Samples are in the same order regardless of the number of GPUs, nodes, or CPU workers. This makes it simple to reproduce and debug training runs and loss spikes. You can load a checkpoint trained on 64 GPUs and debug on 8 GPUs with complete reproducibility. Read more [here](distributed_training/elastic_determinism.md).
+- **Instant Mid-Epoch Resumption**: Resume training in seconds, not hours, in the middle of a long training run. Minimizing resumption latency saves thousands of dollars in egress fees and idle GPU compute time compared to existing solutions. Read more [here](distributed_training/fast_resumption.md)
+- **High throughput**: Our MDS format cuts extraneous work to the bone, resulting in ultra-low sample retrieval latency and higher throughput compared to alternatives.
+- **Effective Shuffling**: Model convergence using StreamingDataset is just as good as using local disk, thanks to our [specialized shuffling algorithms](dataset_configuration/shuffling.md#shuffling-algorithms). StreamingDataset's shuffling reduces egress costs, preserves shuffle quality, and runs efficiently, whereas alternative solutions force tradeoffs between these factors.
+- **Random access**: Access samples right when you need them -- simply call `dataset[i]` to get sample `i`. You can also fetch data on the fly by providing NumPy style indexing to `StreamingDataset`.
+- **Flexible data mixing**: During streaming, different data sources are shuffled and mixed seamlessly just-in-time. Control how datasets are combined using our [batching](dataset_configuration/mixing_data_sources.md#batching-methods) and [sampling](dataset_configuration/replication_and_sampling.md#inter-epoch-sampling) methods.
+- **Disk usage limits**: Dynamically delete least recently used shards in order to keep disk usage under a specified limit. Read more [here](dataset_configuration/shard_retrieval.md#configure-shard-storage)
+- **Parallelism-aware**: Easily train with data parallelism, sequence parallelism, or tensor parallelism -- the right samples end up in the right GPUs, using sample [replication](dataset_configuration/replication_and_sampling.md#replication).
## Community
-Streaming is part of the broader Machine Learning community, and we welcome any contributions, pull requests, and issues.
+Streaming is part of the broader ML/AI community, and we welcome any contributions, pull requests, and issues.
-If you have any questions, please feel free to reach out to us on [Twitter](https://twitter.com/mosaicml),
+If you have any questions, please feel free to reach out to us on [Twitter](https://twitter.com/DbrxMosaicAI),
[Email](mailto:community%40mosaicml.com), or [Slack](https://mosaicml.me/slack)!
```{eval-rst}
.. toctree::
:hidden:
:maxdepth: 1
- :caption: Getting Started
+ :caption: Overview
- getting_started/installation.md
getting_started/quick_start.md
- getting_started/user_guide.md
+ getting_started/main_concepts.md
+ getting_started/faqs_and_tips.md
.. toctree::
:hidden:
:maxdepth: 1
- :caption: Fundamentals
-
- fundamentals/dataset_format.md
- fundamentals/dataset_conversion_guide.md
- fundamentals/compression.md
- fundamentals/hashing.md
- fundamentals/environments.md
- fundamentals/shuffling.md
- fundamentals/sampling.md
- fundamentals/batching.md
- fundamentals/parallelism.md
- fundamentals/simulator.md
+ :caption: Preparing Datasets
+
+ preparing_datasets/dataset_format.md
+ preparing_datasets/basic_dataset_conversion.md
+ preparing_datasets/parallel_dataset_conversion.ipynb
+ preparing_datasets/spark_dataframe_to_mds.ipynb
.. toctree::
:hidden:
:maxdepth: 1
- :caption: How-to Guides
+ :caption: Dataset Configuration
- how_to_guides/configure_cloud_storage_credentials.md
- how_to_guides/dataset_conversion_to_mds_format.md
+ dataset_configuration/shard_retrieval.md
+ dataset_configuration/shuffling.md
+ dataset_configuration/mixing_data_sources.md
+ dataset_configuration/replication_and_sampling.md
+
+.. toctree::
+ :hidden:
+ :maxdepth: 1
+ :caption: Distributed Training
+
+ distributed_training/requirements.md
+ distributed_training/using_launchers.md
+ distributed_training/elastic_determinism.md
+ distributed_training/fast_resumption.md
+ distributed_training/performance_tuning.md
.. toctree::
:hidden:
:maxdepth: 1
- :caption: Examples
+ :caption: How-to Guides
- examples/cifar10.ipynb
- examples/facesynthetics.ipynb
- examples/synthetic_nlp.ipynb
- examples/multiprocess_dataset_conversion.ipynb
- examples/spark_dataframe_to_MDS.ipynb
+ how_to_guides/configure_cloud_storage_credentials.md
+ how_to_guides/cifar10.ipynb
+ how_to_guides/synthetic_nlp.ipynb
.. toctree::
:hidden:
diff --git a/docs/source/preparing_datasets/basic_dataset_conversion.md b/docs/source/preparing_datasets/basic_dataset_conversion.md
new file mode 100644
index 000000000..58b214078
--- /dev/null
+++ b/docs/source/preparing_datasets/basic_dataset_conversion.md
@@ -0,0 +1,337 @@
+# Basic Dataset Conversion
+
+This guide covers how to convert your raw data to MDS format using {class}`streaming.MDSWriter`. Writing to other supported shard formats is very similar. Read more about dataset shard formats in the [Dataset Format](dataset_format.md) guide. For a high-level explanation of how dataset writing works, check out the [main concepts](../getting_started/main_concepts.md#Dataset-conversion) page.
+
+## Configuring dataset writing
+
+Use {class}`streaming.MDSWriter` to convert raw data to MDS format. MDSWriter is like a native file writer; instead of writing the content line by line, MDSWriter writes the data sample by sample. It writes the data into shard files in a sequential manner (for example, `shard.00000.mds`, then `shard.00001.mds`, and so on). Configure {class}`streaming.MDSWriter` according to your requirements with the parameters below:
+
+1. An `out` parameter is an output directory to save shard files. The `out` directory can be specified in three ways:
+ * **Local path**: Shard files are stored locally.
+ * **Remote path**: A local temporary directory is created to cache the shard files, and when shard creation is complete, they are uploaded to the remote location.
+ * **`(local_dir, remote_dir)` tuple**: Shard files are saved in the specified `local_dir` and uploaded to `remote_dir`.
+
+
+```python
+out = '/local/data'
+out = 's3://bucket/data' # Will create a temporary local dir
+out = ('/local/data', 'oci://bucket/data')
+```
+
+2. The optional `keep_local` parameter controls if you would like to keep the shard files locally after they have been uploaded to a remote cloud location. To save local disk space, this defaults to `False`.
+
+3. A `column` parameter is a `dict` mapping a feature name or label name with a streaming supported encoding type. `MDSWriter` encodes your data to bytes, and at training time, data gets decoded back automatically to its original form. The `index.json` file stores `column` metadata for decoding. Supported encoding formats are:
+
+| Category | Name | Class | Notes |
+|--------------------|---------------|--------------|--------------------------|
+| Encoding | 'bytes' | `Bytes` | no-op encoding |
+| Encoding | 'str' | `Str` | stores in UTF-8 |
+| Encoding | 'int' | `Int` | Python `int`, uses `numpy.int64` for encoding |
+| Numpy Array | 'ndarray:dtype:shape' | `NDArray(dtype: Optional[str] = None, shape: Optional[Tuple[int]] = None)` | uses `numpy.ndarray` |
+| Numpy Unsigned Int | 'uint8' | `UInt8` | uses `numpy.uint8` |
+| Numpy Unsigned Int | 'uint16' | `UInt16` | uses `numpy.uint16` |
+| Numpy Unsigned Int | 'uint32' | `Uint32` | uses `numpy.uint32` |
+| Numpy Unsigned Int | 'uint64' | `Uint64` | uses `numpy.uint64` |
+| Numpy Signed Int | 'int8' | `Int8` | uses `numpy.int8` |
+| Numpy Signed Int | 'int16' | `Int16` | uses `numpy.int16` |
+| Numpy Signed Int | 'int32' | `Int32` | uses `numpy.int32` |
+| Numpy Signed Int | 'int64' | `Int64` | uses `numpy.int64` |
+| Numpy Float | 'float16' | `Float16` | uses `numpy.float16` |
+| Numpy Float | 'float32' | `Float32` | uses `numpy.float32` |
+| Numpy Float | 'float64' | `Float64` | uses `numpy.float64` |
+| Numerical String | 'str_int' | `StrInt` | stores in UTF-8 |
+| Numerical String | 'str_float' | `StrFloat` | stores in UTF-8 |
+| Numerical String | 'str_decimal' | `StrDecimal` | stores in UTF-8 |
+| Image | 'pil' | `PIL` | raw PIL image class ([link]((https://pillow.readthedocs.io/en/stable/reference/Image.html))) |
+| Image | 'jpeg' | `JPEG` | PIL image as JPEG |
+| Image | 'png' | `PNG` | PIL image as PNG |
+| Pickle | 'pkl' | `Pickle` | arbitrary Python objects |
+| JSON | 'json' | `JSON` | arbitrary data as JSON |
+
+Here's an example where the field `x` is an image, and `y` is a class label, as an integer.
+
+```python
+column = {
+ 'x': 'jpeg',
+ 'y': 'int8',
+}
+```
+
+If the data type you need is not listed in the above table, then you can write your own data type class with `encode` and `decode` methods in it and patch it inside streaming. For example, let's say, you wanted to add a `complex128` data type (64 bits each for real and imaginary parts):
+
+
+```python
+import numpy as np
+from typing import Any
+
+from streaming.base.format.mds.encodings import Encoding, _encodings
+
+class Complex128(Encoding):
+
+ def encode(self, obj: Any) -> bytes:
+ return np.complex128(obj).tobytes()
+
+ def decode(self, data: bytes) -> Any:
+ return np.frombuffer(data, np.complex128)[0]
+
+_encodings['complex128'] = Complex128
+```
+
+4. An optional shard `size_limit`, in bytes, for each *uncompressed* shard file. This defaults to 67 MB. Specify this as a number of bytes, either directly as an `int`, or a human-readable suffix:
+
+
+```python
+size_limit = 1024 # 1kB limit, as an int
+size_limit = '1kb' # 1kB limit, as a human-readable string
+```
+Shard file size depends on the dataset size, but generally, too small of a shard size creates a ton of shard files and heavy network overheads, and too large of a shard size creates fewer shard files, but downloads are less balanced. A shard size of between 50-100MB works well in practice.
+
+5. An optional `compression` algorithm name (and level) if you would like to compress the shard files. This can reduce egress costs during training. StreamingDataset will uncompress shard files upon download during training. You can control whether to keep compressed shard files locally during training with the `keep_zip` flag -- more information [here](../dataset_configuration/shard_retrieval.md#Keeping-compressed-shards).
+
+Supported compression algorithms:
+
+| Name | Code | Min Level | Default Level | Max Level |
+| --------------------------------------------- | ------ | --------- | ------------- | --------- |
+| [Brotli](https://github.com/google/brotli) | br | 0 | 11 | 11 |
+| [Bzip2](https://sourceware.org/bzip2/) | bz2 | 1 | 9 | 9 |
+| [Gzip](https://www.gzip.org/) | gz | 0 | 9 | 9 |
+| [Snappy](https://github.com/google/snappy) | snappy | – | – | – |
+| [Zstandard](https://github.com/facebook/zstd) | zstd | 1 | 3 | 22 |
+
+The compression algorithm to use, if any, is specified by passing `code` or `code:level` as a string. For example:
+
+
+```python
+compression = 'zstd' # zstd, defaults to level 3.
+compression = 'zstd:9' # zstd, specifying level 9.
+```
+The higher the level, the higher the compression ratio. However, using higher compression levels will impact the compression speed. In our experience, `zstd` is optimal over the time-size Pareto frontier. Compression is most beneficial for text, whereas it is less helpful for other modalities like images.
+
+6. An optional `hashes` list of algorithm names, used to verify data integrity. Hashes are saved in the `index.json` file. Hash verification during training is controlled with the `validate_hash` argument more information [here](../dataset_configuration/shard_retrieval.md#Hash-validation).
+
+Available cryptographic hash functions:
+
+| Hash | Digest Bytes |
+| ---------- | ------------ |
+| 'blake2b' | 64 |
+| 'blake2s' | 32 |
+| 'md5' | 16 |
+| 'sha1' | 20 |
+| 'sha224' | 28 |
+| 'sha256' | 32 |
+| 'sha384' | 48 |
+| 'sha512' | 64 |
+| 'sha3_224' | 28 |
+| 'sha3_256' | 32 |
+| 'sha3_384' | 48 |
+| 'sha3_512' | 64 |
+
+Available non-cryptographic hash functions:
+
+| Hash | Digest Bytes |
+| ---------- | ------------ |
+| 'xxh32' | 4 |
+| 'xxh64' | 8 |
+| 'xxh128' | 16 |
+| 'xxh3_64' | 8 |
+| 'xxh3_128' | 16 |
+
+As an example:
+
+
+```python
+hashes = ['sha256', 'xxh64']
+```
+
+## Example: Writing a dataset to MDS format
+
+Let's put it all together with an example. Here, we create a synthetic classification dataset that returns a tuple of features and a label.
+
+```python
+import numpy as np
+
+class RandomClassificationDataset:
+ """Classification dataset drawn from a normal distribution.
+
+ Args:
+ shape: data sample dimensions (default: (10,))
+ size: number of samples (default: 10000)
+ num_classes: number of classes (default: 2)
+ """
+
+ def __init__(self, shape=(10,), size=10000, num_classes=2):
+ self.size = size
+ self.x = np.random.randn(size, *shape)
+ self.y = np.random.randint(0, num_classes, size)
+
+ def __len__(self):
+ return self.size
+
+ def __getitem__(self, index: int):
+ return self.x[index], self.y[index]
+```
+
+Here, we write shards to a local directory. You can specify a remote path as well.
+
+```python
+output_dir = 'test_output_dir'
+```
+
+Specify the column encoding types for each sample and label:
+
+```python
+columns = {'x': 'pkl', 'y': 'int64'}
+```
+
+Optionally, specify a compression algorithm and level:
+
+```python
+compression = 'zstd:7' # compress shards with ZStandard, level 7
+```
+
+Optionally, specify a list of hash algorithms for verification:
+
+```python
+hashes = ['sha1'] # Use only SHA1 hashing on each shard
+```
+
+Optionally, provide a shard size limit, after which a new shard starts. In this small example, we use 10kb, but for production datasets 50-100MB is more appropriate.
+
+```python
+# Here we use a human-readable string, but we could also
+# pass in an int specifying the number of bytes.
+limit = '10kb'
+```
+
+It's time to call the {class}`streaming.MDSWriter` with the above initialized parameters and write the samples by iterating over a dataset.
+
+```python
+from streaming.base import MDSWriter
+
+dataset = RandomClassificationDataset()
+with MDSWriter(out=output_dir, columns=columns, compression=compression, hashes=hashes, size_limit=limit) as out:
+ for x, y in dataset:
+ out.write({'x': x, 'y': y})
+```
+
+Clean up after ourselves.
+
+```
+from shutil import rmtree
+
+rmtree(output_dir)
+```
+
+Once the dataset has been written, the output directory contains an index.json file that contains shard metadata, the shard files themselves. For example,
+
+```bash
+dirname
+├── index.json
+├── shard.00000.mds.zstd
+└── shard.00001.mds.zstd
+```
+
+## Example: Writing `ndarray`s to MDS format
+
+Here, we show how to write `ndarray`s to MDS format in three ways:
+1. dynamic shape and dtype
+2. dynamic shape but fixed dtype
+3. fixed shape and dtype
+
+Serializing ndarrays with fixed dtype and shape is more efficient than fixed dtype and dynamic shape, which is in turn more efficient than dynamic dtype and shape.
+
+### Dynamic shape, dynamic dtype
+
+The streaming encoding type, as the value in the `columns` dict, should simply be `ndarray`.
+
+```python
+import numpy as np
+from streaming.base import MDSWriter, StreamingDataset
+# Write to MDS
+with MDSWriter(out='my_dataset1/',
+ columns={'my_array': 'ndarray'}) as out:
+ for i in range(42):
+ # Dimension can change
+ ndim = np.random.randint(1, 5)
+ shape = np.random.randint(1, 5, ndim)
+ shape = tuple(shape.tolist())
+ my_array = np.random.normal(0, 1, shape)
+ out.write({'my_array': my_array})
+
+# Inspect dataset
+dataset = StreamingDataset(local='my_dataset1/', batch_size=1)
+for i in range(dataset.num_samples):
+ print(dataset[i])
+```
+
+### Dynamic shape, fixed dtype
+
+The streaming encoding type, as the value in the `columns` dict, should be `ndarray:dtype`. So in this example, it is `ndarray:int16`.
+
+
+```python
+# Write to MDS
+with MDSWriter(out='my_dataset2/',
+ columns={'my_array': 'ndarray:int16'}) as out:
+ for i in range(42):
+ # Dimension can change
+ ndim = np.random.randint(1, 5)
+ shape = np.random.randint(1, 5, ndim)
+ shape = tuple(shape.tolist())
+ # Datatype is fixed
+ my_array = np.random.normal(0, 100, shape).astype(np.int16)
+ out.write({'my_array': my_array})
+
+# Inspect dataset
+dataset = StreamingDataset(local='my_dataset2/', batch_size=1)
+for i in range(dataset.num_samples):
+ print(dataset[i])
+```
+
+### Fixed shape, fixed dtype
+
+The streaming encoding type, as the value in the `columns` dict, should be `ndarray:dtype:shape`. So in this example, it is `ndarray:int16:3,3,3`.
+
+
+```python
+# Write to MDS
+with MDSWriter(out='my_dataset3/',
+ columns={'my_array': 'ndarray:int16:3,3,3'}) as out:
+ for i in range(42):
+ # Shape is fixed
+ shape = 3, 3, 3
+ # Datatype is fixed
+ my_array = np.random.normal(0, 100, shape).astype(np.int16)
+ out.write({'my_array': my_array})
+
+# Inspect dataset
+dataset = StreamingDataset(local='my_dataset3/', batch_size=1)
+for i in range(dataset.num_samples):
+ print(dataset[i])
+```
+
+We can see that the dataset is more efficiently serialized when we are more specific about array shape and datatype:
+
+
+```python
+import subprocess
+
+# Dynamic shape, dynamic dtype uses most space
+subprocess.run(['du', '-sh', 'my_dataset1'])
+
+# Dynamic shape, fixed dtype uses less space
+subprocess.run(['du', '-sh', 'my_dataset2'])
+
+# Fixed shape, fixed dtype uses the least space
+subprocess.run(['du', '-sh', 'my_dataset3'])
+```
+
+Clean up after ourselves.
+
+```python
+from shutil import rmtree
+
+rmtree('my_dataset1')
+rmtree('my_dataset2')
+rmtree('my_dataset3')
+```
diff --git a/docs/source/preparing_datasets/dataset_format.md b/docs/source/preparing_datasets/dataset_format.md
new file mode 100644
index 000000000..339fcea54
--- /dev/null
+++ b/docs/source/preparing_datasets/dataset_format.md
@@ -0,0 +1,66 @@
+# Dataset Format
+
+## Introduction
+
+To use StreamingDataset, one must convert raw data into one of our supported serialized dataset formats. With massive datasets, our serialization format choices are critical to the ultimate observed performance of the system. For deep learning models, we need extremely low latency cold random access of individual samples granularity to ensure that dataloading is not a bottleneck to training.
+
+StreamingDataset is compatible with any data type, including **images**, **text**, **video**, and **multimodal** data. StreamingDataset supports the following formats:
+ * MDS (Mosaic Data Shard, most performant), through {class}`streaming.MDSWriter`
+ * CSV/TSV, through {class}`streaming.CSVWriter` or {class}`streaming.TSVWriter`
+ * JSONL, through {class}`streaming.JSONWriter`
+
+These formats can encode and decode most Python objects.
+
+For a high-level explanation of how dataset writing works, check out the [main concepts](../getting_started/main_concepts.md#Dataset-conversion) page. The [Dataset Conversion Guide](basic_dataset_conversion_guide.md) shows how to use the {class}`streaming.MDSWriter` to convert your raw data to supported file formats. For large datasets, use the [Parallel Dataset Conversion](parallel_dataset_conversion.ipynb) guide.
+
+
+## Formats
+### 1. MDS
+Mosaic Data Shard (MDS) is our most performant file format for fast sample random-access, and stores data in serialized tabular form. A single sample is a dictionary of key/value pairs where the key is the column name, and the value is the sample's entry for that column. Use {class}`streaming.MDSWriter` for MDS.
+
+### 2. CSV/TSV
+CSV/TSV, or more generally XSV, is a plaintext tabular data format consisting of delimiter-separated values. For convenience, we have added two named sub-types which you will recognize as CSV (comma-delimited) and TSV (tab-delimited). To create datasets in these formats, use streaming.XSVWriter, streaming.CSVWriter, or streaming.TSVWriter.
+
+### 3. JSONL
+JSONL is a simple and popular dataset format in which each sample is a JSON dict terminated by a newline. Use {class}`streaming.JSONWriter` for JSONL.
+
+## Metadata
+
+Streaming also must store some metadata to keep track of a dataset's shards and samples. With MDS, only the `index.json` file is present, but with CSV/TSV and JSONL, additional files must also be stored which contain information about where specific samples are stored.
+
+### The `index.json` file
+As mentioned in the [main concepts](../getting_started/main_concepts.md#dataset-conversion) page, an `index.json` file is also created for each of shard files, containing information such as the number of shards, number of samples per shard, shard sizes, etc. An example `index.json` file, which has metadata for multiple MDS shards, and where samples contain only one column called "tokens" encoded as `Bytes`, is structured as below:
+
+```json
+{
+ "shards": [
+ { // Shard 0
+ "column_encodings": ["bytes"],
+ "column_names": ["tokens"],
+ "column_sizes": [null],
+ "compression": null,
+ "format": "mds",
+ "hashes": [],
+ "raw_data": {
+ "basename": "shard.00000.mds",
+ "bytes": 67092637,
+ "hashes": {}
+ },
+ "samples": 4093,
+ "size_limit": 67108864,
+ "version": 2,
+ "zip_data": null
+ },
+ { // Shard 1, very similar to Shard 0 metadata
+ ...
+ "raw_data": {
+ "basename": "shard.00001.mds",
+ "bytes": 67092637,
+ "hashes": {}
+ },
+ ...
+ },
+ // and so on
+ ]
+}
+```
diff --git a/examples/multiprocess_dataset_conversion.ipynb b/docs/source/preparing_datasets/parallel_dataset_conversion.ipynb
similarity index 78%
rename from examples/multiprocess_dataset_conversion.ipynb
rename to docs/source/preparing_datasets/parallel_dataset_conversion.ipynb
index d0ce9f134..35005ad2a 100644
--- a/examples/multiprocess_dataset_conversion.ipynb
+++ b/docs/source/preparing_datasets/parallel_dataset_conversion.ipynb
@@ -7,14 +7,14 @@
"id": "daY0p3RNvzFR"
},
"source": [
- "# Multiprocess dataset conversion\n",
+ "# Parallel dataset conversion\n",
"\n",
- "If your dataset is huge, running single process dataset conversion script could be very time consuming. You can use multiprocessing with MDSWriter to convert your dataset in parallel. There are few ways in which you can convert your raw data into MDS format parallelly.\n",
+ "If your dataset is huge, running single process dataset conversion script could be very time consuming. You can use multiprocessing with MDSWriter to convert your dataset in parallel. There are few ways in which you can convert your raw data into MDS format in parallel fashion.\n",
"\n",
"1. Download a raw data in parallel and convert to MDS format sequentially.\n",
- "2. Group a raw data and convert to MDS format parallely in separate sub-directories and then merge all the sub-directories index.json file to get a unified MDS dataset.\n",
+ "2. Group raw data and convert in parallel to MDS format in separate sub-directories. Then, merge all the `index.json` files from these subdirectories to get a unified MDS dataset.\n",
"\n",
- "Let's look at the small example of each one on how to that."
+ "Let's look at an example for each option."
]
},
{
@@ -25,7 +25,7 @@
},
"source": [
"## 1. Fetch raw data in parallel and write sequentially\n",
- "For a large individual dataset file such as image or a video, it would be useful to download those files in parallel by multiple processes and once it is downloaded, call the MDSWriter to write the data into MDS format. Below is one such example on how to do that."
+ "For a dataset with large files (such as images or videos), it would be useful to download those files in parallel using multiple processes and call the MDSWriter to write the data into MDS format."
]
},
{
@@ -37,7 +37,7 @@
"source": [
"### Setup\n",
"\n",
- "Let's start by making sure the right packages are installed and imported. We need to install the `mosaicml-streaming` package which installs the sufficient dependencies to run this tutorial."
+ "Let's start by installing the `mosaicml-streaming` package, and importing necessary dependencies."
]
},
{
@@ -74,7 +74,7 @@
"source": [
"### Global settings\n",
"\n",
- "Initialize the global variable"
+ "Initialize global variables:"
]
},
{
@@ -98,7 +98,7 @@
"id": "tq4NInVovzFW"
},
"source": [
- "Download data from URL. Here, we just return a number for demonstration purpose. "
+ "Download data from remote URLs. Here, we just return a number for demonstration purposes. "
]
},
{
@@ -122,7 +122,7 @@
"id": "yByl7cpsvzFX"
},
"source": [
- "Initialization method for each worker process which prints the worker PID."
+ "An initialization method for each worker process which prints the worker PID."
]
},
{
@@ -149,7 +149,7 @@
"source": [
"### Convert to MDS format\n",
"\n",
- "Initialize 4 worker processes which downloads the data in parallel and once the data is ready, it is getting written in MDS format using `write` method call."
+ "Initialize 4 worker processes which download the data in parallel. Once the data is ready, it is written to MDS format using the `write` method of {class}`streaming.MDSWriter`."
]
},
{
@@ -178,7 +178,7 @@
"source": [
"### Load MDS dataset\n",
"\n",
- "Read the sample using `StreamingDataset` which prints the sample ID."
+ "Read samples from MDS by iterating over `StreamingDataset`. Here, we just print sample IDs."
]
},
{
@@ -218,7 +218,7 @@
"source": [
"## 2. Group the raw data and convert to MDS format in parallel\n",
"\n",
- "For a large dataset file such as a tar file, zip file, or any other file, we would recommend to map one raw data file to one MDS sub-directories so that the dataset conversion happens by multiple process in parallel."
+ "For large raw datasets, or raw datasets with large files, we recommend partitioning dataset conversion among multiple `MDSWriter`s. Dataset conversion will take place with multiple processes in parallel."
]
},
{
@@ -228,7 +228,7 @@
"id": "3GbXadoPJne7"
},
"source": [
- "Import dependencies"
+ "Importing dependencies:"
]
},
{
@@ -258,7 +258,7 @@
"source": [
"### Global settings\n",
"\n",
- "Initialize the global variable"
+ "Initializing needed global variables:"
]
},
{
@@ -281,9 +281,9 @@
"id": "1qkyMzckKl0V"
},
"source": [
- "Get a sub-directory MDS path and raw dataset sample range of 10. For example, first sub-directory yields a sample from 0 to 9, second sub-directory yields a sample from 10 to 19, and so on.\n",
+ "This function yields a sub-directory path where MDS shards will be stored, as well as the raw dataset sample range of that directory. For example, the first sub-directory will contain samples 0 to 9, the second sub-directory will contain samples 10 to 19, and so on.\n",
"\n",
- "If you are working with a large file, you can also yield a raw dataset file path instead of sample range."
+ "If you are working with large files, you can also yield a single raw dataset file path instead of a sample range."
]
},
{
@@ -295,7 +295,7 @@
"outputs": [],
"source": [
"def each_task(out_root: str, groups: int) -> Iterator[Tuple[str, int, int]]:\n",
- " \"\"\"Get the sub-directory path and the sample range.\n",
+ " \"\"\"Get the sub-directory path and the sample range for each sub-directory.\n",
"\n",
" Args:\n",
" out_root (str): base output mds directory\n",
@@ -318,7 +318,7 @@
"id": "p9XqWLD-Moqz"
},
"source": [
- "Convert a raw dataset into MDS format. "
+ "This function converts raw dataset samples into MDS format. "
]
},
{
@@ -360,7 +360,7 @@
"id": "by7aPVIDM1mG"
},
"source": [
- "Divide the dataset into 4 sub-groups, each process takes a sub-group and converts a data into MDS format in their respective sub-directories."
+ "We partition the raw dataset into 4 sub-groups, and each process takes a converts a sub-group into MDS format. The resulting shards are stored in the respective sub-directories."
]
},
{
@@ -390,7 +390,7 @@
"id": "dX60JcG9M_aT"
},
"source": [
- "Once dataset has been converted to an MDS format, let's look at the directory structure. You will find 4 sub-directories and each sub-directories contain a `index.json` file and a shard files."
+ "Once dataset has been converted to an MDS format, let's look at the directory structure. You will find 4 sub-directories, each containing an `index.json` file and shard files."
]
},
{
@@ -413,7 +413,7 @@
"source": [
"### Merge meta data\n",
"\n",
- "The last step of the conversion process is to merge all the sub-directories `index.json` file. The content of the Shard files will remain as it is. By calling the merge_index utility function, the global shard information will be written to a new `index.json` file placed in `out`."
+ "The last step of the conversion process is to merge all the `index.json` files of the sub-directories. The content of the shard files will remain the same. By calling the merge_index utility function, information for all the shards will be written to a new `index.json` file placed in the `out` directory."
]
},
{
@@ -435,7 +435,7 @@
"id": "NByc1ZGINcXe"
},
"source": [
- "Let's checkout the root directories where you can see one `index.json` file and many shard files."
+ "Let's checkout the root directory, where you can see one `index.json` file along with subdirectories that contain shard files."
]
},
{
@@ -458,7 +458,7 @@
"source": [
"### Load MDS dataset\n",
"\n",
- "Read the sample using `StreamingDataset` which prints the sample ID."
+ "Read the sample using `StreamingDataset`. Here, we just print the sample IDs."
]
},
{
@@ -508,31 +508,7 @@
"\n",
"## What next?\n",
"\n",
- "You've now seen an in-depth tutorial on converting a dataset into MDS format using multiple process. If you are interested in the real world example, then, checkout the [WebVid](https://github.com/mosaicml/streaming/blob/main/streaming/multimodal/convert/webvid/crawl_webvid.py) and [Pile](https://github.com/mosaicml/streaming/blob/main/streaming/text/convert/pile.py) dataset conversion scripts which converts the dataset into MDS format via multiprocessing."
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {
- "id": "dvL3gQ3OXMZo"
- },
- "source": [
- "## Come get involved with MosaicML!\n",
- "\n",
- "We'd love for you to get involved with the MosaicML community in any of these ways:\n",
- "\n",
- "### [Star Streaming on GitHub](https://github.com/mosaicml/streaming)\n",
- "\n",
- "Help make others aware of our work by [starring Streaming on GitHub](https://github.com/mosaicml/streaming).\n",
- "\n",
- "### [Join the MosaicML Slack](https://mosaicml.me/slack)\n",
- "\n",
- "Head on over to the [MosaicML slack](https://mosaicml.me/slack) to join other ML efficiency enthusiasts. Come for the paper discussions, stay for the memes!\n",
- "\n",
- "### Contribute to Streaming\n",
- "\n",
- "Is there a bug you noticed or a feature you'd like? File an [issue](https://github.com/mosaicml/streaming/issues) or make a [pull request](https://github.com/mosaicml/streaming/pulls)!"
+ "You've now seen an in-depth tutorial on converting a dataset into MDS format using multiple process. If you are interested in some real-world examples, then, check out the [WebVid](https://github.com/mosaicml/streaming/blob/main/streaming/multimodal/convert/webvid/crawl_webvid.py) and [Pile](https://github.com/mosaicml/streaming/blob/main/streaming/text/convert/pile.py) dataset conversion scripts which convert datasets into MDS format via multiprocessing."
]
}
],
diff --git a/docs/source/preparing_datasets/spark_dataframe_to_mds.ipynb b/docs/source/preparing_datasets/spark_dataframe_to_mds.ipynb
new file mode 100644
index 000000000..60b0fc767
--- /dev/null
+++ b/docs/source/preparing_datasets/spark_dataframe_to_mds.ipynb
@@ -0,0 +1,857 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {},
+ "inputWidgets": {},
+ "nuid": "992ea5bf-f420-44e3-9c56-ab04c086bf05",
+ "showTitle": false,
+ "title": ""
+ },
+ "id": "vticii3WR_HN"
+ },
+ "source": [
+ "# Spark DataFrame to MDS\n",
+ "In this tutorial, we will demonstrate how to use the Streaming Spark converter to convert a Spark DataFrame to create a StreamingDataset. The users have the option to pass in a preprocessing job such as a tokenizer to the converter, which can be useful if materializing the intermediate dataframe is time consuming or taking extra development."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {},
+ "inputWidgets": {},
+ "nuid": "ed551b02-c3a7-47dc-8aa2-8e6e501cf5e3",
+ "showTitle": false,
+ "title": ""
+ },
+ "id": "EV5xY06KR_HO"
+ },
+ "source": [
+ "## Tutorial Covers\n",
+ "1. Installation of libraries\n",
+ "2. **Basic**: Convert Spark DataFrame to MDS format.\n",
+ "3. **Advanced**: Convert Spark DataFrame into tokenized format and convert to MDS format."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {},
+ "inputWidgets": {},
+ "nuid": "93f83c00-600f-4605-972c-5f0eb4a4152d",
+ "showTitle": false,
+ "title": ""
+ },
+ "id": "gyESiU8KR_HP"
+ },
+ "source": [
+ "## Setup\n",
+ "Let\u2019s start by installing `mosaicml-streaming` and some other needed packages."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "a3547314-ce67-4617-84b0-fafe366f82e4",
+ "showTitle": false,
+ "title": ""
+ },
+ "id": "QE2DHOK_R_HP"
+ },
+ "outputs": [],
+ "source": [
+ "%pip install --upgrade fsspec datasets transformers"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "115c2bc7-94d6-4ada-926c-fc5bdfd6c29c",
+ "showTitle": false,
+ "title": ""
+ },
+ "id": "PeEMBLaSR_HP"
+ },
+ "outputs": [],
+ "source": [
+ "%pip install mosaicml-streaming"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "zETa4qH0TnPE"
+ },
+ "outputs": [],
+ "source": [
+ "%pip install pyspark==3.4.1"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "zjlPAIBfhSbg"
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import shutil\n",
+ "from typing import Any, Sequence, Dict, Iterable, Optional\n",
+ "from pyspark.sql import SparkSession\n",
+ "import pandas as pd\n",
+ "import numpy as np\n",
+ "from tempfile import mkdtemp\n",
+ "import datasets as hf_datasets\n",
+ "from transformers import AutoTokenizer, PreTrainedTokenizerBase"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "G1I2CtGpRk40"
+ },
+ "source": [
+ "We\u2019ll be using Streaming\u2019s `dataframe_to_mds()` method which converts a DataFrame into Streaming's MDS format."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "uzYHe6yYRzyV"
+ },
+ "outputs": [],
+ "source": [
+ "from streaming.base.converters import dataframe_to_mds"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "42e9ffbc-52a7-479d-b3a5-608d044d1f6a",
+ "showTitle": false,
+ "title": ""
+ },
+ "id": "Un4G3VdgR_HQ"
+ },
+ "source": [
+ "## **Basic:** Convert Spark DataFrame to MDS format\n",
+ "**Steps:**\n",
+ "1. Create a Synthetic NLP dataset.\n",
+ "2. Store the above dataset as a parquet file.\n",
+ "3. Load the parquet file as spark dataframe.\n",
+ "4. Convert the Spark DataFrame to MDS format.\n",
+ "5. Load the MDS dataset and inspect the output."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {},
+ "inputWidgets": {},
+ "nuid": "90184dc8-cb85-4921-bc8e-e1525c6ee212",
+ "showTitle": false,
+ "title": ""
+ },
+ "id": "oKA2TTJ-R_HQ"
+ },
+ "source": [
+ "### Create a Synthetic NLP dataset\n",
+ "\n",
+ "In this tutorial, we will be creating a synthetic number-saying dataset, i.e. converting numbers from digits to words, for example, the number `123` would be converted to `one hundred twenty three`. The numbers are generated sequentially.\n",
+ "\n",
+ "Let\u2019s make a short synthetic number-saying dataset class."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "2e58eb05-4271-4ea4-b7ce-2e5466a1405b",
+ "showTitle": false,
+ "title": ""
+ },
+ "id": "SMirmZEqR_HQ"
+ },
+ "outputs": [],
+ "source": [
+ "class NumberAndSayDataset:\n",
+ " \"\"\"Generate a synthetic number-saying dataset.\n",
+ "\n",
+ " Converts numbers from digits to words. Supports positive and negative numbers\n",
+ " up to approximately 99 million.\n",
+ "\n",
+ " Args:\n",
+ " num_samples (int): number of samples. Defaults to 100.\n",
+ " column_names list[str]: A list of features and target name. Defaults to ['number',\n",
+ " 'words'].\n",
+ " seed (int): seed value for deterministic randomness.\n",
+ " \"\"\"\n",
+ "\n",
+ " ones = (\n",
+ " 'zero one two three four five six seven eight nine ten eleven twelve thirteen fourteen ' +\n",
+ " 'fifteen sixteen seventeen eighteen nineteen').split()\n",
+ "\n",
+ " tens = 'twenty thirty forty fifty sixty seventy eighty ninety'.split()\n",
+ "\n",
+ " def __init__(self,\n",
+ " num_samples: int = 100,\n",
+ " column_names: list[str] = ['number', 'words'],\n",
+ " seed: int = 987) -> None:\n",
+ " self.num_samples = num_samples\n",
+ " self.column_encodings = ['int', 'str']\n",
+ " self.column_sizes = [8, None]\n",
+ " self.column_names = column_names\n",
+ " self._index = 0\n",
+ " self.seed = seed\n",
+ "\n",
+ " def __len__(self) -> int:\n",
+ " return self.num_samples\n",
+ "\n",
+ " def _say(self, i: int) -> list[str]:\n",
+ " if i < 0:\n",
+ " return ['negative'] + self._say(-i)\n",
+ " elif i <= 19:\n",
+ " return [self.ones[i]]\n",
+ " elif i < 100:\n",
+ " return [self.tens[i // 10 - 2]] + ([self.ones[i % 10]] if i % 10 else [])\n",
+ " elif i < 1_000:\n",
+ " return [self.ones[i // 100], 'hundred'] + (self._say(i % 100) if i % 100 else [])\n",
+ " elif i < 1_000_000:\n",
+ " return self._say(i // 1_000) + ['thousand'\n",
+ " ] + (self._say(i % 1_000) if i % 1_000 else [])\n",
+ " elif i < 1_000_000_000:\n",
+ " return self._say(\n",
+ " i // 1_000_000) + ['million'] + (self._say(i % 1_000_000) if i % 1_000_000 else [])\n",
+ " else:\n",
+ " assert False\n",
+ "\n",
+ " def _get_number(self) -> int:\n",
+ " sign = (np.random.random() < 0.8) * 2 - 1\n",
+ " mag = 10**np.random.uniform(1, 4) - 10\n",
+ " return sign * int(mag**2)\n",
+ "\n",
+ " def __iter__(self):\n",
+ " return self\n",
+ "\n",
+ " def __next__(self) -> dict[str, Any]:\n",
+ " if self._index >= self.num_samples:\n",
+ " raise StopIteration\n",
+ " number = self._get_number()\n",
+ " words = ' '.join(self._say(number))\n",
+ " self._index += 1\n",
+ " return {\n",
+ " self.column_names[0]: number,\n",
+ " self.column_names[1]: words,\n",
+ " }\n",
+ "\n",
+ " @property\n",
+ " def seed(self) -> int:\n",
+ " return self._seed\n",
+ "\n",
+ " @seed.setter\n",
+ " def seed(self, value: int) -> None:\n",
+ " self._seed = value # pyright: ignore\n",
+ " np.random.seed(self._seed)\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "c4tvDe8hTeMS"
+ },
+ "source": [
+ "### Store the dataset as a parquet file"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "89707251-35c0-4c17-841d-2cc026e6f96f",
+ "showTitle": false,
+ "title": ""
+ },
+ "id": "ysofCNb3R_HR"
+ },
+ "outputs": [],
+ "source": [
+ "# Create a temporary directory\n",
+ "local_dir = mkdtemp()\n",
+ "\n",
+ "syn_dataset = NumberAndSayDataset()\n",
+ "df = pd.DataFrame.from_dict([record for record in syn_dataset])\n",
+ "df.to_parquet(os.path.join(local_dir, 'synthetic_dataset.parquet'))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {},
+ "inputWidgets": {},
+ "nuid": "e9e12fcc-fa78-4635-9e62-fb67ca5f520f",
+ "showTitle": false,
+ "title": ""
+ },
+ "id": "4hvJgJGxR_HR"
+ },
+ "source": [
+ "### Load the parquet file as spark dataframe"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "10e4ZrH3he0q"
+ },
+ "outputs": [],
+ "source": [
+ "spark = SparkSession.builder.getOrCreate()\n",
+ "pdf = spark.read.parquet(os.path.join(local_dir, 'synthetic_dataset.parquet'))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {},
+ "inputWidgets": {},
+ "nuid": "9a43d639-c0a8-438a-9c56-f748b6ca79ad",
+ "showTitle": false,
+ "title": ""
+ },
+ "id": "8JfeOLqWR_HS"
+ },
+ "source": [
+ "Take a peek at the spark dataframe"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "iGkjgXtDhk6g"
+ },
+ "outputs": [],
+ "source": [
+ "pdf.show(5, truncate=False)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {},
+ "inputWidgets": {},
+ "nuid": "2a2468b8-9538-42e6-b0c6-937a3de92210",
+ "showTitle": false,
+ "title": ""
+ },
+ "id": "pt5MYIrmR_HS"
+ },
+ "source": [
+ "### Convert the spark dataframe to MDS format"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "aad931d4-8b97-4d5b-9c40-0f35e715c7b4",
+ "showTitle": false,
+ "title": ""
+ },
+ "id": "zRghnT00R_HS"
+ },
+ "outputs": [],
+ "source": [
+ "# Empty the MDS output directory\n",
+ "out_path = os.path.join(local_dir, 'mds')\n",
+ "shutil.rmtree(out_path, ignore_errors=True)\n",
+ "\n",
+ "# Specify the mandatory MDSWriter arguments `out` and `columns`.\n",
+ "mds_kwargs = {'out': out_path, 'columns': {'number': 'int64', 'words':'str'}}\n",
+ "\n",
+ "# Convert the dataset to an MDS format. It divides the dataframe into 4 parts, one parts per worker and merge the `index.json` from 4 sub-parts into one in a parent directory.\n",
+ "dataframe_to_mds(pdf.repartition(4), merge_index=True, mds_kwargs=mds_kwargs)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {},
+ "inputWidgets": {},
+ "nuid": "327543b6-b2c8-4c60-bff3-bb4e5a980a28",
+ "showTitle": false,
+ "title": ""
+ },
+ "id": "MdKpnmkhR_HS"
+ },
+ "source": [
+ "Let's check file structures in the output MDS dataset. One can see four directories and one `index.json` file. The `index.json` file contains the meta-data information about all four sub-directories."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "268c137e-6e38-4f31-8e51-0adfe3c82c44",
+ "showTitle": false,
+ "title": ""
+ },
+ "id": "CDbPsmUeR_HS"
+ },
+ "outputs": [],
+ "source": [
+ "%ls {out_path}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {},
+ "inputWidgets": {},
+ "nuid": "6d6aa896-9d5a-4763-aaf4-197f9170ec27",
+ "showTitle": false,
+ "title": ""
+ },
+ "id": "S6j3pr_IR_HS"
+ },
+ "source": [
+ "### Load the MDS dataset using StreamingDataset\n",
+ "Here, we use StreamingDataset to load the MDS dataset and inspect it."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "ad4359ac-1652-48bd-b4db-ea034e180e01",
+ "showTitle": false,
+ "title": ""
+ },
+ "id": "qfpr2Df3R_HS"
+ },
+ "outputs": [],
+ "source": [
+ "from torch.utils.data import DataLoader\n",
+ "import streaming\n",
+ "from streaming import StreamingDataset\n",
+ "\n",
+ "# clean stale shared memory if any\n",
+ "streaming.base.util.clean_stale_shared_memory()\n",
+ "\n",
+ "dataset = StreamingDataset(local=out_path, remote=None, batch_size=2, predownload=4)\n",
+ "\n",
+ "dataloader = DataLoader(dataset, batch_size=2, num_workers=1)\n",
+ "\n",
+ "for i, data in enumerate(dataloader):\n",
+ " print(data)\n",
+ " # Display only first 10 batches\n",
+ " if i == 10:\n",
+ " break"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "84e7df6b-b890-45ad-93de-00903a59c058",
+ "showTitle": false,
+ "title": ""
+ },
+ "id": "Uo1jLj1dR_HS"
+ },
+ "source": [
+ "## **Advanced**: Convert Spark DataFrame into tokenized format and convert to MDS format\n",
+ "**Steps:**\n",
+ "1. [Same as above] Create a Synthetic NLP dataset.\n",
+ "2. [Same as above] Store the above dataset as a parquet file.\n",
+ "3. [Same as above] Load the parquet file as spark dataframe.\n",
+ "4. Create a user defined function which modifies the dataframe\n",
+ "4. Convert the modified data into MDS format.\n",
+ "5. Load the MDS dataset and look at the output.\n",
+ "\n",
+ "For steps 1-3, follow the steps detailed above."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {},
+ "inputWidgets": {},
+ "nuid": "57ffd58d-118b-485b-9f5c-d42ecc6f81ad",
+ "showTitle": false,
+ "title": ""
+ },
+ "id": "geRPrzMhR_HS"
+ },
+ "source": [
+ "### Create a user defined function which modifies the dataframe\n",
+ "\n",
+ "The user defined function should be an iterable function and it must yield an output as a dictionary with `key` as the column name and `value` as the output of that column. For example, in this tutorial, the `key` is `tokens` and `value` is the tokenized output in bytes. If an iterable function is defined, the user takes the full responsibility of providing the correct `columns` argument, in the case below, it should be\n",
+ "\n",
+ "```\n",
+ "columns={'tokens': 'bytes'}\n",
+ "```\n",
+ "\n",
+ "where `tokens` is the key created by the udf_iterator, and `bytes` represents the format of the field so that MDS chooses the proper encoding method."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "P1P7BsXsaDCy"
+ },
+ "source": [
+ "\n",
+ "Take a peek at the Spark DataFrame"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "mU5qfafGzKsQ"
+ },
+ "outputs": [],
+ "source": [
+ "pdf.show(5, truncate=False)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "JMeyEPoSaGgH"
+ },
+ "source": [
+ "### Convert the Spark DataFrame to MDS format\n",
+ "This time we supply the user defined iterable function and the associated function arguments. For the purpose of demonstration, the user defined tokenization function `pandas_processing_fn` is largely simplified. For practical applications, the users may want to have more involved preprocessing steps. For concatenation dataset and more process examples, users are referred to [Mosaic's LLM Foundry](https://github.com/mosaicml/llm-foundry/blob/main/llmfoundry/data/data.py)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "H7Z5kjNie4ZR"
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import warnings\n",
+ "from typing import Dict, Iterable, Union\n",
+ "import datasets as hf_datasets\n",
+ "import pandas as pd\n",
+ "import numpy as np\n",
+ "from torch.utils.data import IterableDataset\n",
+ "from transformers import PreTrainedTokenizerBase\n",
+ "\n",
+ "\n",
+ "def pandas_processing_fn(df: pd.DataFrame, **args) -> Iterable[Dict[str, bytes]]:\n",
+ " \"\"\"\n",
+ " Parameters:\n",
+ " -----------\n",
+ " df : pandas.DataFrame\n",
+ " The input pandas DataFrame that needs to be processed.\n",
+ "\n",
+ " **args : keyword arguments\n",
+ " Additional arguments to be passed to the 'process_some_data' function during processing.\n",
+ "\n",
+ " Returns:\n",
+ " --------\n",
+ " iterable obj\n",
+ " \"\"\"\n",
+ " hf_dataset = hf_datasets.Dataset.from_pandas(df=df, split=args['split'])\n",
+ " tokenizer = AutoTokenizer.from_pretrained(args['tokenizer'])\n",
+ " # we will enforce length, so suppress warnings about sequences too long for the model\n",
+ " tokenizer.model_max_length = int(1e30)\n",
+ " max_length = args['concat_tokens']\n",
+ "\n",
+ " for sample in hf_dataset:\n",
+ "\n",
+ " buffer = []\n",
+ " for sample in hf_dataset:\n",
+ " encoded = tokenizer(sample['words'],\n",
+ " truncation=False,\n",
+ " padding=False)\n",
+ " iids = encoded['input_ids']\n",
+ " buffer = buffer + iids\n",
+ " while len(buffer) >= max_length:\n",
+ " concat_sample = buffer[:max_length]\n",
+ " buffer = []\n",
+ " yield {\n",
+ " # convert to bytes to store in MDS binary format\n",
+ " 'tokens': np.asarray(concat_sample).tobytes()\n",
+ " }"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "b0b3cfd1-9753-44ca-9b03-55f29c60fb56",
+ "showTitle": false,
+ "title": ""
+ },
+ "id": "R9Xxr0CDR_HT"
+ },
+ "outputs": [],
+ "source": [
+ "# Empty the MDS output directory\n",
+ "out_path = os.path.join(local_dir, 'mds')\n",
+ "shutil.rmtree(out_path, ignore_errors=True)\n",
+ "\n",
+ "# Provide a MDS keyword args. Ensure `columns` field maps the output from iterable function (Tokenizer in this example)\n",
+ "mds_kwargs = {'out': out_path, 'columns': {'tokens': 'bytes'}}\n",
+ "\n",
+ "# Tokenizer arguments\n",
+ "udf_kwargs = {\n",
+ " 'concat_tokens': 4,\n",
+ " 'tokenizer': 'EleutherAI/gpt-neox-20b',\n",
+ " 'eos_text': '<|endoftext|>',\n",
+ " 'compression': 'zstd',\n",
+ " 'split': 'train',\n",
+ " 'no_wrap': False,\n",
+ " 'bos_text': '',\n",
+ "}\n",
+ "\n",
+ "# Convert the dataset to an MDS format. It fetches sample from dataframe, tokenize it, and then convert to MDS format.\n",
+ "# It divides the dataframe into 4 parts, one parts per worker and merge the `index.json` from 4 sub-parts into one in a parent directory.\n",
+ "dataframe_to_mds(pdf.repartition(4), merge_index=True, mds_kwargs=mds_kwargs, udf_iterable=pandas_processing_fn, udf_kwargs=udf_kwargs)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {},
+ "inputWidgets": {},
+ "nuid": "8c3ea1b0-cab4-4d0b-b4a2-ed30dae39959",
+ "showTitle": false,
+ "title": ""
+ },
+ "id": "Vady_sp9R_HT"
+ },
+ "source": [
+ "Let's check file structures in the output MDS dataset. One can see four directories and one `index.json` file. The `index.json` file contains the meta-data information about all four sub-directories."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "a071fe82-8d1e-4dd4-9e34-8a0dcbee32e0",
+ "showTitle": false,
+ "title": ""
+ },
+ "id": "VUP5VsILR_HT"
+ },
+ "outputs": [],
+ "source": [
+ "%ls {out_path}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "WVR9ld1vl5VX"
+ },
+ "outputs": [],
+ "source": [
+ "%cat {out_path +'/index.json'}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {},
+ "inputWidgets": {},
+ "nuid": "f84d61f1-9f09-484c-97cf-1f5ee1e8a214",
+ "showTitle": false,
+ "title": ""
+ },
+ "id": "q7r3o-mKR_HT"
+ },
+ "source": [
+ "### Load the MDS dataset using StreamingDataset\n",
+ "Here, we use StreamingDataset to load the MDS dataset and inspect it."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {},
+ "inputWidgets": {},
+ "nuid": "ee73abb8-2261-484d-93d5-722b67090ed6",
+ "showTitle": false,
+ "title": ""
+ },
+ "id": "gQSmFIq6R_HT"
+ },
+ "outputs": [],
+ "source": [
+ "from torch.utils.data import DataLoader\n",
+ "import streaming\n",
+ "from streaming import StreamingDataset\n",
+ "\n",
+ "# clean stale shared memory if any\n",
+ "streaming.base.util.clean_stale_shared_memory()\n",
+ "\n",
+ "dataset = StreamingDataset(local=out_path, remote=None, batch_size=2, predownload=4)\n",
+ "\n",
+ "dataloader = DataLoader(dataset, batch_size=2, num_workers=1)\n",
+ "\n",
+ "for i, data in enumerate(dataloader):\n",
+ " print(data)\n",
+ " # Display only first 10 batches\n",
+ " if i == 10:\n",
+ " break"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "xcVwdCL_bcg8"
+ },
+ "source": [
+ "## Cleanup"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Wa_ZGkBckq-b"
+ },
+ "outputs": [],
+ "source": [
+ "shutil.rmtree(out_path, ignore_errors=True)\n",
+ "shutil.rmtree(local_dir, ignore_errors=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "uia53VVabf5P"
+ },
+ "source": [
+ "\n",
+ "## What next?\n",
+ "\n",
+ "You've now seen an in-depth look at how to convert a Spark DataFrame to MDS format and load the same MDS dataset for model training."
+ ]
+ }
+ ],
+ "metadata": {
+ "application/vnd.databricks.v1+notebook": {
+ "dashboards": [],
+ "language": "python",
+ "notebookMetadata": {
+ "mostRecentlyExecutedCommandWithImplicitDF": {
+ "commandId": -1,
+ "dataframes": [
+ "_sqldf"
+ ]
+ },
+ "pythonIndentUnit": 2
+ },
+ "notebookName": "SPARK DataFrame",
+ "widgets": {}
+ },
+ "colab": {
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/setup.py b/setup.py
index 24c5c05c4..41e6e54e2 100644
--- a/setup.py
+++ b/setup.py
@@ -81,17 +81,23 @@
]
extra_deps['docs'] = [
- 'GitPython==3.1.41',
- 'docutils==0.18.1',
- 'furo==2024.1.29',
- 'myst-parser==2.0.0',
- 'nbsphinx==0.9.2',
+ 'GitPython==3.1.42',
+ 'docutils==0.17.1',
+ 'furo==2022.9.29',
+ 'myst-parser==0.16.1',
+ 'nbsphinx==0.9.1',
'pandoc==2.3',
'pypandoc==1.13',
'sphinx-argparse==0.4.0',
'sphinx-copybutton==0.5.2',
- 'sphinx==6.2.1',
+ 'sphinx==4.4.0',
'sphinx-tabs==3.4.5',
+ 'sphinxcontrib.katex==0.9.6',
+ 'sphinxcontrib-applehelp==1.0.0',
+ 'sphinxcontrib-devhelp==1.0.0',
+ 'sphinxcontrib-htmlhelp==2.0.0',
+ 'sphinxcontrib-qthelp==1.0.0',
+ 'sphinxcontrib-serializinghtml==1.1.5',
]
extra_deps['simulator'] = [
diff --git a/streaming/base/converters/README.md b/streaming/base/converters/README.md
index 82bd501fb..b811055f7 100644
--- a/streaming/base/converters/README.md
+++ b/streaming/base/converters/README.md
@@ -4,4 +4,4 @@ Users can read datasets of any formats that Spark supports and convert the Spark
1. We enable converting a Spark DataFrame into an MDS format via the utility function [dataframe_to_mds](https://github.com/mosaicml/streaming/blob/main/streaming/base/converters/dataframe_to_mds.py). This utility function is flexible and supports a callable function, allowing modifications to the original data format. The function iterates over the callable, processes the modified data, and writes it in MDS format. For instance, it can be used with a tokenizer callable function that yields tokens as output.
-2. Users are recommended to refer to the starting example [Jupyter notebook](https://github.com/mosaicml/streaming/blob/main/examples/spark_dataframe_to_MDS.ipynb) which demonstrates a complete workflow. It illustrates how to use Spark to read raw data into a Spark DataFrame and then convert it into the MDS format via the `dataframe_to_mds` function. In that tutorial, we also demonstrate the option to pass in a preprocessing tokenization job to the converter, which can be useful if materializing the intermediate dataframe is time consuming or taking extra development.
+2. Users are recommended to refer to the starting example [Jupyter notebook](https://github.com/mosaicml/streaming/blob/main/preparing_datasets/spark_dataframe_to_mds.ipynb) which demonstrates a complete workflow. It illustrates how to use Spark to read raw data into a Spark DataFrame and then convert it into the MDS format via the `dataframe_to_mds` function. In that tutorial, we also demonstrate the option to pass in a preprocessing tokenization job to the converter, which can be useful if materializing the intermediate dataframe is time consuming or taking extra development.
diff --git a/streaming/base/dataset.py b/streaming/base/dataset.py
index 5c3bb1e8c..9966031d6 100644
--- a/streaming/base/dataset.py
+++ b/streaming/base/dataset.py
@@ -285,7 +285,7 @@ class StreamingDataset(Array, IterableDataset):
.. note::
For sequential sample ordering, set ``shuffle`` to ``False`` and
- ``num_canonical_nodes`` to 1.
+ ``num_canonical_nodes`` to the number of physical nodes of the initial run.
batch_size (int, optional): Per-device batch size, the same as what is passed to the
DataLoader. This affects how the dataset is partitioned over the workers and is
necessary for deterministic resumption and optimal performance. Defaults to ``None``.