Skip to content

Commit

Permalink
chore: enable use of torch 2 (#54)
Browse files Browse the repository at this point in the history
* chore: update dependencies

* tests: fix failng tests

* ci: add tests for torch 1

* fix: linting issues

* ci: ignore line length in flake8

is already covered by the more lenient black rules

* fix: type disagreement between lightning 1 and 2

lightning 2 introduces OptimizerLRSchedulerConfig which is needed as a return type for configure_optimizers. To keep compatibility with lightning 1 some import hacks were needed.
  • Loading branch information
tilman151 authored Jan 12, 2024
1 parent 2cf9fd0 commit 9bc8d4b
Show file tree
Hide file tree
Showing 13 changed files with 1,408 additions and 1,299 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[flake8]
max-line-length = 88
extend-ignore = E203
extend-ignore = E203, E501
per-file-ignores =
# imported but unused
__init__.py: F401
23 changes: 22 additions & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,33 @@ jobs:
poetry-version: "1.2.2"
- run: poetry run pytest tests -m "not integration"

legacy-unit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: ./.github/actions/install-poetry
with:
poetry-version: "1.2.2"
- run: poetry run pip install "pytorch-lightning<2.0.0"
- run: poetry run pytest tests -m "not integration"

integration:
runs-on: ubuntu-latest
needs: [unit]
needs: [ legacy-unit ]
steps:
- uses: actions/checkout@v3
- uses: ./.github/actions/install-poetry
with:
poetry-version: "1.2.2"
- run: poetry run pytest tests -m integration

legacy-integration:
runs-on: ubuntu-latest
needs: [ legacy-unit ]
steps:
- uses: actions/checkout@v3
- uses: ./.github/actions/install-poetry
with:
poetry-version: "1.2.2"
- run: poetry run pip install "pytorch-lightning<2.0.0"
- run: poetry run pytest tests -m integration
2,620 changes: 1,349 additions & 1,271 deletions poetry.lock

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ packages = [{include = "rul_adapt"}]

[tool.poetry.dependencies]
python = "^3.8"
pytorch-lightning = "^1.8.0.post1"
rul-datasets = ">=0.10.5"
pytorch-lightning = ">1.8.0.post1"
rul-datasets = ">=0.14.0"
tqdm = "^4.62.2"
hydra-core = "^1.3.1"
pywavelets = "^1.4.1"
Expand All @@ -34,7 +34,7 @@ pytest-mock = "^3.10.0"
optional = true

[tool.poetry.group.docs.dependencies]
mkdocstrings = {extras = ["python"], version = "^0.22.0"}
mkdocstrings = {extras = ["python"], version = "^0.24.0"}
mkdocs-gen-files = "^0.5.0"
mkdocs-literate-nav = "^0.5.0"
mkdocs-section-index = "^0.3.4"
Expand Down
8 changes: 4 additions & 4 deletions rul_adapt/approach/adarul.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
https://doi.org/10.1109/ICPHM49022.2020.9187053) and evaluated on the CMAPSS dataset."""

import copy
from typing import Optional, Any, List, Dict, Literal
from typing import Optional, Any, List, Literal

import torch
from torch import nn
Expand Down Expand Up @@ -154,7 +154,7 @@ def domain_disc(self):
else:
raise RuntimeError("Domain disc used before 'set_model' was called.")

def configure_optimizers(self) -> List[Dict[str, Any]]:
def configure_optimizers(self) -> List[utils.OptimizerLRSchedulerConfig]: # type: ignore[override]
"""Configure an optimizer for the generator and discriminator respectively."""
return [
self._get_optimizer(self.domain_disc.parameters()),
Expand Down Expand Up @@ -196,12 +196,12 @@ def training_step(self, batch: List[torch.Tensor], batch_idx: int) -> torch.Tens
self._reset_update_counters()

if self._should_update_disc():
optim, _ = self.optimizers() # type: ignore[misc]
optim, _ = self.optimizers() # type: ignore[attr-defined, misc]
loss = self._get_disc_loss(source, target)
self.log("train/disc_loss", loss)
self._disc_counter += 1
elif self._should_update_gen():
_, optim = self.optimizers() # type: ignore[misc]
_, optim = self.optimizers() # type: ignore[attr-defined, misc]
loss = self._get_gen_loss(target)
self.log("train/gen_loss", loss)
self._gen_counter += 1
Expand Down
7 changes: 4 additions & 3 deletions rul_adapt/approach/conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
[Cheng et al.](https://doi.org/10.1007/s10845-021-01814-y) in 2021."""

from copy import deepcopy
from typing import List, Tuple, Literal, Optional, Any, Dict
from typing import List, Tuple, Literal, Optional, Any

import torch
from torch import nn
Expand All @@ -19,6 +19,7 @@
from rul_adapt.approach.abstract import AdaptionApproach
from rul_adapt.approach.evaluation import AdaptionEvaluator
from rul_adapt.model import FullyConnectedHead
from rul_adapt.utils import OptimizerLRSchedulerConfig


class ConditionalMmdApproach(AdaptionApproach):
Expand Down Expand Up @@ -104,7 +105,7 @@ def __init__(
def fuzzy_sets(self) -> List[Tuple[float, float]]:
return self.conditional_mmd_loss.fuzzy_sets

def configure_optimizers(self) -> Dict[str, Any]:
def configure_optimizers(self) -> OptimizerLRSchedulerConfig:
"""Configure an Adam optimizer."""
return self._get_optimizer(self.parameters())

Expand Down Expand Up @@ -333,7 +334,7 @@ def _check_domain_disc(self, domain_disc: Optional[nn.Module]) -> nn.Module:
def domain_disc(self) -> nn.Module:
return self.dann_loss.domain_disc

def configure_optimizers(self) -> Dict[str, Any]:
def configure_optimizers(self) -> utils.OptimizerLRSchedulerConfig:
"""Configure an Adam optimizer."""
return self._get_optimizer(self.parameters())

Expand Down
4 changes: 2 additions & 2 deletions rul_adapt/approach/consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import copy
import math
from itertools import chain
from typing import Optional, Any, List, Tuple, Dict, Literal
from typing import Optional, Any, List, Tuple, Literal

import numpy as np
import torch
Expand Down Expand Up @@ -183,7 +183,7 @@ def dann_factor(self):
"""
return 2 / (1 + math.exp(-10 * self.current_epoch / self.max_epochs)) - 1

def configure_optimizers(self) -> Dict[str, Any]:
def configure_optimizers(self) -> utils.OptimizerLRSchedulerConfig:
"""Configure an optimizer to train the feature extractor, regressor and
domain discriminator."""
parameters = chain(
Expand Down
4 changes: 2 additions & 2 deletions rul_adapt/approach/dann.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
[10.1109/ICPHM49022.2020.9187058](https://doi.org/10.1109/ICPHM49022.2020.9187058)
"""

from typing import Any, Optional, Dict, Literal, List
from typing import Any, Optional, Literal, List

import torch
from torch import nn
Expand Down Expand Up @@ -163,7 +163,7 @@ def domain_disc(self):
else:
raise RuntimeError("Domain disc used before 'set_model' was called.")

def configure_optimizers(self) -> Dict[str, Any]:
def configure_optimizers(self) -> utils.OptimizerLRSchedulerConfig:
"""Configure an optimizer for the whole model."""
return self._get_optimizer(self.parameters())

Expand Down
7 changes: 4 additions & 3 deletions rul_adapt/approach/latent_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
[rul_adapt.approach.latent_align.LatentAlignFttpApproach] introduced by [Li et al.](
https://doi.org/10.1016/j.knosys.2020.105843) in 2020."""

from typing import Tuple, List, Any, Optional, Literal, Dict
from typing import Tuple, List, Any, Optional, Literal

import numpy as np
import torch
Expand All @@ -35,6 +35,7 @@
from rul_adapt import utils
from rul_adapt.approach.abstract import AdaptionApproach
from rul_adapt.approach.evaluation import AdaptionEvaluator
from rul_adapt.utils import OptimizerLRSchedulerConfig


class LatentAlignFttpApproach(AdaptionApproach):
Expand Down Expand Up @@ -134,7 +135,7 @@ def generator(self):
else:
raise RuntimeError("Generator used before 'set_model' was called.")

def configure_optimizers(self) -> Dict[str, Any]:
def configure_optimizers(self) -> OptimizerLRSchedulerConfig:
"""Configure an optimizer for the generator and discriminator."""
return self._get_optimizer(self.parameters())

Expand Down Expand Up @@ -397,7 +398,7 @@ def __init__(

self.save_hyperparameters()

def configure_optimizers(self) -> Dict[str, Any]:
def configure_optimizers(self) -> utils.OptimizerLRSchedulerConfig:
"""Configure an optimizer."""
optim = self._get_optimizer(self.parameters())

Expand Down
4 changes: 2 additions & 2 deletions rul_adapt/approach/mmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
[10.1109/ICPHM49022.2020.9187058](https://doi.org/10.1109/ICPHM49022.2020.9187058)
"""

from typing import List, Literal, Any, Dict
from typing import List, Literal, Any

import torch

Expand Down Expand Up @@ -99,7 +99,7 @@ def __init__(

self.save_hyperparameters()

def configure_optimizers(self) -> Dict[str, Any]:
def configure_optimizers(self) -> utils.OptimizerLRSchedulerConfig:
"""Configure an optimizer."""
return self._get_optimizer(self.parameters())

Expand Down
4 changes: 2 additions & 2 deletions rul_adapt/approach/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
```
"""

from typing import Literal, Any, Dict, List
from typing import Literal, Any, List

import torch
import torchmetrics
Expand Down Expand Up @@ -78,7 +78,7 @@ def __init__(

self.save_hyperparameters()

def configure_optimizers(self) -> Dict[str, Any]:
def configure_optimizers(self) -> utils.OptimizerLRSchedulerConfig:
return self._get_optimizer(self.parameters())

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
Expand Down
14 changes: 11 additions & 3 deletions rul_adapt/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import warnings
from itertools import tee
from typing import Union, Callable, Literal, Any, Dict, Optional, Iterable
from typing import Union, Callable, Literal, Any, Optional, Iterable

import pytorch_lightning
import torch
import torchmetrics
from torch import nn

if pytorch_lightning.__version__.startswith("2."):
from pytorch_lightning.utilities.types import OptimizerLRSchedulerConfig # type: ignore
else:
OptimizerLRSchedulerConfig = dict # type: ignore


def pairwise(iterable):
"""s -> (s0,s1), (s1,s2), (s2, s3), ..."""
Expand Down Expand Up @@ -109,7 +115,9 @@ def _is_excess_kwarg(key: str) -> bool:
f"will be ignored: {excess_kwargs}."
)

def __call__(self, parameters: Iterable[nn.Parameter]) -> Dict[str, Any]:
def __call__(
self, parameters: Iterable[nn.Parameter]
) -> OptimizerLRSchedulerConfig:
"""
Create an optimizer with an optional scheduler for the given parameters.
Expand All @@ -128,7 +136,7 @@ def __call__(self, parameters: Iterable[nn.Parameter]) -> Dict[str, Any]:
if key.startswith("optim_")
}
optim = self._optim_func(parameters, lr=self.lr, **optim_kwargs)
optim_config = {"optimizer": optim}
optim_config = OptimizerLRSchedulerConfig(optimizer=optim)

if self.scheduler_type is not None:
scheduler_kwargs = {
Expand Down
4 changes: 2 additions & 2 deletions tests/test_approach/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def validation_step(self, batch, batch_idx, dataloader_idx):
domain = "source" if dataloader_idx == 0 else "target"
self.evaluator.validation(batch, domain)

def test_step(self, batch, batch_idx, data_loader_idx):
domain = "source" if data_loader_idx == 0 else "target"
def test_step(self, batch, batch_idx, dataloader_idx):
domain = "source" if dataloader_idx == 0 else "target"
self.evaluator.test(batch, domain)


Expand Down

0 comments on commit 9bc8d4b

Please sign in to comment.