From bf42f18a8ebe600107621bacdfc19b164960456b Mon Sep 17 00:00:00 2001 From: DejunL <142548330+DejunL@users.noreply.github.com> Date: Fri, 30 Aug 2024 10:01:03 -0700 Subject: [PATCH] LightningDataModule for webdataset (#100) This implements a LightningDataModule of webdataset called WebDataModule. It takes a set of webdataset tar files and various webdataset config settings as input and setups the WebDataset and WebLoader to be used later by the Lightning.Trainer workflows. This implements another LightningDataModule PickledDataWDS that inherits from the aforementioned WebDataModule that allows the user to experiment with different train/val/test splits of the input pickled data to be used in creating the webdataset tar files. Add tests, docstring and README for the above See PR#100 for details --- launch.sh | 4 +- sub-packages/bionemo-webdatamodule/LICENSE | 202 +++++++ sub-packages/bionemo-webdatamodule/README.md | 353 +++++++++++++ .../bionemo-webdatamodule/pyproject.toml | 32 ++ .../bionemo-webdatamodule/requirements.txt | 1 + .../src/bionemo/webdatamodule/__init__.py | 14 + .../src/bionemo/webdatamodule/datamodule.py | 491 ++++++++++++++++++ .../src/bionemo/webdatamodule/utils.py | 130 +++++ .../tests/bionemo/webdatamodule/conftest.py | 301 +++++++++++ .../bionemo/webdatamodule/test_datamodule.py | 268 ++++++++++ 10 files changed, 1794 insertions(+), 2 deletions(-) create mode 100644 sub-packages/bionemo-webdatamodule/LICENSE create mode 100644 sub-packages/bionemo-webdatamodule/README.md create mode 100644 sub-packages/bionemo-webdatamodule/pyproject.toml create mode 100644 sub-packages/bionemo-webdatamodule/requirements.txt create mode 100644 sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/__init__.py create mode 100644 sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/datamodule.py create mode 100644 sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/utils.py create mode 100644 sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py create mode 100644 sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_datamodule.py diff --git a/launch.sh b/launch.sh index 69a1aff00..34d9e5855 100755 --- a/launch.sh +++ b/launch.sh @@ -103,7 +103,7 @@ build() { } # Check Docker version - docker_version=$(docker --version | grep -oE '[0-9]+\.[0-9]+\.[0-9]+') + docker_version=$(docker --version | awk -F'[, ]' '{print $3}') required_docker_version="23.0.1" if ! version_ge "$docker_version" "$required_docker_version"; then @@ -112,7 +112,7 @@ build() { fi # Check Buildx version - buildx_version=$(docker buildx version | grep -oE '[0-9]+\.[0-9]+\.[0-9]+') + buildx_version=$(docker buildx version | awk '{print $2}') required_buildx_version="0.10.2" if ! version_ge "$buildx_version" "$required_buildx_version"; then diff --git a/sub-packages/bionemo-webdatamodule/LICENSE b/sub-packages/bionemo-webdatamodule/LICENSE new file mode 100644 index 000000000..d64569567 --- /dev/null +++ b/sub-packages/bionemo-webdatamodule/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/sub-packages/bionemo-webdatamodule/README.md b/sub-packages/bionemo-webdatamodule/README.md new file mode 100644 index 000000000..b06442c66 --- /dev/null +++ b/sub-packages/bionemo-webdatamodule/README.md @@ -0,0 +1,353 @@ +# bionemo-webdatamodule + +To install, execute the following: +```bash +pip install -e . +``` + +To run unit tests, execute: +```bash +pytest -v . +``` + +## WebDataModule + +```python +class WebDataModule(L.LightningDataModule) +``` + +A LightningDataModule for using webdataset tar files to setup dataset and +dataloader. This data module takes as input a dictionary: Split -> tar file +directory and vaiours webdataset config settings. In its setup() function, it +creates the webdataset object chaining up the input `pipeline_wds` workflow. In +its train/val/test_dataloader(), it creates the WebLoader object chaining up the +`pipeline_prebatch_wld` workflow + +Examples +-------- + +1. create the data module with input directory to webdataset tar files. +Depending on which of the downstream Lightning.Trainer methods are called, +e.g., `Trainer.fit()`, `Trainer.validate()`, `Trainer.test()` or +`Trainer.predict()`, only a subset of the train, val and test splits need to +be specified in the various input options to the data module: + +- `Trainer.fit()` requires the `train` and `val` splits +- `Trainer.validate()` requires the `val` split +- `Trainer.test()` requires the `test` splits +- `Trainer.predict()` requires the `test` splits + +Here is an example of constructing the data module for `Trainer.fit()`: +``` +>>> from bionemo.webdatamodule.datamodule import Split, WebDataModule +>>> +>>> tar_file_prefix = "shards" +>>> +>>> dirs_of_tar_files = { +>>> Split.train: "/path/to/train/split/tars", +>>> Split.val: "/path/to/val/split/tars", +>>> } +>>> +>>> n_samples { +>>> Split.train: 1000, +>>> Split.val: 100, +>>> } +>>> +>>> # this is the string to retrieve the corresponding data object from the +>>> # webdataset file (see +>>> # https://github.com/webdataset/webdataset?tab=readme-ov-file#the-webdataset-format +>>> # for details) +>>> suffix_keys_wds = "tensor.pyd" +>>> +>>> # see the API doc for the definition of global_batch_size +>>> global_batch_size = 16 +>>> +>>> seed = 27193781 +>>> +>>> # Specify the routines to process the samples in the WebDataset object. +>>> # The routine is a generator of an Iterable of generators that are chained +>>> # together by nested function calling. The following is equivalent of +>>> # defining a overall generator of `shuffle(untuple(...))` which +>>> # untuples the samples and shuffles them. See webdataset's Documentation +>>> # for details. +>>> # NOTE: the `untuple` is almost always necessary due to the webdataset's +>>> # file parsing rule. +>>> +>>> untuple = lambda source : (sample for (sample,) in source) +>>> +>>> from webdatast import shuffle +>>> pipeline_wds = { +>>> Split.train : [untuple, shuffle(n_samples[Split.train], +>>> rng=random.Random(seed_rng_shfl))], +>>> Split.val: untuple +>>> } +>>> +>>> # Similarly the user can optionally define the processing routine on the +>>> # WebLoader (the dataloader of webdataset). +>>> # NOTE: these routines by default take unbatched sample as input so the +>>> # user can customize their batching routines here +>>> +>>> batch = batched(local_batch_size, collation_fn=lambda + list_samples : torch.vstack(list_samples)) +>>> pipeline_prebatch_wld = { + Split.train: [shuffle(n_samples[Split.train], + rng=random.Random(seed_rng_shfl)), batch], + Split.val : batch, + Split.test : batch + } +>>> +>>> # the user can optionally specify the kwargs for WebDataset and +>>> # WebLoader +>>> +>>> kwargs_wds = { +>>> split : {'shardshuffle' : split == Split.train, +>>> 'nodesplitter' : wds.split_by_node, +>>> 'seed' : seed_rng_shfl} +>>> for split in Split +>>> } +>>> +>>> kwargs_wld = { +>>> split : {"num_workers": 2} for split in Split +>>> } +>>> +>>> # construct the data module +>>> data_module = WebDataModule(dirs_of_tar_files, n_samples, suffix_keys_wds, + global_batch_size, + prefix_tars_wds=tar_file_prefix, + pipeline_wds=pipeline_wds, + pipeline_prebatch_wld=pipeline_prebatch_wld, + kwargs_wds=kwargs_wds, + kwargs_wld=kwargs_wld) +``` + + + +#### \_\_init\_\_ + +```python +def __init__( + dirs_tars_wds: Dict[Split, str], + n_samples: Dict[Split, int], + suffix_keys_wds: Union[str, Iterable[str]], + global_batch_size: int, + prefix_tars_wds: str = "wdshards", + pipeline_wds: Optional[Dict[Split, Union[Iterable[Iterable[Any]], + Iterable[Any]]]] = None, + pipeline_prebatch_wld: Optional[Dict[Split, + Union[Iterable[Iterable[Any]], + Iterable[Any]]]] = None, + kwargs_wds: Optional[Dict[Split, Dict[str, Any]]] = None, + kwargs_wld: Optional[Dict[Split, Dict[str, Any]]] = None) +``` + +constructor + +**Arguments**: + +- `dirs_tars_wds` _Dict[Split, str]_ - input dictionary: Split -> tar file + directory that contains the webdataset tar files for each split +- `n_samples` _Dict[Split, int]_ - input dictionary: Split -> number of + data samples for each split +- `suffix_keys_wds` _Union[str, Iterable[str]]_ - a set of keys each + corresponding to a data object in the webdataset tar file + dictionary. The data objects of these keys will be extracted and + tupled for each sample in the tar files +- `global_batch_size` _int_ - size of batch summing across nodes in Data + Distributed Parallel, i.e., local_batch_size * n_nodes. NOTE: + this data module doesn't rely on the input `global_batch_size` + for batching the samples. The batching is supposed to be done as + a part of the input `pipeline_prebatch_wld`. `global_batch_size` + is only used to compute a (pseudo-) epoch length for the data + loader so that the loader yield approximately n_samples // + global_batch_size batches + Kwargs: +- `prefix_tars_wds` _str_ - name prefix of the input webdataset tar + files. The input tar files are globbed by + "{dirs_tars_wds[split]}/{prefix_tars_wds}-*.tar" + pipeline_wds (Optional[Dict[Split, Union[Iterable[Iterable[Any]], +- `Iterable[Any]]]])` - a dictionary of webdatast composable, i.e., + functor that maps a iterator to another iterator that + transforms the data sample yield from the dataset object, for + different splits, or an iterable to such a sequence of such + iterators. For example, this can be used to transform the + sample in the worker before sending it to the main process of + the dataloader + pipeline_prebatch_wld (Optional[Dict[Split, + Union[Iterable[Iterable[Any]], Iterable[Any]]]]): a dictionary + of webloader composable, i.e., functor that maps a iterator to + another iterator that transforms the data sample yield from the + WebLoader object, for different splits, or an iterable to a + seuqnence of such iterators. For example, this can be used for + batching the samples. NOTE: this is applied before batching is + yield from the WebLoader +- `kwargs_wds` _Optional[Dict[Split, Dict[str, Any]]]_ - kwargs for the + WebDataset.__init__() +- `kwargs_wld` _Optional[Dict[Split, Dict[str, Any]]]_ - kwargs for the + WebLoader.__init__(), e.g., num_workers, of each split + + + +#### prepare\_data + +```python +def prepare_data() -> None +``` + +This is called only by the main process by the Lightning workflow. Do +not rely on this data module object's state update here as there is no +way to communicate the state update to other subprocesses. + +Returns: None + + + +#### setup + +```python +def setup(stage: str) -> None +``` + +This is called on all Lightning-managed nodes in a multi-node +training session + + +**Arguments**: + +- `stage` _str_ - "fit", "test" or "predict" +- `Returns` - None + +## PickledDataWDS + +```python +class PickledDataWDS(WebDataModule) +``` + +A LightningDataModule to process pickled data into webdataset tar files +and setup dataset and dataloader. This inherits the webdataset setup from +its parent module `WebDataModule`. This data module takes a directory of +pickled data files, data filename prefixes for train/val/test splits, data +filename suffixes and prepare webdataset tar files by globbing the specific +pickle data files `{dir_pickles}/{name_subset[split]}.{suffix_pickles}` and +outputing to webdataset tar file with the dict structure: +``` + {"__key__" : name.replace(".", "-"), + suffix_pickles : pickled.dumps(data) } +``` +NOTE: this assumes only one pickled file is processed for each sample. In +its setup() function, it creates the webdataset object chaining up the input +`pipeline_wds` workflow. In its train/val/test_dataloader(), it creates the +WebLoader object chaining up the `pipeline_prebatch_wld` workflow. + +Examples +-------- + +1. create the data module with a directory of pickle files and the file name +prefix thereof for different splits to used by `Lightning.Trainer.fit()` + +``` +>>> from bionemo.webdatamodule.datamodule import Split, PickledDataWDS + +>>> dir_pickles = "/path/to/my/pickles/dir" + +>>> # the following will use `sample1.mydata.pt` and `sample2.mydata.pt` as the +>>> # training dataset and `sample4.mydata.pt` and `sample5.mydata.pt` as the +>>> # validation dataset + +>>> suffix_pickles = "mydata.pt" + +>>> names_subset = { +>>> Split.train: [sample1, sample2], +>>> Split.val: [sample4, sample5], +>>> } + +>>> # the following setting will attempt to create at least 5 tar files in +>>> # `/path/to/output/tars/dir/myshards-00000{0-5}.tar` + +>>> n_tars_wds = 5 +>>> prefix_tars_wds = "myshards" +>>> output_dir_tar_files = "/path/to/output/tars/dir" + +>>> # see the `WebDataModule` API doc for the definition of global_batch_size +>>> global_batch_size = 16 + +>>> # user can optionally customize the data processing routines and kwargs used +>>> # in the WebDataset and WebLoader (see the examples in `WebDataModule`) + +>>> pipeline_wds = { Split.train: ... } + +>>> pipeline_prebatch_wld = { Split.train: ... } + +>>> kwargs_wds = { Split.train: ..., Split.val: ... } + +>>> kwargs_wld = { Split.train: ..., Split.val: ... } + +>>> # create the data module +>>> data_module = PickledDataWDS( +>>> dir_pickles, +>>> suffix_pickles, +>>> names_subset, +>>> output_dir_tar_files, +>>> global_batch_size, # `WebDataModule` args +>>> n_tars_wds=n_tars_wds, +>>> prefix_tars_wds=prefix_tars_wds, # `WebDataModule` kwargs +>>> pipeline_wds=pipeline_wds, # `WebDataModule` kwargs +>>> pipeline_prebatch_wld=pipelines_wdl_batch, # `WebDataModule` kwargs +>>> kwargs_wds=kwargs_wds, # `WebDataModule` kwargs +>>> kwargs_wld=kwargs_wld, # `WebDataModule` kwargs +>>> ) + +``` + + + +#### \_\_init\_\_ + +```python +def __init__(dir_pickles: str, + suffix_pickles: str, + names_subset: Dict[Split, List[str]], + prefix_dir_tars_wds: str, + *args, + n_tars_wds: Optional[int] = None, + **kwargs) +``` + +constructor + +**Arguments**: + +- `dir_pickles` _str_ - input directory of pickled data files +- `suffix_pickles` _str_ - filename suffix of the input data in + dir_pickles. This is also used as the key mapped to the + tarballed pickled object in the webdataset +- `names_subset` _Dict[Split, List[str]]_ - list of filename prefix of + the data samples to be loaded in the dataset and dataloader for + each of the split +- `prefix_dir_tars_wds` _str_ - directory name prefix to store the output + webdataset tar files. The actual directories storing the train, val + and test sets will be suffixed with "train", "val" and "test" + respectively. +- `*args` - arguments passed to the parent WebDataModule + + Kwargs: +- `n_tars_wds` _int_ - attempt to create at least this number of + webdataset shards +- `**kwargs` - arguments passed to the parent WebDataModule + + + +#### prepare\_data + +```python +def prepare_data() -> None +``` + +This is called only by the main process by the Lightning workflow. Do +not rely on this data module object's state update here as there is no +way to communicate the state update to other subprocesses. The nesting +`pickles_to_tars` function goes through the data name prefixes in the +different splits, read the corresponding pickled file and output a +webdataset tar archive with the dict structure: {"__key__" : +name.replace(".", "-"), suffix_pickles : pickled.dumps(data) }. + +Returns: None diff --git a/sub-packages/bionemo-webdatamodule/pyproject.toml b/sub-packages/bionemo-webdatamodule/pyproject.toml new file mode 100644 index 000000000..30a44da5c --- /dev/null +++ b/sub-packages/bionemo-webdatamodule/pyproject.toml @@ -0,0 +1,32 @@ +[build-system] +requires = ["setuptools", "wheel"] +build-backend = "setuptools.build_meta" + +# For guidance, see: https://packaging.python.org/en/latest/guides/writing-pyproject-toml/ +[project] +name = "bionemo-webdatamodule" +version = "0.0.1" +authors = [ + { name = "Dejun Lin", email = "dejunl@nvidia.com" }, +] +description = "" +readme = "README.md" +requires-python = ">=3.10" +keywords = [] +license = {file = "LICENSE"} +classifiers = [ + "Programming Language :: Python :: 3.10", + "Private :: Do Not Upload", +] +dynamic = ["dependencies"] + +[project.optional-dependencies] +test = [ + "pytest", +] + +[tool.setuptools.dynamic] +dependencies = {file = ["requirements.txt"]} + +[tool.ruff] +lint.ignore = ["C901", "E741", "E501", "E731"] diff --git a/sub-packages/bionemo-webdatamodule/requirements.txt b/sub-packages/bionemo-webdatamodule/requirements.txt new file mode 100644 index 000000000..24ef528b0 --- /dev/null +++ b/sub-packages/bionemo-webdatamodule/requirements.txt @@ -0,0 +1 @@ +webdataset==0.2.96 diff --git a/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/__init__.py b/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/__init__.py new file mode 100644 index 000000000..25e6abfbc --- /dev/null +++ b/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/datamodule.py b/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/datamodule.py new file mode 100644 index 000000000..33fa7936b --- /dev/null +++ b/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/datamodule.py @@ -0,0 +1,491 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import glob +from enum import Enum, auto +from typing import Any, Dict, Iterable, List, Optional, Union, get_args + +import lightning as L +import webdataset as wds + +from bionemo.webdatamodule.utils import pickles_to_tars + + +class Split(Enum): + train = auto() + val = auto() + test = auto() + + +class WebDataModule(L.LightningDataModule): + """A LightningDataModule for using webdataset tar files to setup dataset and + dataloader. This data module takes as input a dictionary: Split -> tar file + directory and vaiours webdataset config settings. In its setup() function, + it creates the webdataset object chaining up the input `pipeline_wds` + workflow. In its train/val/test_dataloader(), it creates the WebLoader + object chaining up the `pipeline_prebatch_wld` workflow + + Examples + -------- + + 1. create the data module with input directory to webdataset tar files. + Depending on which of the downstream Lightning.Trainer methods are called, + e.g., `Trainer.fit()`, `Trainer.validate()`, `Trainer.test()` or + `Trainer.predict()`, only a subset of the train, val and test splits need to + be specified in the various input options to the data module: + + - `Trainer.fit()` requires the `train` and `val` splits + - `Trainer.validate()` requires the `val` split + - `Trainer.test()` requires the `test` splits + - `Trainer.predict()` requires the `test` splits + + Here is an example of constructing the data module for `Trainer.fit()`: + ``` + >>> from bionemo.core.data.datamodule import Split, WebDataModule + >>> + >>> tar_file_prefix = "shards" + >>> + >>> dirs_of_tar_files = { + >>> Split.train: "/path/to/train/split/tars", + >>> Split.val: "/path/to/val/split/tars", + >>> } + >>> + >>> n_samples { + >>> Split.train: 1000, + >>> Split.val: 100, + >>> } + >>> + >>> # this is the string to retrieve the corresponding data object from the + >>> # webdataset file (see + >>> # https://github.com/webdataset/webdataset?tab=readme-ov-file#the-webdataset-format + >>> # for details) + >>> suffix_keys_wds = "tensor.pyd" + >>> + >>> # see the API doc for the definition of global_batch_size + >>> global_batch_size = 16 + >>> + >>> seed = 27193781 + >>> + >>> # Specify the routines to process the samples in the WebDataset object. + >>> # The routine is a generator of an Iterable of generators that are chained + >>> # together by nested function calling. The following is equivalent of + >>> # defining a overall generator of `shuffle(untuple(...))` which + >>> # untuples the samples and shuffles them. See webdataset's Documentation + >>> # for details. + >>> # NOTE: the `untuple` is almost always necessary due to the webdataset's + >>> # file parsing rule. + >>> + >>> untuple = lambda source : (sample for (sample,) in source) + >>> + >>> from webdatast import shuffle + >>> pipeline_wds = { + >>> Split.train : [untuple, shuffle(n_samples[Split.train], + >>> rng=random.Random(seed_rng_shfl))], + >>> Split.val: untuple + >>> } + >>> + >>> # Similarly the user can optionally define the processing routine on the + >>> # WebLoader (the dataloader of webdataset). + >>> # NOTE: these routines by default take unbatched sample as input so the + >>> # user can customize their batching routines here + >>> + >>> batch = batched(local_batch_size, collation_fn=lambda + list_samples : torch.vstack(list_samples)) + >>> pipeline_prebatch_wld = { + Split.train: [shuffle(n_samples[Split.train], + rng=random.Random(seed_rng_shfl)), batch], + Split.val : batch, + Split.test : batch + } + >>> + >>> # the user can optionally specify the kwargs for WebDataset and + >>> # WebLoader + >>> + >>> kwargs_wds = { + >>> split : {'shardshuffle' : split == Split.train, + >>> 'nodesplitter' : wds.split_by_node, + >>> 'seed' : seed_rng_shfl} + >>> for split in Split + >>> } + >>> + >>> kwargs_wld = { + >>> split : {"num_workers": 2} for split in Split + >>> } + >>> + >>> # construct the data module + >>> data_module = WebDataModule(n_samples, suffix_keys_wds, + dirs_of_tar_files, global_batch_size, + prefix_tars_wds=tar_file_prefix, + pipeline_wds=pipeline_wds, + pipeline_prebatch_wld=pipeline_prebatch_wld, + kwargs_wds=kwargs_wds, + kwargs_wld=kwargs_wld) + ``` + + """ + + def __init__( + self, + n_samples: Dict[Split, int], + suffix_keys_wds: Union[str, Iterable[str]], + dirs_tars_wds: Dict[Split, str], + global_batch_size: int, + prefix_tars_wds: str = "wdshards", + pipeline_wds: Optional[ + Dict[Split, Union[Iterable[Iterable[Any]], Iterable[Any]]] + ] = None, + pipeline_prebatch_wld: Optional[ + Dict[Split, Union[Iterable[Iterable[Any]], Iterable[Any]]] + ] = None, + kwargs_wds: Optional[Dict[Split, Dict[str, Any]]] = None, + kwargs_wld: Optional[Dict[Split, Dict[str, Any]]] = None, + ): + """constructor + + Args: + n_samples (Dict[Split, int]): input dictionary: Split -> number of + data samples for each split + suffix_keys_wds (Union[str, Iterable[str]]): a set of keys each + corresponding to a data object in the webdataset tar file + dictionary. The data objects of these keys will be extracted and + tupled for each sample in the tar files + dirs_tars_wds (Dict[Split, str]): input dictionary: Split -> tar file + directory that contains the webdataset tar files for each split + global_batch_size (int): size of batch summing across nodes in Data + Distributed Parallel, i.e., local_batch_size * n_nodes. NOTE: + this data module doesn't rely on the input `global_batch_size` + for batching the samples. The batching is supposed to be done as + a part of the input `pipeline_prebatch_wld`. `global_batch_size` + is only used to compute a (pseudo-) epoch length for the data + loader so that the loader yield approximately n_samples // + global_batch_size batches + Kwargs: + prefix_tars_wds (str): name prefix of the input webdataset tar + files. The input tar files are globbed by + "{dirs_tars_wds[split]}/{prefix_tars_wds}-*.tar" + pipeline_wds (Optional[Dict[Split, Union[Iterable[Iterable[Any]], + Iterable[Any]]]]): a dictionary of webdatast composable, i.e., + functor that maps a iterator to another iterator that + transforms the data sample yield from the dataset object, for + different splits, or an iterable to such a sequence of such + iterators. For example, this can be used to transform the + sample in the worker before sending it to the main process of + the dataloader + pipeline_prebatch_wld (Optional[Dict[Split, + Union[Iterable[Iterable[Any]], Iterable[Any]]]]): a dictionary + of webloader composable, i.e., functor that maps a iterator to + another iterator that transforms the data sample yield from the + WebLoader object, for different splits, or an iterable to a + seuqnence of such iterators. For example, this can be used for + batching the samples. NOTE: this is applied before batching is + yield from the WebLoader + kwargs_wds (Optional[Dict[Split, Dict[str, Any]]]): kwargs for the + WebDataset.__init__() + kwargs_wld (Optional[Dict[Split, Dict[str, Any]]]): kwargs for the + WebLoader.__init__(), e.g., num_workers, of each split + + + """ + super().__init__() + + self._dirs_tars_wds = dirs_tars_wds + + keys_subset = self._dirs_tars_wds.keys() + + if n_samples.keys() != keys_subset: + raise RuntimeError( + f"Input n_samples has different keys than " + f"dirs_tars_wds: {n_samples.keys()} vs " + f"{keys_subset}" + ) + + self._n_samples = n_samples + + self._global_batch_size = global_batch_size + + if not isinstance(suffix_keys_wds, get_args(Union[str, Iterable])): + raise TypeError("suffix_keys_wds can only be str or Iterable[str]") + + self._suffix_keys_wds = suffix_keys_wds + + self._prefix_tars_wds = prefix_tars_wds + self._pipeline_wds = pipeline_wds + self._pipeline_prebatch_wld = pipeline_prebatch_wld + + self._kwargs_wld = kwargs_wld + + self._kwargs_wds = kwargs_wds + + # to be created later in setup + self._dataset = {} + + def prepare_data(self) -> None: + """This is called only by the main process by the Lightning workflow. Do + not rely on this data module object's state update here as there is no + way to communicate the state update to other subprocesses. + + Returns: None + """ + pass + + def _setup_wds(self, split: Split) -> wds.WebDataset: + """setup webdataset and webloader. This is called by setup() + + Args: + split (Split): train, val or test split + + Returns: WebDataset + + """ + if split not in self._dirs_tars_wds.keys(): + raise RuntimeError( + f"_setup_wds() is called with {split} " + f"split that doesn't have the input tar dir" + ) + urls = sorted( + glob.glob(f"{self._dirs_tars_wds[split]}/{self._prefix_tars_wds}-*.tar") + ) + kwargs = self._kwargs_wds[split] if self._kwargs_wds is not None else None + dataset = wds.WebDataset( + urls, **(kwargs if kwargs is not None else {}) + ).decode() + if isinstance(self._suffix_keys_wds, str): + dataset = dataset.extract_keys(f"*.{self._suffix_keys_wds}") + else: + dataset = dataset.extract_keys( + *[f"*.{key}" for key in self._suffix_keys_wds] + ) + + if self._pipeline_wds is not None and self._pipeline_wds[split] is not None: + if isinstance(self._pipeline_wds[split], Iterable): + dataset = dataset.compose(*self._pipeline_wds[split]) + else: + dataset = dataset.compose(self._pipeline_wds[split]) + return dataset + + def setup(self, stage: str) -> None: + """This is called on all Lightning-managed nodes in a multi-node + training session + + + Args: + stage (str): "fit", "test" or "predict" + Returns: None + """ + if stage == "fit": + self._dataset[Split.train] = self._setup_wds(Split.train) + self._dataset[Split.val] = self._setup_wds(Split.val) + elif stage == "validate": + self._dataset[Split.val] = self._setup_wds(Split.val) + elif stage == "test": + self._dataset[Split.test] = self._setup_wds(Split.test) + elif stage == "predict": + self._dataset[Split.test] = self._setup_wds(Split.test) + else: + raise NotImplementedError( + f"Data setup with stage = {stage} " f"is not implmented" + ) + + def _setup_dataloader(self, split: Split) -> wds.WebLoader: + """setup the dataloader for the input dataset split + + Args: + split (Split): input split type + + Returns: WebLoader object + + """ + if self._dataset[split] is None: + raise RuntimeError( + f"_setup_dataloader() is called with {split} " + f"split without setting up the corresp. dataset" + ) + dataset = self._dataset[split] + n_samples = self._n_samples[split] + n_batches = (n_samples + self._global_batch_size - 1) // self._global_batch_size + kwargs = self._kwargs_wld[split] if self._kwargs_wld is not None else None + loader = wds.WebLoader( + dataset, batch_size=None, **(kwargs if kwargs is not None else {}) + ) + + if ( + self._pipeline_prebatch_wld is not None + and self._pipeline_prebatch_wld[split] is not None + ): + if isinstance(self._pipeline_prebatch_wld[split], Iterable): + loader = loader.compose(*self._pipeline_prebatch_wld[split]) + else: + loader = loader.compose(self._pipeline_prebatch_wld[split]) + + loader = loader.with_epoch(n_batches) + + return loader + + def train_dataloader(self) -> wds.WebLoader: + return self._setup_dataloader(Split.train) + + def val_dataloader(self) -> wds.WebLoader: + return self._setup_dataloader(Split.val) + + def test_dataloader(self) -> wds.WebLoader: + return self._setup_dataloader(Split.test) + + def predict_dataloader(self) -> wds.WebLoader: + return self._setup_dataloader(Split.test) + + +class PickledDataWDS(WebDataModule): + """A LightningDataModule to process pickled data into webdataset tar files + and setup dataset and dataloader. This inherits the webdataset setup from + its parent module `WebDataModule`. This data module takes a directory of + pickled data files, data filename prefixes for train/val/test splits, data + filename suffixes and prepare webdataset tar files by globbing the specific + pickle data files `{dir_pickles}/{name_subset[split]}.{suffix_pickles}` and + outputing to webdataset tar file with the dict structure: + ``` + {"__key__" : name.replace(".", "-"), + suffix_pickles : pickled.dumps(data) } + ``` + NOTE: this assumes only one pickled file is processed for each sample. In + its setup() function, it creates the webdataset object chaining up the input + `pipeline_wds` workflow. In its train/val/test_dataloader(), it creates the + WebLoader object chaining up the `pipeline_prebatch_wld` workflow. + + Examples + -------- + + 1. create the data module with a directory of pickle files and the file name + prefix thereof for different splits to used by `Lightning.Trainer.fit()` + + ``` + >>> from bionemo.core.data.datamodule import Split, PickledDataWDS + + >>> dir_pickles = "/path/to/my/pickles/dir" + + >>> # the following will use `sample1.mydata.pt` and `sample2.mydata.pt` as the + >>> # training dataset and `sample4.mydata.pt` and `sample5.mydata.pt` as the + >>> # validation dataset + + >>> suffix_pickles = "mydata.pt" + + >>> names_subset = { + >>> Split.train: [sample1, sample2], + >>> Split.val: [sample4, sample5], + >>> } + + >>> # the following setting will attempt to create at least 5 tar files in + >>> # `/path/to/output/tars/dir/myshards-00000{0-5}.tar` + + >>> n_tars_wds = 5 + >>> prefix_tars_wds = "myshards" + >>> output_dir_tar_files = { + Split.train : "/path/to/output/tars/dir-train", + Split.val : "/path/to/output/tars/dir-val", + Split.test : "/path/to/output/tars/dir-test", + } + + >>> # see the `WebDataModule` API doc for the definition of global_batch_size + >>> global_batch_size = 16 + + >>> # user can optionally customize the data processing routines and kwargs used + >>> # in the WebDataset and WebLoader (see the examples in `WebDataModule`) + + >>> pipeline_wds = { Split.train: ... } + + >>> pipeline_prebatch_wld = { Split.train: ... } + + >>> kwargs_wds = { Split.train: ..., Split.val: ... } + + >>> kwargs_wld = { Split.train: ..., Split.val: ... } + + >>> # create the data module + >>> data_module = PickledDataWDS( + >>> dir_pickles, + >>> names_subset, + >>> suffix_pickles, # `WebDataModule` args + >>> output_dir_tar_files, # `WebDataModule` args + >>> global_batch_size, # `WebDataModule` args + >>> n_tars_wds=n_tars_wds, + >>> prefix_tars_wds=prefix_tars_wds, # `WebDataModule` kwargs + >>> pipeline_wds=pipeline_wds, # `WebDataModule` kwargs + >>> pipeline_prebatch_wld=pipelines_wdl_batch, # `WebDataModule` kwargs + >>> kwargs_wds=kwargs_wds, # `WebDataModule` kwargs + >>> kwargs_wld=kwargs_wld, # `WebDataModule` kwargs + >>> ) + + ``` + + """ + + def __init__( + self, + dir_pickles: str, + names_subset: Dict[Split, List[str]], + *args, + n_tars_wds: Optional[int] = None, + **kwargs, + ): + """constructor + + Args: + dir_pickles (str): input directory of pickled data files + names_subset (Dict[Split, List[str]]): list of filename prefix of + the data samples to be loaded in the dataset and dataloader for + each of the split + *args: arguments passed to the parent WebDataModule after its + `n_samples` args (where `n_samples` is deduced from the length of + `names_subset` arg of this class) + + Kwargs: + n_tars_wds (int): attempt to create at least this number of + webdataset shards + **kwargs: arguments passed to the parent WebDataModule + + + """ + super().__init__( + {split: len(names_subset[split]) for split in names_subset.keys()}, + *args, + **kwargs, + ) + + self._dir_pickles = dir_pickles + + self._names_subset = names_subset + + self._n_tars_wds = n_tars_wds + + def prepare_data(self) -> None: + """This is called only by the main process by the Lightning workflow. Do + not rely on this data module object's state update here as there is no + way to communicate the state update to other subprocesses. The nesting + `pickles_to_tars` function goes through the data name prefixes in the + different splits, read the corresponding pickled file and output a + webdataset tar archive with the dict structure: {"__key__" : + name.replace(".", "-"), suffix_pickles : pickled.dumps(data) }. + + Returns: None + """ + for split in self._names_subset.keys(): + # create wds shards (tar files) for train set + pickles_to_tars( + self._dir_pickles, + self._names_subset[split], + self._suffix_keys_wds, + self._dirs_tars_wds[split], + self._prefix_tars_wds, + min_num_shards=self._n_tars_wds, + ) diff --git a/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/utils.py b/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/utils.py new file mode 100644 index 000000000..541957edd --- /dev/null +++ b/sub-packages/bionemo-webdatamodule/src/bionemo/webdatamodule/utils.py @@ -0,0 +1,130 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import pickle +from pathlib import Path +from typing import Any, Callable, Dict, Iterable, List, Optional, Union, get_args + +import webdataset as wds +from nemo.utils import logging + + +def pickles_to_tars( + dir_input: str, + input_prefix_subset: List[str], + input_suffix: Union[str, Iterable[str]], + dir_output: str, + output_prefix: str, + func_output_data: Callable[[str, Dict[str, Any]], Dict[str, Any]] = lambda prefix, + suffix_to_data: {"__key__": prefix, **suffix_to_data}, + min_num_shards: Optional[int] = None, +) -> None: + """Convert a subset of pickle files from a directory to Webdataset tar files + Input path and name pattern for sample 0: + f"{dir_input}/{input_prefix_subset[0]}.{input_suffix[0]}" + f"{dir_input}/{input_prefix_subset[0]}.{input_suffix[1]}" + Input path and name pattern for sample 1: + f"{dir_input}/{input_prefix_subset[1]}.{input_suffix[0]}" + f"{dir_input}/{input_prefix_subset[1]}.{input_suffix[1]}" + ... + Output path and name pattern: + f"{dir_output}/{output_prefix}-%06d.tar" + + The webdataset tar archive is specified by the dictionary: + { + "__key__" : sample_filename_preifx, + sample_filename_suffix_1 : data_1, + sample_filename_suffix_2 : data_2, + ... + } + so that parsing the tar archive is equivalent of reading + {sample_filename_preifx}.{sample_filename_suffix_1} etc. + + Here, each sample data get its name prefix from one element of + `input_prefix_subset` and its name suffixes from the list `input_suffix`. + Per the webdataset file format specification, the `sample_filename_preifx` + can't contain dots '.' so this function removes it for the user by calling + .replace(".", "-") on the elements of `input_prefix_subset` + + Args: + dir_input (str): Input directory + input_prefix_subset (List[str]): Input subset of pickle files' prefix + input_suffix (Union[str, Iterable[str]]): Input pickle file name + suffixes, each for one type of data object, for all the samples + dir_output (str): Output directory + output_prefix (str): Output tar file name prefix + func_output_data (Callable[[str, Dict[str, Any]], Dict[str, Any]]) : + function that maps the name prefix, name suffix and data object to a + webdataset tar archive dictionary. Refer to the webdataset github + repo for the archive file format specification. + min_num_shards (int) : create at least this number of tar files. + WebDataset has bugs when reading small number of tar files in a + multi-node lightening + DDP setting so this option can be used to + guarantee the tar file counts + + Returns: None + + """ + if not isinstance(input_suffix, get_args(Union[str, Iterable])): + raise TypeError("input_suffix can only be str or Iterable[str]") + os.makedirs(dir_output, exist_ok=True) + wd_subset_pattern = os.path.join(dir_output, f"{output_prefix}-%06d.tar") + n_samples_per_shard_max = 100000 + if min_num_shards is not None: + if min_num_shards <= 0: + raise ValueError(f"Invalid min_num_shards = {min_num_shards} <= 0") + n_samples_per_shard_max = len(input_prefix_subset) // min_num_shards + with wds.ShardWriter( + wd_subset_pattern, + encoder=False, + maxcount=n_samples_per_shard_max, + compress=False, + mode=0o777, + ) as sink: + for name in input_prefix_subset: + try: + if isinstance(input_suffix, str): + suffix_to_data = { + input_suffix: pickle.dumps( + pickle.loads( + ( + Path(dir_input) / f"{name}.{input_suffix}" + ).read_bytes() + ) + ) + } + else: + suffix_to_data = { + suffix: pickle.dumps( + pickle.loads( + (Path(dir_input) / f"{name}.{suffix}").read_bytes() + ) + ) + for suffix in input_suffix + } + # the prefix name shouldn't contain any "." per webdataset's + # specification + sample = func_output_data(name.replace(".", "-"), suffix_to_data) + sink.write(sample) + except ModuleNotFoundError as e: + logging.error( + f"Dependency for parsing input pickle data not found: {e}" + ) + raise e + except Exception as e: + logging.error(f"Failed to write {name} into tar files due to error {e}") + raise e diff --git a/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py new file mode 100644 index 000000000..a43f4c0be --- /dev/null +++ b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/conftest.py @@ -0,0 +1,301 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pickle +import random + +import lightning as L +import pytest +import torch +import webdataset as wds +from webdataset.filters import batched, shuffle + +from bionemo.webdatamodule.datamodule import PickledDataWDS, Split, WebDataModule +from bionemo.webdatamodule.utils import pickles_to_tars + + +@pytest.fixture(scope="module") +def gen_pickle_files(tmp_path_factory): + dir_pickles = tmp_path_factory.mktemp("pickleddatawds").as_posix() + prefix_sample = "sample" + suffix_sample = ["tensor.pyd", "tensor_copy.pyd"] + n_samples_per_split = 10 + prefixes = [] + # generate the pickles for train, val, and test + for i in range(n_samples_per_split * 3): + prefix = f"{prefix_sample}-{i:04}" + prefixes.append(prefix) + t = torch.tensor(i, dtype=torch.int32) + for suffix in suffix_sample: + with open(f"{dir_pickles}/{prefix}.{suffix}", "wb") as fh: + pickle.dump(t, fh) + prefixes_pickle = { + Split.train: prefixes[0:n_samples_per_split], + Split.val: prefixes[n_samples_per_split : n_samples_per_split * 2], + Split.test: prefixes[n_samples_per_split * 2 : n_samples_per_split * 3], + } + return ( + dir_pickles, + prefix_sample, + suffix_sample, + prefixes_pickle, + n_samples_per_split, + ) + + +@pytest.fixture(scope="module", params=[1, 2]) +def gen_test_data(tmp_path_factory, gen_pickle_files, request): + dir_pickles, prefix_sample, suffixes, prefixes_pickle, n_samples_per_split = ( + gen_pickle_files + ) + n_suffixes = request.param + if n_suffixes <= 1: + suffix_sample = suffixes[0] + else: + suffix_sample = suffixes[0:n_suffixes] + dir_tars_tmp = tmp_path_factory.mktemp("webdatamodule").as_posix() + dir_tars = {split: f"{dir_tars_tmp}{str(split).split('.')[-1]}" for split in Split} + prefix_tar = "tensor" + n_samples = {split: n_samples_per_split for split in Split} + # generate the tars + pickles_to_tars( + dir_pickles, + prefixes_pickle[Split.train], + suffix_sample, + dir_tars[Split.train], + prefix_tar, + min_num_shards=3, + ) + pickles_to_tars( + dir_pickles, + prefixes_pickle[Split.val], + suffix_sample, + dir_tars[Split.val], + prefix_tar, + min_num_shards=3, + ) + pickles_to_tars( + dir_pickles, + prefixes_pickle[Split.test], + suffix_sample, + dir_tars[Split.test], + prefix_tar, + min_num_shards=3, + ) + return ( + dir_pickles, + dir_tars, + prefix_sample, + suffix_sample, + prefix_tar, + n_samples, + prefixes_pickle, + ) + + +def _create_webdatamodule(gen_test_data, num_workers=2): + (_, dirs_tars_wds, _, suffix_keys_wds, prefix_tars_wds, n_samples, _) = ( + gen_test_data + ) + local_batch_size = 2 + global_batch_size = 2 + seed_rng_shfl = 82838392 + + batch = batched( + local_batch_size, collation_fn=lambda list_samples: torch.vstack(list_samples) + ) + + if isinstance(suffix_keys_wds, str): + untuple = lambda source: (sample[0] for sample in source) + elif isinstance(suffix_keys_wds, list): + untuple = lambda source: (torch.vstack(sample) for sample in source) + + pipeline_wds = { + Split.train: [ + untuple, + shuffle(n_samples[Split.train], rng=random.Random(seed_rng_shfl)), + ], + Split.val: untuple, + Split.test: untuple, + } + + pipeline_prebatch_wld = { + Split.train: [ + shuffle(n_samples[Split.train], rng=random.Random(seed_rng_shfl)), + batch, + ], + Split.val: batch, + Split.test: batch, + } + + kwargs_wds = { + split: { + "shardshuffle": split == Split.train, + "nodesplitter": wds.split_by_node, + "seed": seed_rng_shfl, + } + for split in Split + } + + kwargs_wld = {split: {"num_workers": num_workers} for split in Split} + + data_module = WebDataModule( + n_samples, + suffix_keys_wds, + dirs_tars_wds, + global_batch_size, + prefix_tars_wds=prefix_tars_wds, + pipeline_wds=pipeline_wds, + pipeline_prebatch_wld=pipeline_prebatch_wld, + kwargs_wds=kwargs_wds, + kwargs_wld=kwargs_wld, + ) + + return data_module, dirs_tars_wds + + +@pytest.fixture(scope="module") +def create_webdatamodule(gen_test_data): + return _create_webdatamodule(gen_test_data) + + +@pytest.fixture(scope="module") +def create_another_webdatamodule(gen_test_data): + return _create_webdatamodule(gen_test_data) + + +@pytest.fixture(scope="module") +def create_webdatamodule_with_5_workers(gen_test_data): + return _create_webdatamodule(gen_test_data, num_workers=5) + + +class ModelTestWebDataModule(L.LightningModule): + def __init__(self) -> None: + super().__init__() + self._model = torch.nn.Linear(1, 1) + self._samples = {split: [] for split in Split} + + def forward(self, x): + return self._model(x.float()) + + def training_step(self, batch): + self._samples[Split.train].append(batch) + loss = self(batch).sum() + return loss + + def validation_step(self, batch, batch_index): + self._samples[Split.val].append(batch) + return torch.zeros(1) + + def test_step(self, batch, batch_index): + self._samples[Split.test].append(batch) + + def predict_step(self, batch, batch_index): + self._samples[Split.test].append(batch) + return torch.zeros(1) + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=2e-4) + return optimizer + + +@pytest.fixture(scope="function") +def create_trainer_and_model(): + trainer = L.Trainer( + max_epochs=1, accelerator="gpu", devices=1, val_check_interval=1 + ) + model = ModelTestWebDataModule() + return trainer, model + + +def _create_pickleddatawds(tmp_path_factory, gen_test_data): + ( + dir_pickles, + _, + _, + suffix_keys_wds, + prefix_tars_wds, + n_samples, + names, + ) = gen_test_data + local_batch_size = 2 + global_batch_size = 2 + seed_rng_shfl = 82838392 + n_tars_wds = 3 + + prefix_dir_tars_wds = tmp_path_factory.mktemp("pickleddatawds_tars_wds").as_posix() + dirs_tars_wds = {s: f"{prefix_dir_tars_wds}{str(s).split('.')[-1]}" for s in Split} + + batch = batched( + local_batch_size, collation_fn=lambda list_samples: torch.vstack(list_samples) + ) + + untuple = lambda source: (sample[0] for sample in source) + + pipeline_wds = { + Split.train: [ + untuple, + shuffle(n_samples[Split.train], rng=random.Random(seed_rng_shfl)), + ], + Split.val: untuple, + Split.test: untuple, + } + + pipeline_prebatch_wld = { + Split.train: [ + shuffle(n_samples[Split.train], rng=random.Random(seed_rng_shfl)), + batch, + ], + Split.val: batch, + Split.test: batch, + } + + kwargs_wds = { + split: { + "shardshuffle": split == Split.train, + "nodesplitter": wds.split_by_node, + "seed": seed_rng_shfl, + } + for split in Split + } + + kwargs_wld = {split: {"num_workers": 2} for split in Split} + + data_module = PickledDataWDS( + dir_pickles, + names, + suffix_keys_wds, + dirs_tars_wds, + global_batch_size, + n_tars_wds=n_tars_wds, + prefix_tars_wds=prefix_tars_wds, + pipeline_wds=pipeline_wds, + pipeline_prebatch_wld=pipeline_prebatch_wld, + kwargs_wds=kwargs_wds, + kwargs_wld=kwargs_wld, + ) + + return data_module, dirs_tars_wds, n_tars_wds + + +@pytest.fixture(scope="module") +def create_pickleddatawds(tmp_path_factory, gen_test_data): + return _create_pickleddatawds(tmp_path_factory, gen_test_data) + + +@pytest.fixture(scope="module") +def create_another_pickleddatawds(tmp_path_factory, gen_test_data): + return _create_pickleddatawds(tmp_path_factory, gen_test_data) diff --git a/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_datamodule.py b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_datamodule.py new file mode 100644 index 000000000..692905a41 --- /dev/null +++ b/sub-packages/bionemo-webdatamodule/tests/bionemo/webdatamodule/test_datamodule.py @@ -0,0 +1,268 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +from enum import Enum, auto + +import lightning as L +import pytest +import torch + +from bionemo.webdatamodule.datamodule import Split + + +@pytest.mark.parametrize("split", list(Split)) +def test_webdatamodule_init(split, create_webdatamodule): + data_module, dirs_tars_wds = create_webdatamodule + assert data_module._n_samples[split] == 10, ( + f"Wrong {split}-set size: " + f"expected 10 " + f"but got {data_module._n_samples[split]}" + ) + assert data_module._dirs_tars_wds[split] == f"{dirs_tars_wds[split]}", ( + f"Wrong tar files directory: " + f"expected {dirs_tars_wds[split]} " + f"but got {data_module._dirs_tars_wds[split]}" + ) + + +@pytest.mark.parametrize("split", list(Split)) +def test_webdatamodule_setup_dataset( + split, create_webdatamodule, create_another_webdatamodule +): + data_modules = [create_webdatamodule[0], create_another_webdatamodule[0]] + lists_tensors = [] + for m in data_modules: + m.prepare_data() + # run through all the possible stages first to setup all the correps. + # dataset objects + m.setup("fit") + m.setup("test") + L.seed_everything(2823828) + tensors = [] + for sample in m._dataset[split]: + assert isinstance( + sample, torch.Tensor + ), "Sample yield from dataset is not tensor" + tensors.append(sample) + lists_tensors.append(tensors) + + assert len(lists_tensors[0]) > 0, "No names in {split} dataset" + torch.testing.assert_close( + torch.vstack(lists_tensors[0]), torch.vstack(lists_tensors[1]) + ) + + +@pytest.mark.parametrize("split", list(Split)) +def test_webdatamodule_setup_dataloader( + split, create_webdatamodule, create_another_webdatamodule +): + data_modules = [create_webdatamodule[0], create_another_webdatamodule[0]] + lists_tensors = [] + for m in data_modules: + m.prepare_data() + # run through all the possible stages first to setup all the correps. + # dataset objects + m.setup("fit") + m.setup("test") + L.seed_everything(2823828) + tensors = [] + loader = None + if split == Split.train: + loader = m.train_dataloader() + elif split == Split.val: + loader = m.val_dataloader() + elif split == Split.test: + loader = m.test_dataloader() + else: + raise RuntimeError(f"Test for split {split} not implemented") + assert loader is not None, "dataloader not instantated" + for samples in loader: + # PyG's HeteroDataBatch is Batch inherited from HeteroData + assert isinstance( + samples, torch.Tensor + ), "Sample object is not torch.Tensor" + tensors.append(samples) + lists_tensors.append(tensors) + + assert len(lists_tensors[0]) > 0, "No names in {split} dataloader" + torch.testing.assert_close( + torch.vstack(lists_tensors[0]), torch.vstack(lists_tensors[1]) + ) + + +@pytest.mark.parametrize("split", list(Split)) +def test_webdatamodule_throw_on_many_workers( + split, create_webdatamodule_with_5_workers +): + data_module = create_webdatamodule_with_5_workers[0] + urls = glob.glob( + f"{data_module._dirs_tars_wds[split]}/" f"{data_module._prefix_tars_wds}-*.tar" + ) + n_tars = len(urls) + data_module._kwargs_wld[split]["num_workers"] = n_tars + 1 + data_module.prepare_data() + data_module.setup("fit") + data_module.setup("test") + loader = None + if split == Split.train: + loader = data_module.train_dataloader() + elif split == Split.val: + loader = data_module.val_dataloader() + elif split == Split.test: + loader = data_module.test_dataloader() + else: + raise RuntimeError(f"Test for split {split} not implemented") + assert loader is not None, "dataloader not instantated" + try: + for _ in loader: + pass + except ValueError as e: + # this is expected + assert "have fewer shards than workers" in str(e), ( + f"'have fewer shards than workers' not found in exception " + f"raised from data loading: {e}" + ) + except Exception as e: + raise RuntimeError( + f"WebLoader doesn't raise ValueError with fewer " + f"shards than workers but raise this instead: {e}" + ) + else: + raise NotImplementedError( + "WebLoader doesn't throw error with num_workers > num_shards " + "User should report this issue to webdataset and create " + "less shards than workers in practice as a workaround" + ) + + +class Stage(Enum): + fit = auto() + validate = auto() + test = auto() + predict = auto() + + +@pytest.mark.parametrize("stage", list(Stage)) +def test_webdatamodule_in_lightning( + stage, create_webdatamodule, create_another_webdatamodule, create_trainer_and_model +): + data_modules = [create_webdatamodule[0], create_another_webdatamodule[0]] + trainer, model = create_trainer_and_model + # get the list of samples from the loader + L.seed_everything(2823828) + data_modules[0].prepare_data() + split = None + if stage == Stage.fit: + split = Split.train + elif stage == Stage.validate: + split = Split.val + elif stage == Stage.test or stage == Stage.predict: + split = Split.test + else: + raise RuntimeError(f"{stage} stage not implemented") + name_stage = str(stage).split(".")[-1] + data_modules[0].setup(name_stage) + # get the list of samples from the workflow + get_dataloader = getattr(data_modules[0], f"{str(split).split('.')[-1]}_dataloader") + loader = get_dataloader() + L.seed_everything(2823828) + workflow = getattr(trainer, name_stage) + workflow(model, data_modules[1]) + device = model._samples[split][0].device + samples = [sample.to(device=device) for sample in loader] + torch.testing.assert_close( + torch.stack(model._samples[split], dim=0), torch.stack(samples, dim=0) + ) + + +@pytest.mark.parametrize("split", list(Split)) +def test_pickleddatawds_init(split, create_pickleddatawds): + data_module, dirs_tars_wds, _ = create_pickleddatawds + assert data_module._n_samples[split] == 10, ( + f"Wrong {split}-set size: " + f"expected 10 " + f"but got {data_module._n_samples[split]}" + ) + assert data_module._dirs_tars_wds[split] == dirs_tars_wds[split], ( + f"Wrong tar files directory: " + f"expected {dirs_tars_wds[split]} " + f"but got {data_module._dirs_tars_wds[split]}" + ) + + +@pytest.mark.parametrize("split", list(Split)) +def test_pickleddatawds_prepare_data(split, create_pickleddatawds): + data_module, _, n_tars_min = create_pickleddatawds + data_module.prepare_data() + dir_tars = f"{data_module._dirs_tars_wds[split]}" + tars = glob.glob(f"{dir_tars}/{data_module._prefix_tars_wds}-*.tar") + n_tars = len(tars) + assert n_tars_min <= n_tars and n_tars <= n_tars_min + 1, ( + f"Number of tar files: {n_tars} in {dir_tars} is outside the range " + f"[{n_tars_min}, {n_tars_min + 1}]" + ) + + +@pytest.mark.parametrize("split", list(Split)) +def test_pickleddatawds_setup_dataset( + split, create_pickleddatawds, create_another_pickleddatawds +): + data_modules = [create_pickleddatawds[0], create_another_pickleddatawds[0]] + lists_tensors = [] + for m in data_modules: + m.prepare_data() + # run through all the possible stages first to setup all the correps. + # dataset objects + m.setup("fit") + m.setup("test") + L.seed_everything(2823828) + tensors = [] + for sample in m._dataset[split]: + assert isinstance( + sample, torch.Tensor + ), "Sample yield from dataset is not tensor" + tensors.append(sample) + lists_tensors.append(tensors) + + assert len(lists_tensors[0]) > 0, "No names in {split} dataset" + torch.testing.assert_close( + torch.vstack(lists_tensors[0]), torch.vstack(lists_tensors[1]) + ) + + +def test_pickleddatawds_sample_overlap(create_pickleddatawds): + data_module = create_pickleddatawds[0] + # this writes the tar files to disk + data_module.prepare_data() + # read the data back by setting up the dataset object and loop over it + data_module.setup("fit") + data_module.setup("test") + results = { + split: set([sample.item() for sample in data_module._dataset[split]]) + for split in Split + } + overlap_train_val = results[Split.train] & results[Split.val] + overlap_train_test = results[Split.train] & results[Split.test] + overlap_val_test = results[Split.val] & results[Split.test] + assert ( + len(overlap_train_val) == 0 + ), "Shared samples found between train and val datasets" + assert ( + len(overlap_train_test) == 0 + ), "Shared samples found between train and test datasets" + assert ( + len(overlap_val_test) == 0 + ), "Shared samples found between val and test datasets"