Skip to content

Commit

Permalink
Fix import paths in simulation module (#838)
Browse files Browse the repository at this point in the history
Co-authored-by: Saaketh Narayan <saaketh.narayan@databricks.com>
  • Loading branch information
srstevenson and snarayan21 authored Dec 9, 2024
1 parent 67d0cbf commit e1e28f1
Show file tree
Hide file tree
Showing 13 changed files with 46 additions and 57 deletions.
11 changes: 6 additions & 5 deletions simulation/core/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
from typing import Generator, Union

import numpy as np
from core.node_tracker import NodeTracker
from core.shard_downloads import run_cache_limit, simulate_shard_downloads
from core.sim_dataset import SimulationDataset
from core.sim_time import Time
from core.utils import bytes_to_time, get_batches_epochs, time_to_bytes
from numpy.typing import NDArray

from simulation.core.node_tracker import NodeTracker
from simulation.core.shard_downloads import run_cache_limit, simulate_shard_downloads
from simulation.core.sim_dataset import SimulationDataset
from simulation.core.sim_time import Time
from simulation.core.utils import bytes_to_time, get_batches_epochs, time_to_bytes

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

Expand Down
4 changes: 2 additions & 2 deletions simulation/core/node_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from typing import Optional

import numpy as np
from core.last_used_ordered_set import LastUsedOrderedSet
from core.utils import remove_padded_samples
from numpy.typing import NDArray
from sortedcollections import OrderedSet

from simulation.core.last_used_ordered_set import LastUsedOrderedSet
from simulation.core.utils import remove_padded_samples
from streaming.base.spanner import Spanner


Expand Down
3 changes: 2 additions & 1 deletion simulation/core/shard_downloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from typing import Optional

import numpy as np
from core.node_tracker import NodeTracker
from numpy.typing import NDArray

from simulation.core.node_tracker import NodeTracker


def simulate_shard_downloads(node: NodeTracker,
raw_shard_sizes: NDArray[np.int64],
Expand Down
2 changes: 1 addition & 1 deletion simulation/core/shuffle_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import logging

import numpy as np
from core.utils import remove_padded_samples
from numpy.typing import NDArray

from simulation.core.utils import remove_padded_samples
from streaming.base.partition.orig import get_partitions_orig
from streaming.base.shuffle import get_shuffle

Expand Down
4 changes: 2 additions & 2 deletions simulation/core/sim_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from typing import Optional, Sequence, Union

import numpy as np
from core.sim_spanner import SimulationSpanner
from core.sim_world import SimulationWorld
from numpy.typing import NDArray

from simulation.core.sim_spanner import SimulationSpanner
from simulation.core.sim_world import SimulationWorld
from streaming.base import Stream, StreamingDataset
from streaming.base.batching import generate_work
from streaming.base.format import get_index_basename
Expand Down
5 changes: 3 additions & 2 deletions simulation/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
"""Peripheral functions for simulation functionality."""

import numpy as np
from core.sim_dataset import SimulationDataset
from core.sim_time import Time, TimeUnit
from numpy.typing import NDArray

from simulation.core.sim_dataset import SimulationDataset
from simulation.core.sim_time import Time, TimeUnit


def get_batches_epochs(dataset: SimulationDataset, max_duration: Time) -> tuple[int, int, int]:
"""Get batches per epoch, epochs, and total epochs from a Time object.
Expand Down
4 changes: 2 additions & 2 deletions simulation/core/yaml_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

from typing import Optional

from core.sim_dataset import SimulationDataset
from core.sim_time import Time, TimeUnit, ensure_time
from omegaconf import DictConfig
from omegaconf import OmegaConf as om

from simulation.core.sim_dataset import SimulationDataset
from simulation.core.sim_time import Time, TimeUnit, ensure_time
from streaming.base import Stream


Expand Down
2 changes: 1 addition & 1 deletion simulation/interfaces/interface_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
from typing import Optional

import numpy as np
from core.utils import get_rolling_avg_throughput
from numpy.typing import NDArray

from simulation.core.utils import get_rolling_avg_throughput
from streaming.base.util import number_abbrev_to_int


Expand Down
8 changes: 4 additions & 4 deletions simulation/interfaces/sim_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
import argparse

import humanize
from core.main import simulate
from core.utils import get_simulation_stats
from core.yaml_processing import create_simulation_dataset, ingest_yaml
from interfaces.interface_utils import plot_simulation

from simulation.core.main import simulate
from simulation.core.utils import get_simulation_stats
from simulation.core.yaml_processing import create_simulation_dataset, ingest_yaml
from simulation.interfaces.interface_utils import plot_simulation
from streaming.base.util import bytes_to_int

if __name__ == '__main__':
Expand Down
15 changes: 6 additions & 9 deletions simulation/interfaces/sim_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,17 @@
"""Script for simulating training downloads and throughput, and displaying results."""

import os.path
import sys

import humanize
from core.create_index import create_stream_index
from core.main import simulate
from core.sim_dataset import SimulationDataset
from core.sim_time import TimeUnit, ensure_time
from core.utils import get_simulation_stats
from interfaces.interface_utils import plot_simulation

from simulation.core.create_index import create_stream_index
from simulation.core.main import simulate
from simulation.core.sim_dataset import SimulationDataset
from simulation.core.sim_time import TimeUnit, ensure_time
from simulation.core.utils import get_simulation_stats
from simulation.interfaces.interface_utils import plot_simulation
from streaming.base import Stream

sys.path.append(os.path.join(os.path.dirname(__file__), '..'))

# Input Parameters

# dataset
Expand Down
23 changes: 10 additions & 13 deletions simulation/interfaces/sim_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import math
import os.path
import sys
from concurrent.futures import ProcessPoolExecutor
from io import StringIO
from typing import Union
Expand All @@ -14,21 +13,19 @@
import pandas as pd
import streamlit as st
import yaml
from core.create_index import create_stream_index
from core.main import simulate
from core.shuffle_quality import analyze_shuffle_quality_entropy
from core.sim_dataset import SimulationDataset
from core.sim_time import Time
from core.utils import get_total_batches
from core.yaml_processing import create_simulation_dataset, ingest_yaml
from interfaces.interface_utils import get_train_dataset_params
from interfaces.widgets import (display_shuffle_quality_graph, display_simulation_stats,
get_line_chart, param_inputs)

from simulation.core.create_index import create_stream_index
from simulation.core.main import simulate
from simulation.core.shuffle_quality import analyze_shuffle_quality_entropy
from simulation.core.sim_dataset import SimulationDataset
from simulation.core.sim_time import Time
from simulation.core.utils import get_total_batches
from simulation.core.yaml_processing import create_simulation_dataset, ingest_yaml
from simulation.interfaces.interface_utils import get_train_dataset_params
from simulation.interfaces.widgets import (display_shuffle_quality_graph, display_simulation_stats,
get_line_chart, param_inputs)
from streaming.base.util import bytes_to_int, number_abbrev_to_int

sys.path.append(os.path.join(os.path.dirname(__file__), '..'))

# set up page
st.set_page_config(layout='wide')
col1, space, col2 = st.columns((10, 1, 6))
Expand Down
8 changes: 2 additions & 6 deletions simulation/interfaces/widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,20 @@

"""Streamlit widgets for simulation web UI."""

import os.path
import sys
from concurrent.futures import Future
from typing import Optional

import altair as alt
import humanize
import pandas as pd
import streamlit as st
from core.sim_time import TimeUnit, ensure_time
from core.utils import get_simulation_stats
from numpy.typing import NDArray
from streamlit.delta_generator import DeltaGenerator

from simulation.core.sim_time import TimeUnit, ensure_time
from simulation.core.utils import get_simulation_stats
from streaming.base.util import bytes_to_int

sys.path.append(os.path.join(os.path.dirname(__file__), '..'))


def get_line_chart(data: pd.DataFrame,
throughput_window: int,
Expand Down
14 changes: 5 additions & 9 deletions simulation/testing/wandb_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,20 @@

"""Test simulation results against run results from a wandb project."""

import os.path
import sys

sys.path.append(os.path.join(os.path.dirname(__file__), '..'))

import logging
import os
import os.path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import wandb
from core.create_index import create_stream_index
from core.main import simulate
from core.sim_dataset import SimulationDataset
from core.sim_time import TimeUnit, ensure_time
from numpy.typing import NDArray

from simulation.core.create_index import create_stream_index
from simulation.core.main import simulate
from simulation.core.sim_dataset import SimulationDataset
from simulation.core.sim_time import TimeUnit, ensure_time
from streaming.base import Stream

logger = logging.getLogger(__name__)
Expand Down

0 comments on commit e1e28f1

Please sign in to comment.