Skip to content

Commit

Permalink
update v2 torch (#383)
Browse files Browse the repository at this point in the history
* Add nelder-mead sampler (#346)

* add nelder_mead_sampler.py

---------

Co-authored-by: Yoshiaki Bando <yoshipon@users.noreply.github.com>

* Add -e. (#371)

* Support for suggest_int in neldermead sampler. (#373)

* Support for suggest_int.

* Fix for ruff.

* Add Job dispatcher (#360)

* Update aiaccel/job

* fix lint error

* fix aiaccel/job/dispatcher.py

* wip

* Fix issues pointed out in review

* rename

* fix test

* fix for test

* fix for test

* fix for test

* fix for test

* Reflected code review feedback

* fix test

* fix test

* refactoring

* Run code formatter ruff

* refactoring

* Run code formatter ruff

* fix lint.yaml

* fix pypoject.toml

* fix mypy.ini

* fix mypy.ini

* fix mypy.ini

* fix mypy.ini

* fix mypy.ini

* Refactor abci jobs

* Update pyproject.toml

* adjusted to work in the ABCI environment.

* fix lint error

* fix test

* fix test

* Changed from unittest to pytest

* Incorporate revisions based on review feedback

* Add test items

* fix test

* fix test

* fix test

* remove __future__.annotations

* A bit simplify qstat_xml

* Simplify tests

---------

Co-authored-by: Yoshiaki Bando <yoshipon@users.noreply.github.com>

* Add ignore_missing_imports. (#376)

* Draft Implementation of CLI Command (#369)

* Update aiaccel/job

* fix lint error

* fix aiaccel/job/dispatcher.py

* wip

* Fix issues pointed out in review

* rename

* fix test

* fix for test

* fix for test

* fix for test

* fix for test

* Reflected code review feedback

* fix test

* fix test

* refactoring

* Run code formatter ruff

* refactoring

* Run code formatter ruff

* fix lint.yaml

* fix pypoject.toml

* fix mypy.ini

* fix mypy.ini

* fix mypy.ini

* fix mypy.ini

* fix mypy.ini

* Refactor abci jobs

* Update pyproject.toml

* adjusted to work in the ABCI environment.

* fix lint error

* fix test

* fix test

* Changed from unittest to pytest

* Incorporate revisions based on review feedback

* Add test items

* Add 'start.py'

* Fix code format

* update pyproject.toml

* Remove unnecessary items from config

* Add random seed option

* Refactor examples/start.py into aiaccel/apps/optimize.py

* fix lint errors and warnings

* remove apps/config

* fix typo

* change config.yaml format

* fix apps/optimizer

* fix HparamManager __init__

* fix HparamsManager

* Refactor Suggest class to use generics for improved type safety

* auto fix by ruff

* add optuna suggest wrapper

* format fix

* Update wrapper.py

---------

Co-authored-by: Yoshiaki Bando <yoshipon@users.noreply.github.com>

* Add nm docstring and move file (#381)

* Add docstring.

* Move nelder_mead_sampler.py.

* Fix path.

* Move test files.

* autoformat and fix mypy error

* Add v2 torch docstring   (#384)

* add docstring

* Update abci_environment.py

---------

Co-authored-by: Yoshiaki Bando <yoshipon@users.noreply.github.com>

---------

Co-authored-by: KanaiYuma-aist <105629713+KanaiYuma-aist@users.noreply.github.com>
Co-authored-by: Yoshiaki Bando <yoshipon@users.noreply.github.com>
  • Loading branch information
3 people committed Sep 2, 2024
1 parent d05c39e commit 798d4ba
Show file tree
Hide file tree
Showing 22 changed files with 205 additions and 40 deletions.
6 changes: 4 additions & 2 deletions aiaccel/apps/optimize.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import Any

import argparse
import pickle as pkl
from collections.abc import Callable
from pathlib import Path
from typing import Any
import pickle as pkl

from hydra.utils import instantiate
from omegaconf import OmegaConf as oc # noqa: N813

from optuna.trial import Trial

from aiaccel.hpo.optuna.wrapper import Const, Suggest, SuggestFloat, T
Expand Down
5 changes: 1 addition & 4 deletions aiaccel/apps/train.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import os
from argparse import ArgumentParser
import os
from pathlib import Path
import pickle as pkl

from hydra.utils import instantiate
from omegaconf import DictConfig, ListConfig
from omegaconf import OmegaConf as oc # noqa: N813

import lightning as lt
from lightning.fabric.utilities.rank_zero import rank_zero_only

from aiaccel.utils import print_config

Expand Down Expand Up @@ -41,7 +39,6 @@ def main() -> None:
with open(config.working_directory / "config.pkl", "wb") as f:
pkl.dump(config, f)


# train
trainer: lt.Trainer = instantiate(config.trainer)
trainer.fit(
Expand Down
7 changes: 4 additions & 3 deletions aiaccel/hpo/algorithms/nelder_mead_algorithm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import queue
import threading
import numpy.typing as npt

from collections.abc import Generator
from dataclasses import dataclass
import queue
import threading

import numpy as np
import numpy.typing as npt


@dataclass
Expand Down
10 changes: 10 additions & 0 deletions aiaccel/hpo/optuna/samplers/nelder_mead_sampler.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
from __future__ import annotations

<<<<<<< HEAD
import warnings
from collections.abc import Sequence
from typing import Any

import numpy as np
=======
from typing import Any

from collections.abc import Sequence
import warnings

import numpy as np

>>>>>>> 0caccda (update v2 torch (#383))
import optuna
from optuna.distributions import BaseDistribution
from optuna.study import Study
Expand Down
3 changes: 2 additions & 1 deletion aiaccel/hpo/optuna/wrapper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Generic, TypeVar

from collections.abc import Sequence
from dataclasses import dataclass
from typing import Generic, TypeVar

from optuna.trial import Trial

Expand Down
7 changes: 4 additions & 3 deletions aiaccel/job/abci_job.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

from typing import Any

from enum import IntEnum, auto
from pathlib import Path
import re
import subprocess
import time
from enum import IntEnum, auto
from pathlib import Path
from typing import Any
from xml.etree import ElementTree


Expand Down
5 changes: 3 additions & 2 deletions aiaccel/job/abci_job_manager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import time
from pathlib import Path
from typing import Any

from pathlib import Path
import time

from aiaccel.job.abci_job import AbciJob, JobStatus


Expand Down
51 changes: 51 additions & 0 deletions aiaccel/torch/datasets/cached_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,43 @@


class NumpiedTensor:
"""
A wrapper class that converts a PyTorch tensor to a NumPy array and vice versa.
Args:
tensor (torch.Tensor): The input PyTorch tensor.
Attributes:
array (np.ndarray): The NumPy array representation of the tensor.
Methods:
to_tensor: Converts the NumPy array back to a PyTorch tensor.
"""

def __init__(self, tensor: torch.Tensor) -> None:
self.array = tensor.numpy()

def to_tensor(self) -> torch.Tensor:
"""
Converts the NumPy array back to a PyTorch tensor.
Returns:
torch.Tensor: The PyTorch tensor representation of the NumPy array.
"""
return torch.tensor(self.array)


def numpize_sample(sample: Any) -> Any:
"""
Converts the input sample to a NumPy-compatible format.
Args:
sample (Any): The input sample to be converted.
Returns:
Any: The converted sample in a NumPy-compatible format.
"""

if isinstance(sample, torch.Tensor):
return NumpiedTensor(sample)
elif isinstance(sample, tuple):
Expand All @@ -30,6 +59,16 @@ def numpize_sample(sample: Any) -> Any:


def tensorize_sample(sample: Any) -> Any:
"""
Converts the given sample into a tensor representation.
Args:
sample (Any): The input sample to be tensorized.
Returns:
Any: The tensorized representation of the input sample.
"""

if isinstance(sample, NumpiedTensor):
return sample.to_tensor()
elif isinstance(sample, tuple):
Expand All @@ -46,6 +85,18 @@ def tensorize_sample(sample: Any) -> Any:


class CachedDataset(Dataset[T_co]):
"""
A dataset wrapper that caches the samples to improve performance.
Args:
dataset (Dataset): The original dataset to be wrapped.
Attributes:
dataset (Dataset): The original dataset.
manager (Manager): The multiprocessing manager.
cache (dict): The cache dictionary to store the cached samples.
"""

def __init__(self, dataset: Dataset[T_co]) -> None:
self.dataset = dataset

Expand Down
39 changes: 37 additions & 2 deletions aiaccel/torch/datasets/hdf5_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,33 @@
]


class RawHDF5Dataset(Dataset[int]):
class RawHDF5Dataset(Dataset[dict[str, Any]]):
"""
A dataset class for reading data from HDF5 files.
Args:
dataset_path (Union[Path, str]): The path to the HDF5 dataset file.
grp_list (Union[Path, str, List[str], None], optional): The list of groups to load from the dataset.
If None, all groups in the dataset will be loaded. If a string or Path, it should be the path to a file
containing the list of groups. If a list, it should directly specify the groups to load. Defaults to None.
Raises:
NotImplementedError: If grp_list is of an unsupported type.
Attributes:
dataset_path (Union[Path, str]): The path to the HDF5 dataset file.
grp_list (List[str]): The list of groups to load from the dataset.
f (Optional[h5.File]): The HDF5 file object used for reading the dataset.
"""

def __init__(self, dataset_path: Path | str, grp_list: Path | str | list[str] | None = None) -> None:
self.dataset_path = dataset_path

if grp_list is None:
with h5.File(self.dataset_path, "r") as f:
self.grp_list = list(f.keys())
elif isinstance(grp_list, (str, Path)):
elif isinstance(grp_list, (str | Path)):
with open(grp_list, "rb") as f:
self.grp_list = pkl.load(f)
elif isinstance(grp_list, list):
Expand All @@ -49,5 +68,21 @@ def __del__(self) -> None:


class HDF5Dataset(RawHDF5Dataset):
"""
A dataset class for loading data from an HDF5 file.
This class extends the `RawHDF5Dataset` class and provides a convenient way to load data from an HDF5 file
and convert it into a dictionary of torch tensors.
Args:
path (str): The path to the HDF5 file.
transform (callable, optional): A function/transform that takes in a dictionary of data and returns a
modified version. Default is None.
Returns:
dict[str, torch.Tensor]: A dictionary containing the data loaded from the HDF5 file, where the keys are
the names of the data fields and the values are torch tensors.
"""

def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
return {k: torch.as_tensor(v) for k, v in super().__getitem__(index).items()}
14 changes: 14 additions & 0 deletions aiaccel/torch/datasets/scatter_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,20 @@ def scatter_dataset(
dataset: Dataset[T],
permute_fn: Callable[[npt.NDArray[np.int64]], npt.NDArray[np.int64]] | None = None,
) -> Subset[T]:
"""
Splits a dataset into subsets and returns the subset corresponding to the current process rank.
Args:
dataset (Dataset[T]): The input dataset to be split.
permute_fn (Callable[[npt.NDArray[np.int64]], npt.NDArray[np.int64]] | None, optional):
A function that takes an array of indices and returns a permuted version of the array.
If None, a default permutation function using np.random.Generator is used.
Defaults to None.
Returns:
Subset[T]: The subset of the input dataset corresponding to the current process rank.
"""

if permute_fn is None:
permute_fn = np.random.Generator(np.random.PCG64(0)).permutation

Expand Down
10 changes: 8 additions & 2 deletions aiaccel/torch/lightning/abci_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@


class ABCIEnvironment(ClusterEnvironment):
"""
Environment class for ABCI.
This class provides methods to interact with the ABCI environment,
such as retrieving the world size, global rank, node rank, and local rank.
"""

def __init__(self) -> None:
self._world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])
self._rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
Expand Down Expand Up @@ -50,8 +57,7 @@ def set_world_size(self, size: int) -> None:

def set_global_rank(self, rank: int) -> None:
if rank != self.global_rank():
raise ValueError(f"`rank` is expected to be {self.get_global_rank()}, buf {rank} is given.")

raise ValueError(f"`rank` is expected to be {self.global_rank()}, buf {rank} is given.")

def validate_settings(self, num_devices: int, num_nodes: int) -> None:
if num_devices != self._local_size:
Expand Down
39 changes: 33 additions & 6 deletions aiaccel/torch/lightning/opt_lightning_module.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,55 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from collections.abc import Callable
from dataclasses import dataclass

import torch
from torch.optim import Optimizer

import lightning as lt
from lightning.pytorch.utilities.types import OptimizerLRSchedulerConfig

if TYPE_CHECKING:
from torch import optim


@dataclass
class OptimizerConfig:
optimizer_generator: Callable[..., torch.optim.Optimizer]
scheduler_generator: Callable[..., torch.optim.lr_scheduler.LRScheduler] | None = None
"""
Configuration class for the optimizer and scheduler in the LightningModule.
"""

optimizer_generator: Callable[..., optim.optimizer.Optimizer]
scheduler_generator: Callable[..., optim.lr_scheduler.LRScheduler] | None = None
scheduler_interval: str | None = "step"
scheduler_monitor: str | None = "validation/loss"


class OptimizerLightningModule(lt.LightningModule):
"""
LightningModule subclass for models that use custom optimizers and schedulers.
Args:
optimizer_config (OptimizerConfig): Configuration object for the optimizer.
Attributes:
optcfg (OptimizerConfig): Configuration object for the optimizer.
Methods:
configure_optimizers: Configures the optimizer and scheduler for training.
"""

def __init__(self, optimizer_config: OptimizerConfig):
super().__init__()

self.optcfg = optimizer_config

def configure_optimizers(self) -> Optimizer | OptimizerLRSchedulerConfig:
def configure_optimizers(self) -> optim.optimizer.Optimizer | OptimizerLRSchedulerConfig:
"""
Configures the optimizer and scheduler for training.
Returns:
Union[optim.optimizer.Optimizer, OptimizerLRSchedulerConfig]: The optimizer and scheduler configuration.
"""
optimizer = self.optcfg.optimizer_generator(params=self.parameters())
if self.optcfg.scheduler_generator is None:
return optimizer
Expand Down
2 changes: 0 additions & 2 deletions aiaccel/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from pathlib import Path
import re

import yaml

from colorama import Fore
from omegaconf import DictConfig, ListConfig
from omegaconf import OmegaConf as oc # noqa:N813
Expand Down
5 changes: 5 additions & 0 deletions examples/job/optuna/main_parallel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
<<<<<<< HEAD
import pickle as pkl
from pathlib import Path
=======
from pathlib import Path
import pickle as pkl
>>>>>>> 0caccda (update v2 torch (#383))

import optuna

Expand Down
5 changes: 5 additions & 0 deletions examples/job/optuna/main_serial.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
<<<<<<< HEAD
import pickle as pkl
from pathlib import Path
=======
from pathlib import Path
import pickle as pkl
>>>>>>> 0caccda (update v2 torch (#383))

import optuna

Expand Down
Loading

0 comments on commit 798d4ba

Please sign in to comment.