Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

StreamingDataloader: Add profiling support #19338

Merged
merged 26 commits into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements/data/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ pytest-cov ==4.1.0
pytest-timeout ==2.1.0
pytest-rerunfailures ==12.0
pytest-random-order ==1.1.0
viztracer
122 changes: 121 additions & 1 deletion src/lightning/data/streaming/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,99 @@ def _get_iterator(self) -> "_BaseDataLoaderIter":
return _MultiProcessingDataLoaderIterPatch(self)


def _wrapper(fetcher: Any, func: Callable, tracer: Any, profile: int, profile_dir: str) -> Callable:
counter = 0

def wrap(*args: Any, **kwargs: Any) -> Any:
nonlocal counter
result = func(*args, **kwargs)

if tracer.enable and counter == profile:
tracer.stop()
tracer.save()
print(
f"Saved {os.path.join(profile_dir, 'result.json')} file after {profile} batches."
"Use chrome://tracing/ to view it."
)
fetcher.fetch = func

counter += 1
return result

return wrap


class _ProfileWorkerLoop:
"""Wrap the PyTorch DataLoader WorkerLoop to add profiling."""

def __init__(self, profile: Union[int, bool], profile_dir: Optional[str] = None):
self._profile = profile
self._profile_dir = profile_dir if profile_dir else os.getcwd()

def __call__(
self,
dataset_kind: Any,
dataset: Any,
index_queue: Any,
data_queue: Any,
done_event: Any,
auto_collation: Any,
collate_fn: Any,
drop_last: Any,
base_seed: Any,
init_fn: Any,
worker_id: Any,
*args: Any,
**kwargs: Any,
) -> None:
from torch.utils.data._utils import worker
from viztracer import VizTracer

if worker_id == 0:
output_file = os.path.join(self._profile_dir, "result.json")

if os.path.exists(output_file):
os.remove(output_file)

tracer = VizTracer(output_file=output_file, verbose=0)
tracer.start()

# Reload to remove the patching
reloaded_worker = reload(worker)
create_fetcher = _DatasetKind.create_fetcher
fetcher = None

def create_fetcher_fn(*args: Any, **kwargs: Any) -> "_BaseDatasetFetcher":
nonlocal fetcher
fetcher = create_fetcher(*args, **kwargs)

if worker_id == 0 and isinstance(self._profile, int):
fetcher.fetch = _wrapper(fetcher, fetcher.fetch, tracer, self._profile, self._profile_dir)
return fetcher

_DatasetKind.create_fetcher = create_fetcher_fn # type: ignore

reloaded_worker._worker_loop(
dataset_kind,
dataset,
index_queue,
data_queue,
done_event,
auto_collation,
collate_fn,
drop_last,
base_seed,
init_fn,
worker_id,
*args,
**kwargs,
)

if worker_id == 0 and isinstance(self._profile, bool):
tracer.stop()
tracer.save()


class _StreamingMultiProcessingDataLoaderIter(_MultiProcessingDataLoaderIter):
def __init__(self, loader: DataLoader) -> None:
self._loader = loader
Expand All @@ -355,6 +448,15 @@ def __init__(self, loader: DataLoader) -> None:
if self._loader._latest_worker_idx > 0
else []
)
self._num_workers = loader.num_workers

distributed_env = _DistributedEnv.detect()

if self._loader._profile_bactches and distributed_env.global_rank == 0 and _VIZ_TRACKER_AVAILABLE:
from torch.utils.data._utils import worker

worker._worker_loop = _ProfileWorkerLoop(self._loader._profile_bactches, self._loader._profile_dir)

super().__init__(loader)

def _try_put_index(self) -> None:
Expand Down Expand Up @@ -388,6 +490,9 @@ def __init__(
*args: Any,
batch_size: int = 1,
num_workers: int = 0,
profile_bactches: Union[bool, int] = False,
profile_dir: Optional[str] = None,
prefetch_factor: Optional[int] = None,
**kwargs: Any,
) -> None: # pyright: ignore
if not isinstance(dataset, (StreamingDataset, CombinedStreamingDataset)):
Expand All @@ -396,17 +501,32 @@ def __init__(
f" Found {dataset}."
)

if profile_bactches and not _VIZ_TRACKER_AVAILABLE:
raise ModuleNotFoundError("To use profile_bactches, viztracer is required. Run `pip install viztracer`")

if profile_bactches and num_workers == 0:
raise ValueError("Profiling is supported only with num_workers >= 1.")

self.current_epoch = 0
self.batch_size = batch_size
self.num_workers = num_workers
self._profile_bactches = profile_bactches
self._profile_dir = profile_dir
self._num_samples_yielded_streaming = 0
self._num_samples_yielded_combined: Dict[int, List[Any]] = {}
self.rng_state: Optional[Any] = None
self._worker_idx = cycle(list(range(self.num_workers if self.num_workers > 0 else 1)))
self._worker_idx_iter: Optional[Any] = None
self._latest_worker_idx = 0
self.restore = False
super().__init__(dataset, *args, batch_size=batch_size, num_workers=num_workers, **kwargs) # type: ignore
super().__init__(
dataset,
*args,
batch_size=batch_size,
num_workers=num_workers,
prefetch_factor=(10 if num_workers > 0 else None) if prefetch_factor is None else prefetch_factor,
**kwargs,
) # type: ignore

def __iter__(self) -> Any:
if not self.restore:
Expand Down
22 changes: 22 additions & 0 deletions tests/tests_data/streaming/test_dataloader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import os

import pytest
import torch
from lightning.data.streaming import CombinedStreamingDataset, StreamingDataLoader
from lightning.data.streaming import dataloader as streaming_dataloader_module
from torch import tensor


Expand Down Expand Up @@ -70,3 +74,21 @@ def test_streaming_dataloader():
"latest_worker_idx": 0,
"num_samples_yielded": {0: [11, 9]},
}


@pytest.mark.parametrize("profile", [2, True])
def test_dataloader_profiling(profile, tmpdir, monkeypatch):
monkeypatch.setattr(streaming_dataloader_module, "_VIZ_TRACKER_AVAILABLE", True)

dataset = TestCombinedStreamingDataset(
[TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], 42, weights=(0.5, 0.5)
)
dataloader = StreamingDataLoader(
dataset, batch_size=2, profile_bactches=profile, profile_dir=str(tmpdir), num_workers=1
)
dataloader_iter = iter(dataloader)
batches = []
for batch in dataloader_iter:
batches.append(batch)

assert os.path.exists(os.path.join(tmpdir, "result.json"))
Loading