diff --git a/aiaccel/apps/optimize.py b/aiaccel/apps/optimize.py index 852b38ef..fd6c2168 100644 --- a/aiaccel/apps/optimize.py +++ b/aiaccel/apps/optimize.py @@ -1,11 +1,13 @@ +from typing import Any + import argparse -import pickle as pkl from collections.abc import Callable from pathlib import Path -from typing import Any +import pickle as pkl from hydra.utils import instantiate from omegaconf import OmegaConf as oc # noqa: N813 + from optuna.trial import Trial from aiaccel.hpo.optuna.wrapper import Const, Suggest, SuggestFloat, T diff --git a/aiaccel/apps/train.py b/aiaccel/apps/train.py index 698b2eab..c300ed2c 100755 --- a/aiaccel/apps/train.py +++ b/aiaccel/apps/train.py @@ -1,14 +1,12 @@ -import os from argparse import ArgumentParser +import os from pathlib import Path import pickle as pkl from hydra.utils import instantiate -from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as oc # noqa: N813 import lightning as lt -from lightning.fabric.utilities.rank_zero import rank_zero_only from aiaccel.utils import print_config @@ -41,7 +39,6 @@ def main() -> None: with open(config.working_directory / "config.pkl", "wb") as f: pkl.dump(config, f) - # train trainer: lt.Trainer = instantiate(config.trainer) trainer.fit( diff --git a/aiaccel/hpo/algorithms/nelder_mead_algorithm.py b/aiaccel/hpo/algorithms/nelder_mead_algorithm.py index ba7aa4aa..98652798 100644 --- a/aiaccel/hpo/algorithms/nelder_mead_algorithm.py +++ b/aiaccel/hpo/algorithms/nelder_mead_algorithm.py @@ -1,10 +1,11 @@ -import queue -import threading +import numpy.typing as npt + from collections.abc import Generator from dataclasses import dataclass +import queue +import threading import numpy as np -import numpy.typing as npt @dataclass diff --git a/aiaccel/hpo/optuna/samplers/nelder_mead_sampler.py b/aiaccel/hpo/optuna/samplers/nelder_mead_sampler.py index 64e3e367..1ca63993 100644 --- a/aiaccel/hpo/optuna/samplers/nelder_mead_sampler.py +++ b/aiaccel/hpo/optuna/samplers/nelder_mead_sampler.py @@ -1,10 +1,20 @@ from __future__ import annotations +<<<<<<< HEAD import warnings from collections.abc import Sequence from typing import Any import numpy as np +======= +from typing import Any + +from collections.abc import Sequence +import warnings + +import numpy as np + +>>>>>>> 0caccda (update v2 torch (#383)) import optuna from optuna.distributions import BaseDistribution from optuna.study import Study diff --git a/aiaccel/hpo/optuna/wrapper.py b/aiaccel/hpo/optuna/wrapper.py index ea0aba51..cce5edfe 100644 --- a/aiaccel/hpo/optuna/wrapper.py +++ b/aiaccel/hpo/optuna/wrapper.py @@ -1,6 +1,7 @@ +from typing import Generic, TypeVar + from collections.abc import Sequence from dataclasses import dataclass -from typing import Generic, TypeVar from optuna.trial import Trial diff --git a/aiaccel/job/abci_job.py b/aiaccel/job/abci_job.py index 55317a6b..fee9e260 100644 --- a/aiaccel/job/abci_job.py +++ b/aiaccel/job/abci_job.py @@ -1,11 +1,12 @@ from __future__ import annotations +from typing import Any + +from enum import IntEnum, auto +from pathlib import Path import re import subprocess import time -from enum import IntEnum, auto -from pathlib import Path -from typing import Any from xml.etree import ElementTree diff --git a/aiaccel/job/abci_job_manager.py b/aiaccel/job/abci_job_manager.py index 9b23dcff..c1951168 100644 --- a/aiaccel/job/abci_job_manager.py +++ b/aiaccel/job/abci_job_manager.py @@ -1,7 +1,8 @@ -import time -from pathlib import Path from typing import Any +from pathlib import Path +import time + from aiaccel.job.abci_job import AbciJob, JobStatus diff --git a/aiaccel/torch/datasets/cached_dataset.py b/aiaccel/torch/datasets/cached_dataset.py index 402ce7d2..f5c363d9 100644 --- a/aiaccel/torch/datasets/cached_dataset.py +++ b/aiaccel/torch/datasets/cached_dataset.py @@ -9,14 +9,43 @@ class NumpiedTensor: + """ + A wrapper class that converts a PyTorch tensor to a NumPy array and vice versa. + + Args: + tensor (torch.Tensor): The input PyTorch tensor. + + Attributes: + array (np.ndarray): The NumPy array representation of the tensor. + + Methods: + to_tensor: Converts the NumPy array back to a PyTorch tensor. + """ + def __init__(self, tensor: torch.Tensor) -> None: self.array = tensor.numpy() def to_tensor(self) -> torch.Tensor: + """ + Converts the NumPy array back to a PyTorch tensor. + + Returns: + torch.Tensor: The PyTorch tensor representation of the NumPy array. + """ return torch.tensor(self.array) def numpize_sample(sample: Any) -> Any: + """ + Converts the input sample to a NumPy-compatible format. + + Args: + sample (Any): The input sample to be converted. + + Returns: + Any: The converted sample in a NumPy-compatible format. + """ + if isinstance(sample, torch.Tensor): return NumpiedTensor(sample) elif isinstance(sample, tuple): @@ -30,6 +59,16 @@ def numpize_sample(sample: Any) -> Any: def tensorize_sample(sample: Any) -> Any: + """ + Converts the given sample into a tensor representation. + + Args: + sample (Any): The input sample to be tensorized. + + Returns: + Any: The tensorized representation of the input sample. + """ + if isinstance(sample, NumpiedTensor): return sample.to_tensor() elif isinstance(sample, tuple): @@ -46,6 +85,18 @@ def tensorize_sample(sample: Any) -> Any: class CachedDataset(Dataset[T_co]): + """ + A dataset wrapper that caches the samples to improve performance. + + Args: + dataset (Dataset): The original dataset to be wrapped. + + Attributes: + dataset (Dataset): The original dataset. + manager (Manager): The multiprocessing manager. + cache (dict): The cache dictionary to store the cached samples. + """ + def __init__(self, dataset: Dataset[T_co]) -> None: self.dataset = dataset diff --git a/aiaccel/torch/datasets/hdf5_dataset.py b/aiaccel/torch/datasets/hdf5_dataset.py index 9e6c306c..fd2566cf 100644 --- a/aiaccel/torch/datasets/hdf5_dataset.py +++ b/aiaccel/torch/datasets/hdf5_dataset.py @@ -16,14 +16,33 @@ ] -class RawHDF5Dataset(Dataset[int]): +class RawHDF5Dataset(Dataset[dict[str, Any]]): + """ + A dataset class for reading data from HDF5 files. + + Args: + dataset_path (Union[Path, str]): The path to the HDF5 dataset file. + grp_list (Union[Path, str, List[str], None], optional): The list of groups to load from the dataset. + If None, all groups in the dataset will be loaded. If a string or Path, it should be the path to a file + containing the list of groups. If a list, it should directly specify the groups to load. Defaults to None. + + Raises: + NotImplementedError: If grp_list is of an unsupported type. + + Attributes: + dataset_path (Union[Path, str]): The path to the HDF5 dataset file. + grp_list (List[str]): The list of groups to load from the dataset. + f (Optional[h5.File]): The HDF5 file object used for reading the dataset. + + """ + def __init__(self, dataset_path: Path | str, grp_list: Path | str | list[str] | None = None) -> None: self.dataset_path = dataset_path if grp_list is None: with h5.File(self.dataset_path, "r") as f: self.grp_list = list(f.keys()) - elif isinstance(grp_list, (str, Path)): + elif isinstance(grp_list, (str | Path)): with open(grp_list, "rb") as f: self.grp_list = pkl.load(f) elif isinstance(grp_list, list): @@ -49,5 +68,21 @@ def __del__(self) -> None: class HDF5Dataset(RawHDF5Dataset): + """ + A dataset class for loading data from an HDF5 file. + + This class extends the `RawHDF5Dataset` class and provides a convenient way to load data from an HDF5 file + and convert it into a dictionary of torch tensors. + + Args: + path (str): The path to the HDF5 file. + transform (callable, optional): A function/transform that takes in a dictionary of data and returns a + modified version. Default is None. + + Returns: + dict[str, torch.Tensor]: A dictionary containing the data loaded from the HDF5 file, where the keys are + the names of the data fields and the values are torch tensors. + """ + def __getitem__(self, index: int) -> dict[str, torch.Tensor]: return {k: torch.as_tensor(v) for k, v in super().__getitem__(index).items()} diff --git a/aiaccel/torch/datasets/scatter_dataset.py b/aiaccel/torch/datasets/scatter_dataset.py index dd4ebec8..a0ca46dc 100644 --- a/aiaccel/torch/datasets/scatter_dataset.py +++ b/aiaccel/torch/datasets/scatter_dataset.py @@ -15,6 +15,20 @@ def scatter_dataset( dataset: Dataset[T], permute_fn: Callable[[npt.NDArray[np.int64]], npt.NDArray[np.int64]] | None = None, ) -> Subset[T]: + """ + Splits a dataset into subsets and returns the subset corresponding to the current process rank. + + Args: + dataset (Dataset[T]): The input dataset to be split. + permute_fn (Callable[[npt.NDArray[np.int64]], npt.NDArray[np.int64]] | None, optional): + A function that takes an array of indices and returns a permuted version of the array. + If None, a default permutation function using np.random.Generator is used. + Defaults to None. + + Returns: + Subset[T]: The subset of the input dataset corresponding to the current process rank. + """ + if permute_fn is None: permute_fn = np.random.Generator(np.random.PCG64(0)).permutation diff --git a/aiaccel/torch/lightning/abci_environment.py b/aiaccel/torch/lightning/abci_environment.py index 54c2db81..113c964b 100644 --- a/aiaccel/torch/lightning/abci_environment.py +++ b/aiaccel/torch/lightning/abci_environment.py @@ -7,6 +7,13 @@ class ABCIEnvironment(ClusterEnvironment): + """ + Environment class for ABCI. + + This class provides methods to interact with the ABCI environment, + such as retrieving the world size, global rank, node rank, and local rank. + """ + def __init__(self) -> None: self._world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) self._rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) @@ -50,8 +57,7 @@ def set_world_size(self, size: int) -> None: def set_global_rank(self, rank: int) -> None: if rank != self.global_rank(): - raise ValueError(f"`rank` is expected to be {self.get_global_rank()}, buf {rank} is given.") - + raise ValueError(f"`rank` is expected to be {self.global_rank()}, buf {rank} is given.") def validate_settings(self, num_devices: int, num_nodes: int) -> None: if num_devices != self._local_size: diff --git a/aiaccel/torch/lightning/opt_lightning_module.py b/aiaccel/torch/lightning/opt_lightning_module.py index 4007fc29..d8a46183 100644 --- a/aiaccel/torch/lightning/opt_lightning_module.py +++ b/aiaccel/torch/lightning/opt_lightning_module.py @@ -1,28 +1,55 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + from collections.abc import Callable from dataclasses import dataclass -import torch -from torch.optim import Optimizer - import lightning as lt from lightning.pytorch.utilities.types import OptimizerLRSchedulerConfig +if TYPE_CHECKING: + from torch import optim + @dataclass class OptimizerConfig: - optimizer_generator: Callable[..., torch.optim.Optimizer] - scheduler_generator: Callable[..., torch.optim.lr_scheduler.LRScheduler] | None = None + """ + Configuration class for the optimizer and scheduler in the LightningModule. + """ + + optimizer_generator: Callable[..., optim.optimizer.Optimizer] + scheduler_generator: Callable[..., optim.lr_scheduler.LRScheduler] | None = None scheduler_interval: str | None = "step" scheduler_monitor: str | None = "validation/loss" class OptimizerLightningModule(lt.LightningModule): + """ + LightningModule subclass for models that use custom optimizers and schedulers. + + Args: + optimizer_config (OptimizerConfig): Configuration object for the optimizer. + + Attributes: + optcfg (OptimizerConfig): Configuration object for the optimizer. + + Methods: + configure_optimizers: Configures the optimizer and scheduler for training. + """ + def __init__(self, optimizer_config: OptimizerConfig): super().__init__() self.optcfg = optimizer_config - def configure_optimizers(self) -> Optimizer | OptimizerLRSchedulerConfig: + def configure_optimizers(self) -> optim.optimizer.Optimizer | OptimizerLRSchedulerConfig: + """ + Configures the optimizer and scheduler for training. + + Returns: + Union[optim.optimizer.Optimizer, OptimizerLRSchedulerConfig]: The optimizer and scheduler configuration. + """ optimizer = self.optcfg.optimizer_generator(params=self.parameters()) if self.optcfg.scheduler_generator is None: return optimizer diff --git a/aiaccel/utils/config.py b/aiaccel/utils/config.py index 70845248..034fbb3d 100644 --- a/aiaccel/utils/config.py +++ b/aiaccel/utils/config.py @@ -4,8 +4,6 @@ from pathlib import Path import re -import yaml - from colorama import Fore from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as oc # noqa:N813 diff --git a/examples/job/optuna/main_parallel.py b/examples/job/optuna/main_parallel.py index ecaff933..504984ca 100644 --- a/examples/job/optuna/main_parallel.py +++ b/examples/job/optuna/main_parallel.py @@ -1,5 +1,10 @@ +<<<<<<< HEAD import pickle as pkl from pathlib import Path +======= +from pathlib import Path +import pickle as pkl +>>>>>>> 0caccda (update v2 torch (#383)) import optuna diff --git a/examples/job/optuna/main_serial.py b/examples/job/optuna/main_serial.py index 215c65ec..a2625d45 100644 --- a/examples/job/optuna/main_serial.py +++ b/examples/job/optuna/main_serial.py @@ -1,5 +1,10 @@ +<<<<<<< HEAD import pickle as pkl from pathlib import Path +======= +from pathlib import Path +import pickle as pkl +>>>>>>> 0caccda (update v2 torch (#383)) import optuna diff --git a/examples/job/optuna/objective.py b/examples/job/optuna/objective.py index bddd398e..188f6d1e 100644 --- a/examples/job/optuna/objective.py +++ b/examples/job/optuna/objective.py @@ -1,6 +1,12 @@ +<<<<<<< HEAD import pickle as pkl from argparse import ArgumentParser from pathlib import Path +======= +from argparse import ArgumentParser +from pathlib import Path +import pickle as pkl +>>>>>>> 0caccda (update v2 torch (#383)) def main() -> None: diff --git a/pyproject.toml b/pyproject.toml index b4a8e727..b7d9a05d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,15 +16,15 @@ classifiers = [ "License :: OSI Approved :: MIT License", ] dependencies = [ - "colorama", "numpy", "optuna", "optuna>=3.4.0", - "lightning>=2.2.1", - "hydra-core", - "torch>=2.2.0", "omegaconf", "hydra-core", + "colorama", + "lightning>=2.2.1", + "torch>=2.2.0", + "h5py" ] [project.scripts] diff --git a/tests/hpo/algorithms/test_nelder_mead_algorithm.py b/tests/hpo/algorithms/test_nelder_mead_algorithm.py index 62637758..bee3375e 100644 --- a/tests/hpo/algorithms/test_nelder_mead_algorithm.py +++ b/tests/hpo/algorithms/test_nelder_mead_algorithm.py @@ -1,7 +1,9 @@ +import numpy.typing as npt + from unittest.mock import patch import numpy as np -import numpy.typing as npt + import pytest from aiaccel.hpo.algorithms.nelder_mead_algorithm import ( diff --git a/tests/hpo/optuna/samplers/test_nelder_mead_sampler.py b/tests/hpo/optuna/samplers/test_nelder_mead_sampler.py index 90f60ab4..ecb2a722 100644 --- a/tests/hpo/optuna/samplers/test_nelder_mead_sampler.py +++ b/tests/hpo/optuna/samplers/test_nelder_mead_sampler.py @@ -1,15 +1,17 @@ +import numpy.typing as npt +from typing import Any + +from collections.abc import Callable import csv import datetime import math -import time -from collections.abc import Callable from multiprocessing import Pool from pathlib import Path -from typing import Any +import time from unittest.mock import patch import numpy as np -import numpy.typing as npt + import optuna import pytest diff --git a/tests/job/test_abci_job.py b/tests/job/test_abci_job.py index bfec8b97..a31f5e5a 100644 --- a/tests/job/test_abci_job.py +++ b/tests/job/test_abci_job.py @@ -1,7 +1,7 @@ -import re -import shutil from collections.abc import Generator from pathlib import Path +import re +import shutil from unittest.mock import patch import pytest diff --git a/tests/job/test_abci_job_manager.py b/tests/job/test_abci_job_manager.py index 0c926788..1d0a9e4c 100644 --- a/tests/job/test_abci_job_manager.py +++ b/tests/job/test_abci_job_manager.py @@ -1,6 +1,6 @@ -import shutil from collections.abc import Generator from pathlib import Path +import shutil from unittest.mock import MagicMock, patch import pytest diff --git a/tests/job/utils.py b/tests/job/utils.py index 8da86e29..77e7c912 100644 --- a/tests/job/utils.py +++ b/tests/job/utils.py @@ -1,6 +1,7 @@ -import subprocess from typing import Any +import subprocess + def qstat_xml(txt_data_path: str = "tests/job/qstat_dat.txt") -> Any: with open(txt_data_path) as f: