Skip to content

Commit

Permalink
Merge branch 'main' into srs/change-broken-user-guide-link-to
Browse files Browse the repository at this point in the history
  • Loading branch information
snarayan21 authored Dec 9, 2024
2 parents 80098ab + a0d491e commit a979454
Show file tree
Hide file tree
Showing 25 changed files with 68 additions and 69 deletions.
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ pytest -vv -s . # run all the unittests
cd docs && make clean && make doctest # run doctests
```

6\. [Optional] Compile and visualize the documentation locally. If you have a documentation changes, running the below commands is mandatory.
6\. [Optional] Compile and visualize the documentation locally. If you have documentation changes, running the below commands is mandatory.

<!--pytest.mark.skip-->
```bash
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# several pytest settings
PYTHON ?= python # Python command
PYTEST ?= pytest # Pytest command
PYRIGHT ?= pyright # Pyright command. Pyright must be installed seperately -- e.g. `node install -g pyright`
PYRIGHT ?= pyright # Pyright command. Pyright must be installed separately -- e.g. `node install -g pyright`
EXTRA_ARGS ?= # extra arguments for pytest

dirs := streaming tests docs
Expand Down
2 changes: 1 addition & 1 deletion docs/source/_templates/base.html
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@
version = fragments[1].split("/")[0]

// NOTE: The version string will resolve to the PR number for RTD sites.
// Checking whether first charater is a number.
// Checking whether first character is a number.
if (version[0] >= '0' && version[0] <= '9') {
version = undefined
}
Expand Down
2 changes: 1 addition & 1 deletion docs/source/dataset_configuration/shuffling.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,4 @@ Samples within each shard are shuffled both before and after shards are split am

Globally shuffles all samples. This is useful for single-node training on small data, where you want the most random shuffle possible, but is the least download-efficient of all shuffle algorithms. Training throughput is often much lower when using the `naive` shuffling algorithm.

If you are having trouble with throughput, network downloads, or shuffle quality, please refer to the [perfomance tuning page](../distributed_training/performance_tuning.md).
If you are having trouble with throughput, network downloads, or shuffle quality, please refer to the [performance tuning page](../distributed_training/performance_tuning.md).
2 changes: 1 addition & 1 deletion docs/source/distributed_training/performance_tuning.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ $$L = 2 \cdot S \cdot \lceil\frac{C}{P}\rceil $$

