Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: degraded-only evaluation mode #35

Merged
merged 3 commits into from
Jun 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion rul_adapt/approach/adarul.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
num_disc_updates: int,
num_gen_updates: int,
rul_score_mode: Literal["phm08", "phm12"] = "phm08",
evaluate_degraded_only: bool = False,
**optim_kwargs: Any,
) -> None:
"""
Expand Down Expand Up @@ -83,14 +84,17 @@ def __init__(
self.num_disc_updates = num_disc_updates
self.num_gen_updates = num_gen_updates
self.rul_score_mode = rul_score_mode
self.evaluate_degraded_only = evaluate_degraded_only
self.optim_kwargs = optim_kwargs

self._disc_counter, self._gen_counter = 0, 0
self._get_optimizer = utils.OptimizerFactory(**self.optim_kwargs)

self.gan_loss = nn.BCEWithLogitsLoss()

self.evaluator = AdaptionEvaluator(self.forward, self.log, self.rul_score_mode)
self.evaluator = AdaptionEvaluator(
self.forward, self.log, self.rul_score_mode, self.evaluate_degraded_only
)

self.save_hyperparameters()

Expand Down
12 changes: 10 additions & 2 deletions rul_adapt/approach/conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(
fuzzy_sets: List[Tuple[float, float]],
loss_type: Literal["mse", "rmse", "mae"] = "mae",
rul_score_mode: Literal["phm08", "phm12"] = "phm08",
evaluate_degraded_only: bool = False,
**optim_kwargs: Any,
) -> None:
"""
Expand Down Expand Up @@ -75,6 +76,7 @@ def __init__(
self.dynamic_adaptive_factor = dynamic_adaptive_factor
self.loss_type = loss_type
self.rul_score_mode = rul_score_mode
self.evaluate_degraded_only = evaluate_degraded_only
self.optim_kwargs = optim_kwargs

self._get_optimizer = utils.OptimizerFactory(**self.optim_kwargs)
Expand All @@ -90,7 +92,9 @@ def __init__(
conditional_mmd_losses, fuzzy_sets, mean_over_sets=True
)

self.evaluator = AdaptionEvaluator(self.forward, self.log, self.rul_score_mode)
self.evaluator = AdaptionEvaluator(
self.forward, self.log, self.rul_score_mode, self.evaluate_degraded_only
)

self.save_hyperparameters()

Expand Down Expand Up @@ -221,6 +225,7 @@ def __init__(
fuzzy_sets: List[Tuple[float, float]],
loss_type: Literal["mse", "rmse", "mae"] = "mae",
rul_score_mode: Literal["phm08", "phm12"] = "phm08",
evaluate_degraded_only: bool = False,
**optim_kwargs: Any,
) -> None:
"""
Expand Down Expand Up @@ -255,11 +260,14 @@ def __init__(
self.loss_type = loss_type
self._fuzzy_sets = fuzzy_sets
self.rul_score_mode = rul_score_mode
self.evaluate_degraded_only = evaluate_degraded_only
self.optim_kwargs = optim_kwargs

self.train_source_loss = utils.get_loss(self.loss_type)
self._get_optimizer = utils.OptimizerFactory(**self.optim_kwargs)
self.evaluator = AdaptionEvaluator(self.forward, self.log, self.rul_score_mode)
self.evaluator = AdaptionEvaluator(
self.forward, self.log, self.rul_score_mode, self.evaluate_degraded_only
)

self.save_hyperparameters()

Expand Down
6 changes: 5 additions & 1 deletion rul_adapt/approach/consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__(
max_epochs: int,
loss_type: Literal["mse", "mae", "rmse"] = "rmse",
rul_score_mode: Literal["phm08", "phm12"] = "phm08",
evaluate_degraded_only: bool = False,
**optim_kwargs: Any,
) -> None:
"""
Expand Down Expand Up @@ -105,12 +106,15 @@ def __init__(
self.max_epochs = max_epochs
self.loss_type = loss_type
self.rul_score_mode = rul_score_mode
self.evaluate_degraded_only = evaluate_degraded_only
self.optim_kwargs = optim_kwargs

self.train_source_loss = utils.get_loss(loss_type)
self.consistency_loss = rul_adapt.loss.ConsistencyLoss()
self._get_optimizer = utils.OptimizerFactory(**self.optim_kwargs)
self.evaluator = AdaptionEvaluator(self.forward, self.log, self.rul_score_mode)
self.evaluator = AdaptionEvaluator(
self.forward, self.log, self.rul_score_mode, self.evaluate_degraded_only
)

self.save_hyperparameters()

Expand Down
6 changes: 5 additions & 1 deletion rul_adapt/approach/dann.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(
dann_factor: float,
loss_type: Literal["mae", "mse", "rmse"] = "mae",
rul_score_mode: Literal["phm08", "phm12"] = "phm08",
evaluate_degraded_only: bool = False,
**optim_kwargs: Any,
):
"""
Expand Down Expand Up @@ -97,12 +98,15 @@ def __init__(
self.dann_factor = dann_factor
self.loss_type = loss_type
self.rul_score_mode = rul_score_mode
self.evaluate_degraded_only = evaluate_degraded_only
self.optim_kwargs = optim_kwargs

self._get_optimizer = utils.OptimizerFactory(**self.optim_kwargs)

self.train_source_loss = utils.get_loss(self.loss_type)
self.evaluator = AdaptionEvaluator(self.forward, self.log, self.rul_score_mode)
self.evaluator = AdaptionEvaluator(
self.forward, self.log, self.rul_score_mode, self.evaluate_degraded_only
)

self.save_hyperparameters()

Expand Down
27 changes: 25 additions & 2 deletions rul_adapt/approach/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, List, Literal
from typing import Callable, List, Literal, Tuple

import torch
import torchmetrics
Expand All @@ -13,12 +13,14 @@ def __init__(
network_func: Callable[[torch.Tensor], torch.Tensor],
log_func: Callable[[str, torchmetrics.Metric], None],
score_mode: Literal["phm08", "phm12"] = "phm08",
degraded_only: bool = False,
):
super().__init__()

self.network_func = network_func
self.log_func = log_func
self.score_mode = score_mode
self.degraded_only = degraded_only

self.val_metrics = self._get_default_metrics()
self.test_metrics = self._get_default_metrics()
Expand Down Expand Up @@ -52,10 +54,15 @@ def test(
self._evaluate("test", self.test_metrics, batch, domain)

def _evaluate(
self, prefix, metrics, batch, domain: Literal["source", "target"]
self,
prefix: str,
metrics: nn.ModuleDict,
batch: List[torch.Tensor],
domain: Literal["source", "target"],
) -> None:
self._check_domain(domain, prefix)
features, labels = batch
features, labels = filter_batch(features, labels, self.degraded_only)
labels = labels[:, None]
predictions = self.network_func(features)
for metric_name, metric in metrics[domain].items():
Expand All @@ -68,3 +75,19 @@ def _check_domain(self, domain: str, prefix: str) -> None:
f"Unexpected {prefix} domain '{domain}'. "
"Use either 'source' or 'target'."
)


def filter_batch(
features: torch.Tensor, labels: torch.Tensor, degraded_only: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
if degraded_only:
if torch.any(labels > 1.0):
raise RuntimeError(
"Degradation-only evaluation configured which works only with "
"normalized RUL, but labels contain values greater than 1.0."
)
degraded = labels < 1.0
features = features[degraded]
labels = labels[degraded]

return features, labels
6 changes: 5 additions & 1 deletion rul_adapt/approach/latent_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ def __init__(
alpha_fusion: float,
loss_type: Literal["mse", "mae", "rmse"] = "mse",
rul_score_mode: Literal["phm08", "phm12"] = "phm08",
evaluate_degraded_only: bool = False,
**optim_kwargs: Any,
) -> None:
"""
Expand Down Expand Up @@ -374,6 +375,7 @@ def __init__(
self.alpha_fusion = alpha_fusion
self.loss_type = loss_type
self.rul_score_mode = rul_score_mode
self.evaluate_degraded_only = evaluate_degraded_only
self.optim_kwargs = optim_kwargs

# training metrics
Expand All @@ -384,7 +386,9 @@ def __init__(
self.fusion_align = rul_adapt.loss.MaximumMeanDiscrepancyLoss(num_kernels=5)
self._get_optimizer = utils.OptimizerFactory(**self.optim_kwargs)

self.evaluator = AdaptionEvaluator(self.forward, self.log, self.rul_score_mode)
self.evaluator = AdaptionEvaluator(
self.forward, self.log, self.rul_score_mode, self.evaluate_degraded_only
)

self.save_hyperparameters()

Expand Down
6 changes: 5 additions & 1 deletion rul_adapt/approach/mmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
num_mmd_kernels: int = 5,
loss_type: Literal["mse", "rmse", "mae"] = "mse",
rul_score_mode: Literal["phm08", "phm12"] = "phm08",
evaluate_degraded_only: bool = False,
**optim_kwargs: Any,
) -> None:
"""
Expand All @@ -83,13 +84,16 @@ def __init__(
self.num_mmd_kernels = num_mmd_kernels
self.loss_type = loss_type
self.rul_score_mode = rul_score_mode
self.evaluate_degraded_only = evaluate_degraded_only
self.optim_kwargs = optim_kwargs

# training metrics
self.train_source_loss = utils.get_loss(self.loss_type)
self.mmd_loss = rul_adapt.loss.MaximumMeanDiscrepancyLoss(self.num_mmd_kernels)
self._get_optimizer = utils.OptimizerFactory(**self.optim_kwargs)
self.evaluator = AdaptionEvaluator(self.forward, self.log)
self.evaluator = AdaptionEvaluator(
self.forward, self.log, self.rul_score_mode, self.evaluate_degraded_only
)

self.save_hyperparameters()

Expand Down
5 changes: 4 additions & 1 deletion rul_adapt/approach/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from rul_adapt import utils
from rul_adapt.approach.abstract import AdaptionApproach
from rul_adapt.approach.evaluation import filter_batch


class SupervisedApproach(AdaptionApproach):
Expand All @@ -38,6 +39,7 @@ def __init__(
self,
loss_type: Literal["mse", "mae", "rmse"],
rul_scale: int = 1,
evaluate_degraded_only: bool = False,
**optim_kwargs: Any,
) -> None:
"""
Expand All @@ -58,6 +60,7 @@ def __init__(

self.loss_type = loss_type
self.rul_scale = rul_scale
self.evaluate_degraded_only = evaluate_degraded_only
self.optim_kwargs = optim_kwargs

self.train_loss = utils.get_loss(loss_type)
Expand Down Expand Up @@ -109,7 +112,7 @@ def validation_step(
batch: A list of feature and label tensors.
batch_idx: The index of the current batch.
"""
inputs, labels = batch
inputs, labels = filter_batch(*batch, degraded_only=self.evaluate_degraded_only)
predictions = self.forward(inputs)
self.val_loss(predictions, labels[:, None])
self.log("val/loss", self.val_loss)
21 changes: 21 additions & 0 deletions tests/test_approach/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,24 @@ def test_metric_aggregation(lightning_module, eval_func):

npt.assert_almost_equal(exp_rmse, actual_rmse, decimal=5)
npt.assert_almost_equal(exp_score, actual_score, decimal=5)


@pytest.mark.parametrize(
["step_func", "prefix"], [("validation", "val"), ("test", "test")]
)
@pytest.mark.parametrize("domain", ["source", "target"])
def test_degraded_only_evaluation(mocked_evaluator, domain, prefix, step_func):
mocked_evaluator.degraded_only = True
healthy_lables = torch.ones(5)
healthy_features = torch.ones(5, 3, 5)
degraded_labels = torch.rand(5) # smaller than 1.0
degraded_features = torch.zeros(5, 3, 5)
batch = [
torch.cat([healthy_features, degraded_features]),
torch.cat([healthy_lables, degraded_labels]),
]

getattr(mocked_evaluator, step_func)(batch, domain)

(actual_network_input,) = mocked_evaluator.network_func.call_args.args
assert torch.dist(actual_network_input, degraded_features) == 0.0