From e1e28f1b4d5bac7994a3b6b8d91afc7901bff448 Mon Sep 17 00:00:00 2001 From: Scott Stevenson Date: Mon, 9 Dec 2024 20:37:15 +0000 Subject: [PATCH] Fix import paths in `simulation` module (#838) Co-authored-by: Saaketh Narayan --- simulation/core/main.py | 11 ++++++----- simulation/core/node_tracker.py | 4 ++-- simulation/core/shard_downloads.py | 3 ++- simulation/core/shuffle_quality.py | 2 +- simulation/core/sim_dataset.py | 4 ++-- simulation/core/utils.py | 5 +++-- simulation/core/yaml_processing.py | 4 ++-- simulation/interfaces/interface_utils.py | 2 +- simulation/interfaces/sim_cli.py | 8 ++++---- simulation/interfaces/sim_script.py | 15 ++++++--------- simulation/interfaces/sim_ui.py | 23 ++++++++++------------- simulation/interfaces/widgets.py | 8 ++------ simulation/testing/wandb_testing.py | 14 +++++--------- 13 files changed, 46 insertions(+), 57 deletions(-) diff --git a/simulation/core/main.py b/simulation/core/main.py index bf211421f..264233294 100644 --- a/simulation/core/main.py +++ b/simulation/core/main.py @@ -8,13 +8,14 @@ from typing import Generator, Union import numpy as np -from core.node_tracker import NodeTracker -from core.shard_downloads import run_cache_limit, simulate_shard_downloads -from core.sim_dataset import SimulationDataset -from core.sim_time import Time -from core.utils import bytes_to_time, get_batches_epochs, time_to_bytes from numpy.typing import NDArray +from simulation.core.node_tracker import NodeTracker +from simulation.core.shard_downloads import run_cache_limit, simulate_shard_downloads +from simulation.core.sim_dataset import SimulationDataset +from simulation.core.sim_time import Time +from simulation.core.utils import bytes_to_time, get_batches_epochs, time_to_bytes + logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) diff --git a/simulation/core/node_tracker.py b/simulation/core/node_tracker.py index 873f49175..0bddc7833 100644 --- a/simulation/core/node_tracker.py +++ b/simulation/core/node_tracker.py @@ -6,11 +6,11 @@ from typing import Optional import numpy as np -from core.last_used_ordered_set import LastUsedOrderedSet -from core.utils import remove_padded_samples from numpy.typing import NDArray from sortedcollections import OrderedSet +from simulation.core.last_used_ordered_set import LastUsedOrderedSet +from simulation.core.utils import remove_padded_samples from streaming.base.spanner import Spanner diff --git a/simulation/core/shard_downloads.py b/simulation/core/shard_downloads.py index 81cb22790..3ffb8b1ea 100644 --- a/simulation/core/shard_downloads.py +++ b/simulation/core/shard_downloads.py @@ -6,9 +6,10 @@ from typing import Optional import numpy as np -from core.node_tracker import NodeTracker from numpy.typing import NDArray +from simulation.core.node_tracker import NodeTracker + def simulate_shard_downloads(node: NodeTracker, raw_shard_sizes: NDArray[np.int64], diff --git a/simulation/core/shuffle_quality.py b/simulation/core/shuffle_quality.py index b6333d990..7d9cbef47 100644 --- a/simulation/core/shuffle_quality.py +++ b/simulation/core/shuffle_quality.py @@ -6,9 +6,9 @@ import logging import numpy as np -from core.utils import remove_padded_samples from numpy.typing import NDArray +from simulation.core.utils import remove_padded_samples from streaming.base.partition.orig import get_partitions_orig from streaming.base.shuffle import get_shuffle diff --git a/simulation/core/sim_dataset.py b/simulation/core/sim_dataset.py index 8d1fb176e..fe1036ea3 100644 --- a/simulation/core/sim_dataset.py +++ b/simulation/core/sim_dataset.py @@ -12,10 +12,10 @@ from typing import Optional, Sequence, Union import numpy as np -from core.sim_spanner import SimulationSpanner -from core.sim_world import SimulationWorld from numpy.typing import NDArray +from simulation.core.sim_spanner import SimulationSpanner +from simulation.core.sim_world import SimulationWorld from streaming.base import Stream, StreamingDataset from streaming.base.batching import generate_work from streaming.base.format import get_index_basename diff --git a/simulation/core/utils.py b/simulation/core/utils.py index 7759f6316..6284cc38b 100644 --- a/simulation/core/utils.py +++ b/simulation/core/utils.py @@ -4,10 +4,11 @@ """Peripheral functions for simulation functionality.""" import numpy as np -from core.sim_dataset import SimulationDataset -from core.sim_time import Time, TimeUnit from numpy.typing import NDArray +from simulation.core.sim_dataset import SimulationDataset +from simulation.core.sim_time import Time, TimeUnit + def get_batches_epochs(dataset: SimulationDataset, max_duration: Time) -> tuple[int, int, int]: """Get batches per epoch, epochs, and total epochs from a Time object. diff --git a/simulation/core/yaml_processing.py b/simulation/core/yaml_processing.py index 80d0e6927..b5a1a8512 100644 --- a/simulation/core/yaml_processing.py +++ b/simulation/core/yaml_processing.py @@ -5,11 +5,11 @@ from typing import Optional -from core.sim_dataset import SimulationDataset -from core.sim_time import Time, TimeUnit, ensure_time from omegaconf import DictConfig from omegaconf import OmegaConf as om +from simulation.core.sim_dataset import SimulationDataset +from simulation.core.sim_time import Time, TimeUnit, ensure_time from streaming.base import Stream diff --git a/simulation/interfaces/interface_utils.py b/simulation/interfaces/interface_utils.py index b4ca9f40c..a00eb1e7d 100644 --- a/simulation/interfaces/interface_utils.py +++ b/simulation/interfaces/interface_utils.py @@ -11,9 +11,9 @@ from typing import Optional import numpy as np -from core.utils import get_rolling_avg_throughput from numpy.typing import NDArray +from simulation.core.utils import get_rolling_avg_throughput from streaming.base.util import number_abbrev_to_int diff --git a/simulation/interfaces/sim_cli.py b/simulation/interfaces/sim_cli.py index 4fa4a8f55..d27fd39a4 100644 --- a/simulation/interfaces/sim_cli.py +++ b/simulation/interfaces/sim_cli.py @@ -11,11 +11,11 @@ import argparse import humanize -from core.main import simulate -from core.utils import get_simulation_stats -from core.yaml_processing import create_simulation_dataset, ingest_yaml -from interfaces.interface_utils import plot_simulation +from simulation.core.main import simulate +from simulation.core.utils import get_simulation_stats +from simulation.core.yaml_processing import create_simulation_dataset, ingest_yaml +from simulation.interfaces.interface_utils import plot_simulation from streaming.base.util import bytes_to_int if __name__ == '__main__': diff --git a/simulation/interfaces/sim_script.py b/simulation/interfaces/sim_script.py index 0515798d1..1bb9757d0 100644 --- a/simulation/interfaces/sim_script.py +++ b/simulation/interfaces/sim_script.py @@ -4,20 +4,17 @@ """Script for simulating training downloads and throughput, and displaying results.""" import os.path -import sys import humanize -from core.create_index import create_stream_index -from core.main import simulate -from core.sim_dataset import SimulationDataset -from core.sim_time import TimeUnit, ensure_time -from core.utils import get_simulation_stats -from interfaces.interface_utils import plot_simulation +from simulation.core.create_index import create_stream_index +from simulation.core.main import simulate +from simulation.core.sim_dataset import SimulationDataset +from simulation.core.sim_time import TimeUnit, ensure_time +from simulation.core.utils import get_simulation_stats +from simulation.interfaces.interface_utils import plot_simulation from streaming.base import Stream -sys.path.append(os.path.join(os.path.dirname(__file__), '..')) - # Input Parameters # dataset diff --git a/simulation/interfaces/sim_ui.py b/simulation/interfaces/sim_ui.py index 691c1c9cb..85b08ed47 100644 --- a/simulation/interfaces/sim_ui.py +++ b/simulation/interfaces/sim_ui.py @@ -5,7 +5,6 @@ import math import os.path -import sys from concurrent.futures import ProcessPoolExecutor from io import StringIO from typing import Union @@ -14,21 +13,19 @@ import pandas as pd import streamlit as st import yaml -from core.create_index import create_stream_index -from core.main import simulate -from core.shuffle_quality import analyze_shuffle_quality_entropy -from core.sim_dataset import SimulationDataset -from core.sim_time import Time -from core.utils import get_total_batches -from core.yaml_processing import create_simulation_dataset, ingest_yaml -from interfaces.interface_utils import get_train_dataset_params -from interfaces.widgets import (display_shuffle_quality_graph, display_simulation_stats, - get_line_chart, param_inputs) +from simulation.core.create_index import create_stream_index +from simulation.core.main import simulate +from simulation.core.shuffle_quality import analyze_shuffle_quality_entropy +from simulation.core.sim_dataset import SimulationDataset +from simulation.core.sim_time import Time +from simulation.core.utils import get_total_batches +from simulation.core.yaml_processing import create_simulation_dataset, ingest_yaml +from simulation.interfaces.interface_utils import get_train_dataset_params +from simulation.interfaces.widgets import (display_shuffle_quality_graph, display_simulation_stats, + get_line_chart, param_inputs) from streaming.base.util import bytes_to_int, number_abbrev_to_int -sys.path.append(os.path.join(os.path.dirname(__file__), '..')) - # set up page st.set_page_config(layout='wide') col1, space, col2 = st.columns((10, 1, 6)) diff --git a/simulation/interfaces/widgets.py b/simulation/interfaces/widgets.py index 760fcf415..529e3411e 100644 --- a/simulation/interfaces/widgets.py +++ b/simulation/interfaces/widgets.py @@ -3,8 +3,6 @@ """Streamlit widgets for simulation web UI.""" -import os.path -import sys from concurrent.futures import Future from typing import Optional @@ -12,15 +10,13 @@ import humanize import pandas as pd import streamlit as st -from core.sim_time import TimeUnit, ensure_time -from core.utils import get_simulation_stats from numpy.typing import NDArray from streamlit.delta_generator import DeltaGenerator +from simulation.core.sim_time import TimeUnit, ensure_time +from simulation.core.utils import get_simulation_stats from streaming.base.util import bytes_to_int -sys.path.append(os.path.join(os.path.dirname(__file__), '..')) - def get_line_chart(data: pd.DataFrame, throughput_window: int, diff --git a/simulation/testing/wandb_testing.py b/simulation/testing/wandb_testing.py index e2a98dc99..d4b3e5cca 100644 --- a/simulation/testing/wandb_testing.py +++ b/simulation/testing/wandb_testing.py @@ -3,24 +3,20 @@ """Test simulation results against run results from a wandb project.""" -import os.path -import sys - -sys.path.append(os.path.join(os.path.dirname(__file__), '..')) - import logging import os +import os.path import matplotlib.pyplot as plt import numpy as np import pandas as pd import wandb -from core.create_index import create_stream_index -from core.main import simulate -from core.sim_dataset import SimulationDataset -from core.sim_time import TimeUnit, ensure_time from numpy.typing import NDArray +from simulation.core.create_index import create_stream_index +from simulation.core.main import simulate +from simulation.core.sim_dataset import SimulationDataset +from simulation.core.sim_time import TimeUnit, ensure_time from streaming.base import Stream logger = logging.getLogger(__name__)