Where $L$ is the required minimum cache limit per node, in MB, $S$ is the average shard size, in MB, $C$ is the number of canonical nodes (see [here](../dataset_configuration/shuffling.md#how-shuffling-works) and [here](../distributed_training/elastic_determinism.md#requirements)), and $P$ is the number of physical nodes. This is because only a single shard, plus a potentially predownloaded subsequent shard, needs to be resident per canonical node to make progress during training.

If using a shuffle-block-based algorithm such as [`'py1e'`](../dataset_configuration/shuffling.md#py1e-default) or [`'py1br'`](../dataset_configuration/shuffling.md#py1br), the required minumum cache limit per node will be approximately:
If using a shuffle-block-based algorithm such as [`'py1e'`](../dataset_configuration/shuffling.md#py1e-default) or [`'py1br'`](../dataset_configuration/shuffling.md#py1br), the required minimum cache limit per node will be approximately:

$$L = k \cdot S \lceil \frac{B}{Q} \rceil \cdot \lceil\frac{C}{P}\rceil $$

Expand Down
4 changes: 2 additions & 2 deletions scripts/samples/bench_and_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def bench(args: Namespace, bench_name: str, desc: str, generate: Callable,
args (Namespace): Command-line arguments.
bench_name (str): What to call this benchmark.
desc (str): Brief description of the data.
generate (Callable): Method to genereate the dataset.
generate (Callable): Method to generate the dataset.
formats (List[str]): List of shard formats to benchmark this data in.
"""
print(f'Bench: {bench_name}')
Expand Down Expand Up @@ -373,7 +373,7 @@ def bench(args: Namespace, bench_name: str, desc: str, generate: Callable,
y *= args.plot_bins
y = y.astype(np.int64)

# Truncate the higest ``args.truncate_highest_frac`` timings because they get further
# Truncate the highest ``args.truncate_highest_frac`` timings because they get further
# and further spaced as you ascend, which would ruin the plot.
y = y[np.nonzero(y < args.plot_bins)[0]]

Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
'Brotli>=1.0.9',
'google-cloud-storage>=2.9.0,<2.11.0',
'matplotlib>=3.5.2,<4',
'numpy>=1.21.5,<2.2.0',
'paramiko>=2.11.0,<4',
'python-snappy>=0.6.1,<1',
'torch>=1.10,<3',
Expand Down
11 changes: 6 additions & 5 deletions simulation/core/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
from typing import Generator, Union

import numpy as np
from core.node_tracker import NodeTracker
from core.shard_downloads import run_cache_limit, simulate_shard_downloads
from core.sim_dataset import SimulationDataset
from core.sim_time import Time
from core.utils import bytes_to_time, get_batches_epochs, time_to_bytes
from numpy.typing import NDArray

from simulation.core.node_tracker import NodeTracker
from simulation.core.shard_downloads import run_cache_limit, simulate_shard_downloads
from simulation.core.sim_dataset import SimulationDataset
from simulation.core.sim_time import Time
from simulation.core.utils import bytes_to_time, get_batches_epochs, time_to_bytes

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

Expand Down
4 changes: 2 additions & 2 deletions simulation/core/node_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from typing import Optional

import numpy as np
from core.last_used_ordered_set import LastUsedOrderedSet
from core.utils import remove_padded_samples
from numpy.typing import NDArray
from sortedcollections import OrderedSet

from simulation.core.last_used_ordered_set import LastUsedOrderedSet
from simulation.core.utils import remove_padded_samples
from streaming.base.spanner import Spanner


Expand Down
3 changes: 2 additions & 1 deletion simulation/core/shard_downloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from typing import Optional

import numpy as np
from core.node_tracker import NodeTracker
from numpy.typing import NDArray

from simulation.core.node_tracker import NodeTracker


def simulate_shard_downloads(node: NodeTracker,
raw_shard_sizes: NDArray[np.int64],
Expand Down
2 changes: 1 addition & 1 deletion simulation/core/shuffle_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import logging

import numpy as np
from core.utils import remove_padded_samples
from numpy.typing import NDArray

from simulation.core.utils import remove_padded_samples
from streaming.base.partition.orig import get_partitions_orig
from streaming.base.shuffle import get_shuffle

Expand Down
4 changes: 2 additions & 2 deletions simulation/core/sim_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from typing import Optional, Sequence, Union

import numpy as np
from core.sim_spanner import SimulationSpanner
from core.sim_world import SimulationWorld
from numpy.typing import NDArray

from simulation.core.sim_spanner import SimulationSpanner
from simulation.core.sim_world import SimulationWorld
from streaming.base import Stream, StreamingDataset
from streaming.base.batching import generate_work
from streaming.base.format import get_index_basename
Expand Down
7 changes: 4 additions & 3 deletions simulation/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
"""Peripheral functions for simulation functionality."""

import numpy as np
from core.sim_dataset import SimulationDataset
from core.sim_time import Time, TimeUnit
from numpy.typing import NDArray

from simulation.core.sim_dataset import SimulationDataset
from simulation.core.sim_time import Time, TimeUnit


def get_batches_epochs(dataset: SimulationDataset, max_duration: Time) -> tuple[int, int, int]:
"""Get batches per epoch, epochs, and total epochs from a Time object.
Expand All @@ -19,7 +20,7 @@ def get_batches_epochs(dataset: SimulationDataset, max_duration: Time) -> tuple[
Returns:
Tuple[int, int, int]: batches per epoch, epochs, and the total batches.
"""
# get epochs, batches_per_epoch, and total_batches from a Time obect
# get epochs, batches_per_epoch, and total_batches from a Time object
dataset_batches = dataset.get_num_batches()
batches_per_epoch = dataset_batches
epochs = 1
Expand Down
4 changes: 2 additions & 2 deletions simulation/core/yaml_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

from typing import Optional

from core.sim_dataset import SimulationDataset
from core.sim_time import Time, TimeUnit, ensure_time
from omegaconf import DictConfig
from omegaconf import OmegaConf as om

from simulation.core.sim_dataset import SimulationDataset
from simulation.core.sim_time import Time, TimeUnit, ensure_time
from streaming.base import Stream


Expand Down
2 changes: 1 addition & 1 deletion simulation/interfaces/interface_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
from typing import Optional

import numpy as np
from core.utils import get_rolling_avg_throughput
from numpy.typing import NDArray

from simulation.core.utils import get_rolling_avg_throughput
from streaming.base.util import number_abbrev_to_int


Expand Down
8 changes: 4 additions & 4 deletions simulation/interfaces/sim_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
import argparse

import humanize
from core.main import simulate
from core.utils import get_simulation_stats
from core.yaml_processing import create_simulation_dataset, ingest_yaml
from interfaces.interface_utils import plot_simulation

from simulation.core.main import simulate
from simulation.core.utils import get_simulation_stats
from simulation.core.yaml_processing import create_simulation_dataset, ingest_yaml
from simulation.interfaces.interface_utils import plot_simulation
from streaming.base.util import bytes_to_int

if __name__ == '__main__':
Expand Down
15 changes: 6 additions & 9 deletions simulation/interfaces/sim_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,17 @@
"""Script for simulating training downloads and throughput, and displaying results."""

import os.path
import sys

import humanize
from core.create_index import create_stream_index
from core.main import simulate
from core.sim_dataset import SimulationDataset
from core.sim_time import TimeUnit, ensure_time
from core.utils import get_simulation_stats
from interfaces.interface_utils import plot_simulation

from simulation.core.create_index import create_stream_index
from simulation.core.main import simulate
from simulation.core.sim_dataset import SimulationDataset
from simulation.core.sim_time import TimeUnit, ensure_time
from simulation.core.utils import get_simulation_stats
from simulation.interfaces.interface_utils import plot_simulation
from streaming.base import Stream

sys.path.append(os.path.join(os.path.dirname(__file__), '..'))

# Input Parameters

# dataset
Expand Down
23 changes: 10 additions & 13 deletions simulation/interfaces/sim_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import math
import os.path
import sys
from concurrent.futures import ProcessPoolExecutor
from io import StringIO
from typing import Union
Expand All @@ -14,21 +13,19 @@
import pandas as pd
import streamlit as st
import yaml
from core.create_index import create_stream_index
from core.main import simulate
from core.shuffle_quality import analyze_shuffle_quality_entropy
from core.sim_dataset import SimulationDataset
from core.sim_time import Time
from core.utils import get_total_batches
from core.yaml_processing import create_simulation_dataset, ingest_yaml
from interfaces.interface_utils import get_train_dataset_params
from interfaces.widgets import (display_shuffle_quality_graph, display_simulation_stats,
get_line_chart, param_inputs)

from simulation.core.create_index import create_stream_index
from simulation.core.main import simulate
from simulation.core.shuffle_quality import analyze_shuffle_quality_entropy
from simulation.core.sim_dataset import SimulationDataset
from simulation.core.sim_time import Time
from simulation.core.utils import get_total_batches
from simulation.core.yaml_processing import create_simulation_dataset, ingest_yaml
from simulation.interfaces.interface_utils import get_train_dataset_params
from simulation.interfaces.widgets import (display_shuffle_quality_graph, display_simulation_stats,
get_line_chart, param_inputs)
from streaming.base.util import bytes_to_int, number_abbrev_to_int

sys.path.append(os.path.join(os.path.dirname(__file__), '..'))

# set up page
st.set_page_config(layout='wide')
col1, space, col2 = st.columns((10, 1, 6))
Expand Down
8 changes: 2 additions & 6 deletions simulation/interfaces/widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,20 @@

"""Streamlit widgets for simulation web UI."""

import os.path
import sys
from concurrent.futures import Future
from typing import Optional

import altair as alt
import humanize
import pandas as pd
import streamlit as st
from core.sim_time import TimeUnit, ensure_time
from core.utils import get_simulation_stats
from numpy.typing import NDArray
from streamlit.delta_generator import DeltaGenerator

from simulation.core.sim_time import TimeUnit, ensure_time
from simulation.core.utils import get_simulation_stats
from streaming.base.util import bytes_to_int

sys.path.append(os.path.join(os.path.dirname(__file__), '..'))


def get_line_chart(data: pd.DataFrame,
throughput_window: int,
Expand Down
14 changes: 5 additions & 9 deletions simulation/testing/wandb_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,20 @@

"""Test simulation results against run results from a wandb project."""

import os.path
import sys

sys.path.append(os.path.join(os.path.dirname(__file__), '..'))

import logging
import os
import os.path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import wandb
from core.create_index import create_stream_index
from core.main import simulate
from core.sim_dataset import SimulationDataset
from core.sim_time import TimeUnit, ensure_time
from numpy.typing import NDArray

from simulation.core.create_index import create_stream_index
from simulation.core.main import simulate
from simulation.core.sim_dataset import SimulationDataset
from simulation.core.sim_time import TimeUnit, ensure_time
from streaming.base import Stream

logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion streaming/base/batching/stratified.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def generate_work_stratified_batching(dataset: StreamingDataset, world: World, e
f'Number of samples for stream {stream_id} is {batch_portion} because the portion '
+
f'of this stream in the global batch, which is of size {global_batch_size}, is ' +
f'too low. Please increase the global batch size or increase the porportion of ' +
f'too low. Please increase the global batch size or increase the proportion of ' +
f'total samples that come from stream {stream_id}.')

# We now merge the partitions from each stream to get our final partition over all
Expand Down
9 changes: 9 additions & 0 deletions streaming/base/storage/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,15 @@ def _create_s3_client(self, unsigned: bool = False, timeout: float = DEFAULT_TIM
config=config,
endpoint_url=os.environ.get('S3_ENDPOINT_URL'))

def __getstate__(self) -> dict:
state = self.__dict__.copy()
state['_s3_client'] = None # Exclude _s3_client from being pickled
return state

def __setstate__(self, state: dict):
self.__dict__.update(state)
self._s3_client = None # Ensure _s3_client is reset after unpickling


class SFTPDownloader(CloudDownloader):
"""Download files from SFTP to local filesystem."""
Expand Down
2 changes: 1 addition & 1 deletion streaming/text/convert/enwiki/mds/merge_shard_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


def parse_args() -> Namespace:
"""Parse commmand-line arguments.
"""Parse command-line arguments.
Returns:
Namespace: Command-line arguments.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Script for picking certain number of sampels.
"""Script for picking certain number of samples.
"""

import argparse
Expand Down
2 changes: 1 addition & 1 deletion tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ def test_stratified_batching_Exception(local_remote_dir: tuple[str, str], stream

with pytest.raises(ValueError, match=f'Number of samples for stream*'):
# When we iterate through the dataloader, the samples will be partitioned.
# This should thow ValueError since stream 2 is too small to be included in each batch.
# This should throw ValueError since stream 2 is too small to be included in each batch.
for _ in dataloader:
continue

Expand Down

0 comments on commit a979454

Please sign in to comment.