From afece98fc75dda744304137ef8c26b633371813c Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Wed, 24 Jan 2024 13:26:59 +0000 Subject: [PATCH 01/19] update --- src/lightning/data/streaming/constants.py | 1 + src/lightning/data/streaming/dataloader.py | 100 ++++++++++++++++++++ src/lightning/data/streaming/downloader.py | 49 ++++++---- src/lightning/data/streaming/item_loader.py | 4 +- src/lightning/data/streaming/reader.py | 5 + 5 files changed, 136 insertions(+), 23 deletions(-) diff --git a/src/lightning/data/streaming/constants.py b/src/lightning/data/streaming/constants.py index c020bb080d1df..9640f4552c416 100644 --- a/src/lightning/data/streaming/constants.py +++ b/src/lightning/data/streaming/constants.py @@ -29,6 +29,7 @@ _LIGHTNING_CLOUD_LATEST = RequirementCache("lightning-cloud>=0.5.61") _BOTO3_AVAILABLE = RequirementCache("boto3") + # DON'T CHANGE ORDER _TORCH_DTYPES_MAPPING = { 0: torch.float32, diff --git a/src/lightning/data/streaming/dataloader.py b/src/lightning/data/streaming/dataloader.py index 29fabd8754f8a..764731ce22bb1 100644 --- a/src/lightning/data/streaming/dataloader.py +++ b/src/lightning/data/streaming/dataloader.py @@ -346,6 +346,93 @@ def _get_iterator(self) -> "_BaseDataLoaderIter": self.check_worker_number_rationality() return _MultiProcessingDataLoaderIterPatch(self) +class StopRecordingException(Exception): + pass + + +def _wrapper(func, tracer, profile): + counter = 0 + has_stopped = False + def wrap(*args, **kwargs): + nonlocal counter + nonlocal has_stopped + result = func(*args, **kwargs) + + if not has_stopped and counter >= profile: + tracer.stop() + tracer.save() + raise StopRecordingException("The collection has terminated.") + has_stopped = True + + counter += 1 + return result + return wrap + + +class _ProfileWorkerLoop: + """Wrap the PyTorch DataLoader WorkerLoop to add profiling.""" + + def __init__(self, profile: Union[int, bool]): + self._profile = profile + + def __call__( + self, + dataset_kind: Any, + dataset: Any, + index_queue: Any, + data_queue: Any, + done_event: Any, + auto_collation: Any, + collate_fn: Any, + drop_last: Any, + base_seed: Any, + init_fn: Any, + worker_id: Any, + *args: Any, + **kwargs: Any, + ) -> None: + from torch.utils.data._utils import worker + + from viztracer import VizTracer + + tracer = VizTracer(output_file=os.path.join(os.getcwd(), "result.json")) + tracer.start() + + # Reload to remove the patching + reloaded_worker = reload(worker) + create_fetcher = _DatasetKind.create_fetcher + fetcher = None + + def create_fetcher_fn(*args: Any, **kwargs: Any) -> "_BaseDatasetFetcher": + nonlocal fetcher + fetcher = create_fetcher(*args, **kwargs) + + if isinstance(self._profile, int): + fetcher.fetch = _wrapper(fetcher.fetch, tracer, self._profile) + return fetcher + + _DatasetKind.create_fetcher = create_fetcher_fn # type: ignore + + reloaded_worker._worker_loop( + dataset_kind, + dataset, + index_queue, + data_queue, + done_event, + auto_collation, + collate_fn, + drop_last, + base_seed, + init_fn, + worker_id, + *args, + **kwargs, + ) + + if isinstance(self._profile, bool): + tracer.stop() + tracer.save() + class _StreamingMultiProcessingDataLoaderIter(_MultiProcessingDataLoaderIter): def __init__(self, loader: DataLoader) -> None: @@ -355,6 +442,14 @@ def __init__(self, loader: DataLoader) -> None: if self._loader._latest_worker_idx > 0 else [] ) + self._num_workers = loader.num_workers + + distributed_env = _DistributedEnv.detect() + + if self._loader._profile_bactches and distributed_env.global_rank == 0: + from torch.utils.data._utils import worker + worker._worker_loop = _ProfileWorkerLoop(self._loader._profile_bactches) + super().__init__(loader) def _try_put_index(self) -> None: @@ -388,6 +483,7 @@ def __init__( *args: Any, batch_size: int = 1, num_workers: int = 0, + profile_bactches: Union[bool, int] = False, **kwargs: Any, ) -> None: # pyright: ignore if not isinstance(dataset, (StreamingDataset, CombinedStreamingDataset)): @@ -396,9 +492,13 @@ def __init__( f" Found {dataset}." ) + if profile_bactches and not _VIZ_TRACKER_AVAILABLE: + raise ModuleNotFoundError("To use profile_bactches, viztracer is required. Run `pip install viztracer`") + self.current_epoch = 0 self.batch_size = batch_size self.num_workers = num_workers + self._profile_bactches = profile_bactches self._num_samples_yielded_streaming = 0 self._num_samples_yielded_combined: Dict[int, List[Any]] = {} self.rng_state: Optional[Any] = None diff --git a/src/lightning/data/streaming/downloader.py b/src/lightning/data/streaming/downloader.py index b9097c843e66e..9662ec339327b 100644 --- a/src/lightning/data/streaming/downloader.py +++ b/src/lightning/data/streaming/downloader.py @@ -15,9 +15,9 @@ from abc import ABC from typing import Any, Dict, List from urllib import parse - +import subprocess from filelock import FileLock, Timeout - +from lightning.data.streaming.constants import _BOTO3_AVAILABLE from lightning.data.streaming.client import S3Client @@ -40,7 +40,10 @@ def download_file(self, remote_chunkpath: str, local_chunkpath: str) -> None: class S3Downloader(Downloader): def __init__(self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]]): super().__init__(remote_dir, cache_dir, chunks) - self._client = S3Client() + self._s5cmd_available = os.system("s5cmd > /dev/null 2>&1") == 0 + + if not self._s5cmd_available: + self._client = S3Client() def download_file(self, remote_filepath: str, local_filepath: str) -> None: obj = parse.urlparse(remote_filepath) @@ -48,24 +51,28 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None: if obj.scheme != "s3": raise ValueError(f"Expected obj.scheme to be `s3`, instead, got {obj.scheme} for remote={remote_filepath}") - from boto3.s3.transfer import TransferConfig - - extra_args: Dict[str, Any] = {} - - try: - with FileLock(local_filepath + ".lock", timeout=1): - if not os.path.exists(local_filepath): - # Issue: https://github.com/boto/boto3/issues/3113 - self._client.client.download_file( - obj.netloc, - obj.path.lstrip("/"), - local_filepath, - ExtraArgs=extra_args, - Config=TransferConfig(use_threads=False), - ) - except Timeout: - # another process is responsible to download that file, continue - pass + if self._s5cmd_available: + proc = subprocess.Popen(f"s5cmd cp {remote_filepath} {local_filepath}", shell=True, stdout=subprocess.PIPE) + proc.wait() + else: + from boto3.s3.transfer import TransferConfig + + extra_args: Dict[str, Any] = {} + + # try: + # with FileLock(local_filepath + ".lock", timeout=1): + if not os.path.exists(local_filepath): + # Issue: https://github.com/boto/boto3/issues/3113 + self._client.client.download_file( + obj.netloc, + obj.path.lstrip("/"), + local_filepath, + ExtraArgs=extra_args, + Config=TransferConfig(use_threads=False), + ) + # except Timeout: + # # another process is responsible to download that file, continue + # pass class LocalDownloader(Downloader): diff --git a/src/lightning/data/streaming/item_loader.py b/src/lightning/data/streaming/item_loader.py index 779a683146182..4ddb4e4cd5526 100644 --- a/src/lightning/data/streaming/item_loader.py +++ b/src/lightning/data/streaming/item_loader.py @@ -90,7 +90,7 @@ def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size > 0 while not exists: - sleep(0.1) + sleep(0.01) exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size > 0 self._chunk_filepaths[chunk_filepath] = True @@ -188,7 +188,7 @@ def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size > 0 while not exists: - sleep(0.1) + sleep(0.01) exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size > 0 self._chunk_filepaths[chunk_filepath] = True diff --git a/src/lightning/data/streaming/reader.py b/src/lightning/data/streaming/reader.py index 50705cb18f663..946ba916cf1fd 100644 --- a/src/lightning/data/streaming/reader.py +++ b/src/lightning/data/streaming/reader.py @@ -71,6 +71,11 @@ def __init__( # Check whether a dataset slice fits on the node num_bytes_per_nodes = self._config.num_bytes // self._distributed_env.num_nodes self._delete_chunks_when_processed = num_bytes_per_nodes > max_cache_size if max_cache_size else False + + # When the dataset slice fits on the node, we don't need to wait for downloading the chunks. + if not self._delete_chunks_when_processed: + self._max_pre_download = 10e7 + self._has_exited = False def download(self, chunk_indexes: List[int]) -> None: From a1e5e30f9cfca84545d438eafe45265b6d8433e3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 24 Jan 2024 13:54:02 +0000 Subject: [PATCH 02/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/data/streaming/dataloader.py | 19 +++++++++++-------- src/lightning/data/streaming/downloader.py | 7 +++---- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/lightning/data/streaming/dataloader.py b/src/lightning/data/streaming/dataloader.py index 764731ce22bb1..d51f3b90fbc7d 100644 --- a/src/lightning/data/streaming/dataloader.py +++ b/src/lightning/data/streaming/dataloader.py @@ -346,27 +346,30 @@ def _get_iterator(self) -> "_BaseDataLoaderIter": self.check_worker_number_rationality() return _MultiProcessingDataLoaderIterPatch(self) + class StopRecordingException(Exception): pass -def _wrapper(func, tracer, profile): +def _wrapper(func, tracer, profile): counter = 0 has_stopped = False - def wrap(*args, **kwargs): + + def wrap(*args, **kwargs): nonlocal counter nonlocal has_stopped - result = func(*args, **kwargs) + result = func(*args, **kwargs) if not has_stopped and counter >= profile: tracer.stop() tracer.save() raise StopRecordingException("The collection has terminated.") has_stopped = True - + counter += 1 - return result - return wrap + return result + + return wrap class _ProfileWorkerLoop: @@ -392,7 +395,6 @@ def __call__( **kwargs: Any, ) -> None: from torch.utils.data._utils import worker - from viztracer import VizTracer tracer = VizTracer(output_file=os.path.join(os.getcwd(), "result.json")) @@ -443,11 +445,12 @@ def __init__(self, loader: DataLoader) -> None: else [] ) self._num_workers = loader.num_workers - + distributed_env = _DistributedEnv.detect() if self._loader._profile_bactches and distributed_env.global_rank == 0: from torch.utils.data._utils import worker + worker._worker_loop = _ProfileWorkerLoop(self._loader._profile_bactches) super().__init__(loader) diff --git a/src/lightning/data/streaming/downloader.py b/src/lightning/data/streaming/downloader.py index 9662ec339327b..92c2f66b649c1 100644 --- a/src/lightning/data/streaming/downloader.py +++ b/src/lightning/data/streaming/downloader.py @@ -12,12 +12,11 @@ # limitations under the License. import os import shutil +import subprocess from abc import ABC from typing import Any, Dict, List from urllib import parse -import subprocess -from filelock import FileLock, Timeout -from lightning.data.streaming.constants import _BOTO3_AVAILABLE + from lightning.data.streaming.client import S3Client @@ -42,7 +41,7 @@ def __init__(self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]] super().__init__(remote_dir, cache_dir, chunks) self._s5cmd_available = os.system("s5cmd > /dev/null 2>&1") == 0 - if not self._s5cmd_available: + if not self._s5cmd_available: self._client = S3Client() def download_file(self, remote_filepath: str, local_filepath: str) -> None: From eed6bfe5b33797f2cba9318a9664282244d6bcc0 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Wed, 24 Jan 2024 13:55:19 +0000 Subject: [PATCH 03/19] update --- _notebooks | 1 - src/lightning/data/streaming/constants.py | 1 - src/lightning/data/streaming/downloader.py | 49 ++++++++++++---------- 3 files changed, 27 insertions(+), 24 deletions(-) delete mode 160000 _notebooks diff --git a/_notebooks b/_notebooks deleted file mode 160000 index 543a8d8200662..0000000000000 --- a/_notebooks +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 543a8d82006620906dc9eb669eab18d06ebe6863 diff --git a/src/lightning/data/streaming/constants.py b/src/lightning/data/streaming/constants.py index 9640f4552c416..c020bb080d1df 100644 --- a/src/lightning/data/streaming/constants.py +++ b/src/lightning/data/streaming/constants.py @@ -29,7 +29,6 @@ _LIGHTNING_CLOUD_LATEST = RequirementCache("lightning-cloud>=0.5.61") _BOTO3_AVAILABLE = RequirementCache("boto3") - # DON'T CHANGE ORDER _TORCH_DTYPES_MAPPING = { 0: torch.float32, diff --git a/src/lightning/data/streaming/downloader.py b/src/lightning/data/streaming/downloader.py index 9662ec339327b..5ad312d159554 100644 --- a/src/lightning/data/streaming/downloader.py +++ b/src/lightning/data/streaming/downloader.py @@ -51,28 +51,33 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None: if obj.scheme != "s3": raise ValueError(f"Expected obj.scheme to be `s3`, instead, got {obj.scheme} for remote={remote_filepath}") - if self._s5cmd_available: - proc = subprocess.Popen(f"s5cmd cp {remote_filepath} {local_filepath}", shell=True, stdout=subprocess.PIPE) - proc.wait() - else: - from boto3.s3.transfer import TransferConfig - - extra_args: Dict[str, Any] = {} - - # try: - # with FileLock(local_filepath + ".lock", timeout=1): - if not os.path.exists(local_filepath): - # Issue: https://github.com/boto/boto3/issues/3113 - self._client.client.download_file( - obj.netloc, - obj.path.lstrip("/"), - local_filepath, - ExtraArgs=extra_args, - Config=TransferConfig(use_threads=False), - ) - # except Timeout: - # # another process is responsible to download that file, continue - # pass + if os.path.exists(local_filepath): + return + + try: + with FileLock(local_filepath + ".lock", timeout=1): + if self._s5cmd_available: + proc = subprocess.Popen(f"s5cmd cp {remote_filepath} {local_filepath}", shell=True, stdout=subprocess.PIPE) + proc.wait() + else: + from boto3.s3.transfer import TransferConfig + + extra_args: Dict[str, Any] = {} + + # try: + # with FileLock(local_filepath + ".lock", timeout=1): + if not os.path.exists(local_filepath): + # Issue: https://github.com/boto/boto3/issues/3113 + self._client.client.download_file( + obj.netloc, + obj.path.lstrip("/"), + local_filepath, + ExtraArgs=extra_args, + Config=TransferConfig(use_threads=False), + ) + except Timeout: + # another process is responsible to download that file, continue + pass class LocalDownloader(Downloader): From 739981235ad5ad77376ea69344e48da4eb4bd114 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 24 Jan 2024 13:58:19 +0000 Subject: [PATCH 04/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/data/streaming/downloader.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/lightning/data/streaming/downloader.py b/src/lightning/data/streaming/downloader.py index e509332f8c64d..ed41c48857716 100644 --- a/src/lightning/data/streaming/downloader.py +++ b/src/lightning/data/streaming/downloader.py @@ -56,7 +56,9 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None: try: with FileLock(local_filepath + ".lock", timeout=1): if self._s5cmd_available: - proc = subprocess.Popen(f"s5cmd cp {remote_filepath} {local_filepath}", shell=True, stdout=subprocess.PIPE) + proc = subprocess.Popen( + f"s5cmd cp {remote_filepath} {local_filepath}", shell=True, stdout=subprocess.PIPE + ) proc.wait() else: from boto3.s3.transfer import TransferConfig From a68c2bf054284ae7a9b7396ad5600b4283c690f0 Mon Sep 17 00:00:00 2001 From: thomas Date: Wed, 24 Jan 2024 14:01:13 +0000 Subject: [PATCH 05/19] update --- src/lightning/data/streaming/constants.py | 1 - src/lightning/data/streaming/downloader.py | 58 ++++++++++++---------- 2 files changed, 33 insertions(+), 26 deletions(-) diff --git a/src/lightning/data/streaming/constants.py b/src/lightning/data/streaming/constants.py index 9640f4552c416..c020bb080d1df 100644 --- a/src/lightning/data/streaming/constants.py +++ b/src/lightning/data/streaming/constants.py @@ -29,7 +29,6 @@ _LIGHTNING_CLOUD_LATEST = RequirementCache("lightning-cloud>=0.5.61") _BOTO3_AVAILABLE = RequirementCache("boto3") - # DON'T CHANGE ORDER _TORCH_DTYPES_MAPPING = { 0: torch.float32, diff --git a/src/lightning/data/streaming/downloader.py b/src/lightning/data/streaming/downloader.py index 9662ec339327b..33c03e254c40b 100644 --- a/src/lightning/data/streaming/downloader.py +++ b/src/lightning/data/streaming/downloader.py @@ -12,12 +12,13 @@ # limitations under the License. import os import shutil +import subprocess from abc import ABC from typing import Any, Dict, List from urllib import parse -import subprocess + from filelock import FileLock, Timeout -from lightning.data.streaming.constants import _BOTO3_AVAILABLE + from lightning.data.streaming.client import S3Client @@ -42,7 +43,7 @@ def __init__(self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]] super().__init__(remote_dir, cache_dir, chunks) self._s5cmd_available = os.system("s5cmd > /dev/null 2>&1") == 0 - if not self._s5cmd_available: + if not self._s5cmd_available: self._client = S3Client() def download_file(self, remote_filepath: str, local_filepath: str) -> None: @@ -51,28 +52,35 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None: if obj.scheme != "s3": raise ValueError(f"Expected obj.scheme to be `s3`, instead, got {obj.scheme} for remote={remote_filepath}") - if self._s5cmd_available: - proc = subprocess.Popen(f"s5cmd cp {remote_filepath} {local_filepath}", shell=True, stdout=subprocess.PIPE) - proc.wait() - else: - from boto3.s3.transfer import TransferConfig - - extra_args: Dict[str, Any] = {} - - # try: - # with FileLock(local_filepath + ".lock", timeout=1): - if not os.path.exists(local_filepath): - # Issue: https://github.com/boto/boto3/issues/3113 - self._client.client.download_file( - obj.netloc, - obj.path.lstrip("/"), - local_filepath, - ExtraArgs=extra_args, - Config=TransferConfig(use_threads=False), - ) - # except Timeout: - # # another process is responsible to download that file, continue - # pass + if os.path.exists(local_filepath): + return + + try: + with FileLock(local_filepath + ".lock", timeout=0): + if self._s5cmd_available: + proc = subprocess.Popen( + f"s5cmd cp {remote_filepath} {local_filepath}", shell=True, stdout=subprocess.PIPE + ) + proc.wait() + else: + from boto3.s3.transfer import TransferConfig + + extra_args: Dict[str, Any] = {} + + # try: + # with FileLock(local_filepath + ".lock", timeout=1): + if not os.path.exists(local_filepath): + # Issue: https://github.com/boto/boto3/issues/3113 + self._client.client.download_file( + obj.netloc, + obj.path.lstrip("/"), + local_filepath, + ExtraArgs=extra_args, + Config=TransferConfig(use_threads=False), + ) + except Timeout: + # another process is responsible to download that file, continue + pass class LocalDownloader(Downloader): From 41f11ccb22bf021a1f3542844e15656af8fbe5f7 Mon Sep 17 00:00:00 2001 From: thomas Date: Wed, 24 Jan 2024 14:04:13 +0000 Subject: [PATCH 06/19] update --- _notebooks | 1 + 1 file changed, 1 insertion(+) create mode 160000 _notebooks diff --git a/_notebooks b/_notebooks new file mode 160000 index 0000000000000..543a8d8200662 --- /dev/null +++ b/_notebooks @@ -0,0 +1 @@ +Subproject commit 543a8d82006620906dc9eb669eab18d06ebe6863 From 01197be35926812d1017156e8683c0c0e3bbc010 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Wed, 24 Jan 2024 14:38:45 +0000 Subject: [PATCH 07/19] update --- src/lightning/data/streaming/dataloader.py | 28 ++++++++++++---------- src/lightning/data/streaming/downloader.py | 4 ++-- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/src/lightning/data/streaming/dataloader.py b/src/lightning/data/streaming/dataloader.py index d51f3b90fbc7d..32f2f0489576f 100644 --- a/src/lightning/data/streaming/dataloader.py +++ b/src/lightning/data/streaming/dataloader.py @@ -42,7 +42,7 @@ from lightning.data.streaming.constants import _DEFAULT_CHUNK_BYTES, _TORCH_GREATER_EQUAL_2_1_0, _VIZ_TRACKER_AVAILABLE from lightning.data.streaming.dataset import StreamingDataset from lightning.data.streaming.sampler import CacheBatchSampler -from lightning.data.utilities.env import _DistributedEnv +from lightning.data.utilities.env import _DistributedEnv, _WorkerEnv if _TORCH_GREATER_EQUAL_2_1_0: from torch.utils._pytree import tree_flatten @@ -351,20 +351,18 @@ class StopRecordingException(Exception): pass -def _wrapper(func, tracer, profile): +def _wrapper(fetcher, func, tracer, profile): counter = 0 - has_stopped = False def wrap(*args, **kwargs): nonlocal counter - nonlocal has_stopped result = func(*args, **kwargs) - if not has_stopped and counter >= profile: + if tracer.enable and counter == profile: tracer.stop() tracer.save() - raise StopRecordingException("The collection has terminated.") - has_stopped = True + print(f"Saved result.json file after {profile} batches.") + fetcher.fetch = func counter += 1 return result @@ -397,8 +395,14 @@ def __call__( from torch.utils.data._utils import worker from viztracer import VizTracer - tracer = VizTracer(output_file=os.path.join(os.getcwd(), "result.json")) - tracer.start() + if worker_id == 0: + output_file = os.path.join(os.getcwd(), "result.json") + + if os.path.exists(output_file): + os.remove(output_file) + + tracer = VizTracer(output_file=output_file, verbose=0) + tracer.start() # Reload to remove the patching reloaded_worker = reload(worker) @@ -409,8 +413,8 @@ def create_fetcher_fn(*args: Any, **kwargs: Any) -> "_BaseDatasetFetcher": nonlocal fetcher fetcher = create_fetcher(*args, **kwargs) - if isinstance(self._profile, int): - fetcher.fetch = _wrapper(fetcher.fetch, tracer, self._profile) + if worker_id == 0 and isinstance(self._profile, int): + fetcher.fetch = _wrapper(fetcher, fetcher.fetch, tracer, self._profile) return fetcher _DatasetKind.create_fetcher = create_fetcher_fn # type: ignore @@ -431,7 +435,7 @@ def create_fetcher_fn(*args: Any, **kwargs: Any) -> "_BaseDatasetFetcher": **kwargs, ) - if isinstance(self._profile, bool): + if worker_id == 0 and isinstance(self._profile, bool): tracer.stop() tracer.save() diff --git a/src/lightning/data/streaming/downloader.py b/src/lightning/data/streaming/downloader.py index e509332f8c64d..5146ef6e94882 100644 --- a/src/lightning/data/streaming/downloader.py +++ b/src/lightning/data/streaming/downloader.py @@ -16,7 +16,7 @@ from abc import ABC from typing import Any, Dict, List from urllib import parse - +from filelock import FileLock, Timeout from lightning.data.streaming.client import S3Client @@ -54,7 +54,7 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None: return try: - with FileLock(local_filepath + ".lock", timeout=1): + with FileLock(local_filepath + ".lock", timeout=0): if self._s5cmd_available: proc = subprocess.Popen(f"s5cmd cp {remote_filepath} {local_filepath}", shell=True, stdout=subprocess.PIPE) proc.wait() From 20293a3f0a868e90cbf6efb86ef15a2165b9dc7b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 24 Jan 2024 14:40:35 +0000 Subject: [PATCH 08/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/data/streaming/dataloader.py | 2 +- src/lightning/data/streaming/downloader.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/lightning/data/streaming/dataloader.py b/src/lightning/data/streaming/dataloader.py index 32f2f0489576f..aef66bbd343bc 100644 --- a/src/lightning/data/streaming/dataloader.py +++ b/src/lightning/data/streaming/dataloader.py @@ -42,7 +42,7 @@ from lightning.data.streaming.constants import _DEFAULT_CHUNK_BYTES, _TORCH_GREATER_EQUAL_2_1_0, _VIZ_TRACKER_AVAILABLE from lightning.data.streaming.dataset import StreamingDataset from lightning.data.streaming.sampler import CacheBatchSampler -from lightning.data.utilities.env import _DistributedEnv, _WorkerEnv +from lightning.data.utilities.env import _DistributedEnv if _TORCH_GREATER_EQUAL_2_1_0: from torch.utils._pytree import tree_flatten diff --git a/src/lightning/data/streaming/downloader.py b/src/lightning/data/streaming/downloader.py index ebf4a2e9dc79a..33c03e254c40b 100644 --- a/src/lightning/data/streaming/downloader.py +++ b/src/lightning/data/streaming/downloader.py @@ -16,7 +16,9 @@ from abc import ABC from typing import Any, Dict, List from urllib import parse + from filelock import FileLock, Timeout + from lightning.data.streaming.client import S3Client From dc191200a88789ebedfb7d57590e08b83a8a2124 Mon Sep 17 00:00:00 2001 From: thomas Date: Wed, 24 Jan 2024 14:43:22 +0000 Subject: [PATCH 09/19] update --- src/lightning/data/streaming/dataloader.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/lightning/data/streaming/dataloader.py b/src/lightning/data/streaming/dataloader.py index aef66bbd343bc..491d6bc4ec2ee 100644 --- a/src/lightning/data/streaming/dataloader.py +++ b/src/lightning/data/streaming/dataloader.py @@ -347,10 +347,6 @@ def _get_iterator(self) -> "_BaseDataLoaderIter": return _MultiProcessingDataLoaderIterPatch(self) -class StopRecordingException(Exception): - pass - - def _wrapper(fetcher, func, tracer, profile): counter = 0 @@ -361,7 +357,7 @@ def wrap(*args, **kwargs): if tracer.enable and counter == profile: tracer.stop() tracer.save() - print(f"Saved result.json file after {profile} batches.") + print(f"Saved {os.path.join(os.getcwd(), 'result.json')} file after {profile} batches.") fetcher.fetch = func counter += 1 From da9665295fdeb88b72375d31900b8b64e717c072 Mon Sep 17 00:00:00 2001 From: thomas Date: Wed, 24 Jan 2024 14:44:26 +0000 Subject: [PATCH 10/19] update --- src/lightning/data/streaming/dataloader.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/lightning/data/streaming/dataloader.py b/src/lightning/data/streaming/dataloader.py index 491d6bc4ec2ee..56653bca744a9 100644 --- a/src/lightning/data/streaming/dataloader.py +++ b/src/lightning/data/streaming/dataloader.py @@ -357,7 +357,10 @@ def wrap(*args, **kwargs): if tracer.enable and counter == profile: tracer.stop() tracer.save() - print(f"Saved {os.path.join(os.getcwd(), 'result.json')} file after {profile} batches.") + print( + f"Saved {os.path.join(os.getcwd(), 'result.json')} file after {profile} batches." + "Use chrome://tracing/ to view it." + ) fetcher.fetch = func counter += 1 From b213091a3ef70bd4c1f83826ec5b1da64d2de5e3 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Wed, 24 Jan 2024 15:49:45 +0000 Subject: [PATCH 11/19] update --- src/lightning/data/streaming/dataloader.py | 10 +++++++++- src/lightning/data/streaming/downloader.py | 2 +- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/lightning/data/streaming/dataloader.py b/src/lightning/data/streaming/dataloader.py index 32f2f0489576f..4be1f23ecfe01 100644 --- a/src/lightning/data/streaming/dataloader.py +++ b/src/lightning/data/streaming/dataloader.py @@ -491,6 +491,7 @@ def __init__( batch_size: int = 1, num_workers: int = 0, profile_bactches: Union[bool, int] = False, + prefetch_factor: Optional[int] = None, **kwargs: Any, ) -> None: # pyright: ignore if not isinstance(dataset, (StreamingDataset, CombinedStreamingDataset)): @@ -513,7 +514,14 @@ def __init__( self._worker_idx_iter: Optional[Any] = None self._latest_worker_idx = 0 self.restore = False - super().__init__(dataset, *args, batch_size=batch_size, num_workers=num_workers, **kwargs) # type: ignore + super().__init__( + dataset, + *args, + batch_size=batch_size, + num_workers=num_workers, + prefetch_factor=10 if num_workers > 0 else 0, + **kwargs + ) # type: ignore def __iter__(self) -> Any: if not self.restore: diff --git a/src/lightning/data/streaming/downloader.py b/src/lightning/data/streaming/downloader.py index ebf4a2e9dc79a..0ea4e49538da9 100644 --- a/src/lightning/data/streaming/downloader.py +++ b/src/lightning/data/streaming/downloader.py @@ -57,7 +57,7 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None: with FileLock(local_filepath + ".lock", timeout=0): if self._s5cmd_available: proc = subprocess.Popen( - f"s5cmd cp {remote_filepath} {local_filepath}", shell=True, stdout=subprocess.PIPE + f"s5cmd --numworkers 64 cp {remote_filepath} {local_filepath}", shell=True, stdout=subprocess.PIPE ) proc.wait() else: From bcd34805403722c7a2ae06656129495a494792f6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 24 Jan 2024 15:51:35 +0000 Subject: [PATCH 12/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/data/streaming/dataloader.py | 4 ++-- src/lightning/data/streaming/downloader.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/lightning/data/streaming/dataloader.py b/src/lightning/data/streaming/dataloader.py index 99491871e8fc4..8a043484ab4a5 100644 --- a/src/lightning/data/streaming/dataloader.py +++ b/src/lightning/data/streaming/dataloader.py @@ -518,8 +518,8 @@ def __init__( *args, batch_size=batch_size, num_workers=num_workers, - prefetch_factor=10 if num_workers > 0 else 0, - **kwargs + prefetch_factor=10 if num_workers > 0 else 0, + **kwargs, ) # type: ignore def __iter__(self) -> Any: diff --git a/src/lightning/data/streaming/downloader.py b/src/lightning/data/streaming/downloader.py index f3ae7d6a542dd..d55a57db6fb46 100644 --- a/src/lightning/data/streaming/downloader.py +++ b/src/lightning/data/streaming/downloader.py @@ -59,7 +59,9 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None: with FileLock(local_filepath + ".lock", timeout=0): if self._s5cmd_available: proc = subprocess.Popen( - f"s5cmd --numworkers 64 cp {remote_filepath} {local_filepath}", shell=True, stdout=subprocess.PIPE + f"s5cmd --numworkers 64 cp {remote_filepath} {local_filepath}", + shell=True, + stdout=subprocess.PIPE, ) proc.wait() else: From 0af3fa5716fd03be9bb859d95d4907b8c875fa43 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Wed, 24 Jan 2024 16:34:45 +0000 Subject: [PATCH 13/19] update --- src/lightning/data/streaming/item_loader.py | 4 ++-- src/lightning/data/streaming/reader.py | 4 ---- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/lightning/data/streaming/item_loader.py b/src/lightning/data/streaming/item_loader.py index 4ddb4e4cd5526..779a683146182 100644 --- a/src/lightning/data/streaming/item_loader.py +++ b/src/lightning/data/streaming/item_loader.py @@ -90,7 +90,7 @@ def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size > 0 while not exists: - sleep(0.01) + sleep(0.1) exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size > 0 self._chunk_filepaths[chunk_filepath] = True @@ -188,7 +188,7 @@ def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size > 0 while not exists: - sleep(0.01) + sleep(0.1) exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size > 0 self._chunk_filepaths[chunk_filepath] = True diff --git a/src/lightning/data/streaming/reader.py b/src/lightning/data/streaming/reader.py index 946ba916cf1fd..181994d5022bb 100644 --- a/src/lightning/data/streaming/reader.py +++ b/src/lightning/data/streaming/reader.py @@ -72,10 +72,6 @@ def __init__( num_bytes_per_nodes = self._config.num_bytes // self._distributed_env.num_nodes self._delete_chunks_when_processed = num_bytes_per_nodes > max_cache_size if max_cache_size else False - # When the dataset slice fits on the node, we don't need to wait for downloading the chunks. - if not self._delete_chunks_when_processed: - self._max_pre_download = 10e7 - self._has_exited = False def download(self, chunk_indexes: List[int]) -> None: From 8e810655f069a886e80b707921d817610d843590 Mon Sep 17 00:00:00 2001 From: thomas Date: Wed, 24 Jan 2024 16:39:23 +0000 Subject: [PATCH 14/19] update --- src/lightning/data/streaming/reader.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/lightning/data/streaming/reader.py b/src/lightning/data/streaming/reader.py index 181994d5022bb..50705cb18f663 100644 --- a/src/lightning/data/streaming/reader.py +++ b/src/lightning/data/streaming/reader.py @@ -71,7 +71,6 @@ def __init__( # Check whether a dataset slice fits on the node num_bytes_per_nodes = self._config.num_bytes // self._distributed_env.num_nodes self._delete_chunks_when_processed = num_bytes_per_nodes > max_cache_size if max_cache_size else False - self._has_exited = False def download(self, chunk_indexes: List[int]) -> None: From c613d951273ed273a97ba2ce5a1b388b452aa573 Mon Sep 17 00:00:00 2001 From: thomas Date: Wed, 24 Jan 2024 16:44:21 +0000 Subject: [PATCH 15/19] update --- src/lightning/data/streaming/downloader.py | 43 +++++++--------------- 1 file changed, 13 insertions(+), 30 deletions(-) diff --git a/src/lightning/data/streaming/downloader.py b/src/lightning/data/streaming/downloader.py index d55a57db6fb46..b9097c843e66e 100644 --- a/src/lightning/data/streaming/downloader.py +++ b/src/lightning/data/streaming/downloader.py @@ -12,7 +12,6 @@ # limitations under the License. import os import shutil -import subprocess from abc import ABC from typing import Any, Dict, List from urllib import parse @@ -41,10 +40,7 @@ def download_file(self, remote_chunkpath: str, local_chunkpath: str) -> None: class S3Downloader(Downloader): def __init__(self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]]): super().__init__(remote_dir, cache_dir, chunks) - self._s5cmd_available = os.system("s5cmd > /dev/null 2>&1") == 0 - - if not self._s5cmd_available: - self._client = S3Client() + self._client = S3Client() def download_file(self, remote_filepath: str, local_filepath: str) -> None: obj = parse.urlparse(remote_filepath) @@ -52,34 +48,21 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None: if obj.scheme != "s3": raise ValueError(f"Expected obj.scheme to be `s3`, instead, got {obj.scheme} for remote={remote_filepath}") - if os.path.exists(local_filepath): - return + from boto3.s3.transfer import TransferConfig + + extra_args: Dict[str, Any] = {} try: - with FileLock(local_filepath + ".lock", timeout=0): - if self._s5cmd_available: - proc = subprocess.Popen( - f"s5cmd --numworkers 64 cp {remote_filepath} {local_filepath}", - shell=True, - stdout=subprocess.PIPE, + with FileLock(local_filepath + ".lock", timeout=1): + if not os.path.exists(local_filepath): + # Issue: https://github.com/boto/boto3/issues/3113 + self._client.client.download_file( + obj.netloc, + obj.path.lstrip("/"), + local_filepath, + ExtraArgs=extra_args, + Config=TransferConfig(use_threads=False), ) - proc.wait() - else: - from boto3.s3.transfer import TransferConfig - - extra_args: Dict[str, Any] = {} - - # try: - # with FileLock(local_filepath + ".lock", timeout=1): - if not os.path.exists(local_filepath): - # Issue: https://github.com/boto/boto3/issues/3113 - self._client.client.download_file( - obj.netloc, - obj.path.lstrip("/"), - local_filepath, - ExtraArgs=extra_args, - Config=TransferConfig(use_threads=False), - ) except Timeout: # another process is responsible to download that file, continue pass From 29a905be374bfc1c99ceaa01494bad7c1730fb62 Mon Sep 17 00:00:00 2001 From: thomas Date: Wed, 24 Jan 2024 17:08:16 +0000 Subject: [PATCH 16/19] update --- src/lightning/data/streaming/dataloader.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lightning/data/streaming/dataloader.py b/src/lightning/data/streaming/dataloader.py index 8a043484ab4a5..57c5425a506dc 100644 --- a/src/lightning/data/streaming/dataloader.py +++ b/src/lightning/data/streaming/dataloader.py @@ -347,10 +347,10 @@ def _get_iterator(self) -> "_BaseDataLoaderIter": return _MultiProcessingDataLoaderIterPatch(self) -def _wrapper(fetcher, func, tracer, profile): +def _wrapper(fetcher: Any, func: Callable, tracer: Any, profile: int) -> Callable: counter = 0 - def wrap(*args, **kwargs): + def wrap(*args: Any, **kwargs: Any) -> Any: nonlocal counter result = func(*args, **kwargs) @@ -518,7 +518,7 @@ def __init__( *args, batch_size=batch_size, num_workers=num_workers, - prefetch_factor=10 if num_workers > 0 else 0, + prefetch_factor=10 if num_workers > 0 else None, **kwargs, ) # type: ignore From 46a1da3917dc86aa3d8ce3afd6f50e800f287275 Mon Sep 17 00:00:00 2001 From: thomas Date: Wed, 24 Jan 2024 17:35:21 +0000 Subject: [PATCH 17/19] update --- requirements/data/test.txt | 1 + src/lightning/data/streaming/dataloader.py | 19 ++++++++++-------- tests/tests_data/streaming/test_dataloader.py | 20 +++++++++++++++++++ 3 files changed, 32 insertions(+), 8 deletions(-) diff --git a/requirements/data/test.txt b/requirements/data/test.txt index d30343b08a628..38439e2d6705a 100644 --- a/requirements/data/test.txt +++ b/requirements/data/test.txt @@ -4,3 +4,4 @@ pytest-cov ==4.1.0 pytest-timeout ==2.1.0 pytest-rerunfailures ==12.0 pytest-random-order ==1.1.0 +viztracer diff --git a/src/lightning/data/streaming/dataloader.py b/src/lightning/data/streaming/dataloader.py index 57c5425a506dc..090976afb9104 100644 --- a/src/lightning/data/streaming/dataloader.py +++ b/src/lightning/data/streaming/dataloader.py @@ -347,7 +347,7 @@ def _get_iterator(self) -> "_BaseDataLoaderIter": return _MultiProcessingDataLoaderIterPatch(self) -def _wrapper(fetcher: Any, func: Callable, tracer: Any, profile: int) -> Callable: +def _wrapper(fetcher: Any, func: Callable, tracer: Any, profile: int, profile_dir: str) -> Callable: counter = 0 def wrap(*args: Any, **kwargs: Any) -> Any: @@ -358,7 +358,7 @@ def wrap(*args: Any, **kwargs: Any) -> Any: tracer.stop() tracer.save() print( - f"Saved {os.path.join(os.getcwd(), 'result.json')} file after {profile} batches." + f"Saved {os.path.join(profile_dir, 'result.json')} file after {profile} batches." "Use chrome://tracing/ to view it." ) fetcher.fetch = func @@ -372,8 +372,9 @@ def wrap(*args: Any, **kwargs: Any) -> Any: class _ProfileWorkerLoop: """Wrap the PyTorch DataLoader WorkerLoop to add profiling.""" - def __init__(self, profile: Union[int, bool]): + def __init__(self, profile: Union[int, bool], profile_dir: Optional[str] = None): self._profile = profile + self._profile_dir = profile_dir if profile_dir else os.getcwd() def __call__( self, @@ -395,7 +396,7 @@ def __call__( from viztracer import VizTracer if worker_id == 0: - output_file = os.path.join(os.getcwd(), "result.json") + output_file = os.path.join(self._profile_dir, "result.json") if os.path.exists(output_file): os.remove(output_file) @@ -413,7 +414,7 @@ def create_fetcher_fn(*args: Any, **kwargs: Any) -> "_BaseDatasetFetcher": fetcher = create_fetcher(*args, **kwargs) if worker_id == 0 and isinstance(self._profile, int): - fetcher.fetch = _wrapper(fetcher, fetcher.fetch, tracer, self._profile) + fetcher.fetch = _wrapper(fetcher, fetcher.fetch, tracer, self._profile, self._profile_dir) return fetcher _DatasetKind.create_fetcher = create_fetcher_fn # type: ignore @@ -451,10 +452,10 @@ def __init__(self, loader: DataLoader) -> None: distributed_env = _DistributedEnv.detect() - if self._loader._profile_bactches and distributed_env.global_rank == 0: + if self._loader._profile_bactches and distributed_env.global_rank == 0 and _VIZ_TRACKER_AVAILABLE: from torch.utils.data._utils import worker - worker._worker_loop = _ProfileWorkerLoop(self._loader._profile_bactches) + worker._worker_loop = _ProfileWorkerLoop(self._loader._profile_bactches, self._loader._profile_dir) super().__init__(loader) @@ -490,6 +491,7 @@ def __init__( batch_size: int = 1, num_workers: int = 0, profile_bactches: Union[bool, int] = False, + profile_dir: Optional[str] = None, prefetch_factor: Optional[int] = None, **kwargs: Any, ) -> None: # pyright: ignore @@ -506,6 +508,7 @@ def __init__( self.batch_size = batch_size self.num_workers = num_workers self._profile_bactches = profile_bactches + self._profile_dir = profile_dir self._num_samples_yielded_streaming = 0 self._num_samples_yielded_combined: Dict[int, List[Any]] = {} self.rng_state: Optional[Any] = None @@ -518,7 +521,7 @@ def __init__( *args, batch_size=batch_size, num_workers=num_workers, - prefetch_factor=10 if num_workers > 0 else None, + prefetch_factor=(10 if num_workers > 0 else None) if prefetch_factor is None else prefetch_factor, **kwargs, ) # type: ignore diff --git a/tests/tests_data/streaming/test_dataloader.py b/tests/tests_data/streaming/test_dataloader.py index 6fdd7387db251..dd7d055bc2d8e 100644 --- a/tests/tests_data/streaming/test_dataloader.py +++ b/tests/tests_data/streaming/test_dataloader.py @@ -1,5 +1,8 @@ +import os + import torch from lightning.data.streaming import CombinedStreamingDataset, StreamingDataLoader +from lightning.data.streaming import dataloader as streaming_dataloader_module from torch import tensor @@ -70,3 +73,20 @@ def test_streaming_dataloader(): "latest_worker_idx": 0, "num_samples_yielded": {0: [11, 9]}, } + + +def test_dataloader_profiling(tmpdir, monkeypatch): + monkeypatch.setattr(streaming_dataloader_module, "_VIZ_TRACKER_AVAILABLE", True) + + dataset = TestCombinedStreamingDataset( + [TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], 42, weights=(0.5, 0.5) + ) + dataloader = StreamingDataLoader( + dataset, batch_size=2, profile_bactches=True, profile_dir=str(tmpdir), num_workers=1 + ) + dataloader_iter = iter(dataloader) + batches = [] + for batch in dataloader_iter: + batches.append(batch) + + assert os.path.exists(os.path.join(tmpdir, "result.json")) From 3ab5893c91bbfaa32b19680d143ae64773b4d2e4 Mon Sep 17 00:00:00 2001 From: thomas Date: Wed, 24 Jan 2024 17:36:27 +0000 Subject: [PATCH 18/19] update --- src/lightning/data/streaming/dataloader.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/lightning/data/streaming/dataloader.py b/src/lightning/data/streaming/dataloader.py index 090976afb9104..acd0ebef19af6 100644 --- a/src/lightning/data/streaming/dataloader.py +++ b/src/lightning/data/streaming/dataloader.py @@ -504,6 +504,9 @@ def __init__( if profile_bactches and not _VIZ_TRACKER_AVAILABLE: raise ModuleNotFoundError("To use profile_bactches, viztracer is required. Run `pip install viztracer`") + if profile_bactches and num_workers == 0: + raise ValueError("Profiling is supported only with num_workers >= 1.") + self.current_epoch = 0 self.batch_size = batch_size self.num_workers = num_workers From f1c3e58a383834ee3a25e234a0882736c6274137 Mon Sep 17 00:00:00 2001 From: thomas Date: Wed, 24 Jan 2024 17:39:20 +0000 Subject: [PATCH 19/19] update --- tests/tests_data/streaming/test_dataloader.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/tests_data/streaming/test_dataloader.py b/tests/tests_data/streaming/test_dataloader.py index dd7d055bc2d8e..293a96636adae 100644 --- a/tests/tests_data/streaming/test_dataloader.py +++ b/tests/tests_data/streaming/test_dataloader.py @@ -1,5 +1,6 @@ import os +import pytest import torch from lightning.data.streaming import CombinedStreamingDataset, StreamingDataLoader from lightning.data.streaming import dataloader as streaming_dataloader_module @@ -75,14 +76,15 @@ def test_streaming_dataloader(): } -def test_dataloader_profiling(tmpdir, monkeypatch): +@pytest.mark.parametrize("profile", [2, True]) +def test_dataloader_profiling(profile, tmpdir, monkeypatch): monkeypatch.setattr(streaming_dataloader_module, "_VIZ_TRACKER_AVAILABLE", True) dataset = TestCombinedStreamingDataset( [TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], 42, weights=(0.5, 0.5) ) dataloader = StreamingDataLoader( - dataset, batch_size=2, profile_bactches=True, profile_dir=str(tmpdir), num_workers=1 + dataset, batch_size=2, profile_bactches=profile, profile_dir=str(tmpdir), num_workers=1 ) dataloader_iter = iter(dataloader) batches = []