From 481b5dcf1de9ad5d714872912ba62c273b31ce21 Mon Sep 17 00:00:00 2001 From: Tilman Krokotsch Date: Thu, 7 Dec 2023 14:52:21 +0100 Subject: [PATCH] feat: use data module for paired dataset (#45) * feat: use data module for paired dataset This enables using the feature extractors of the data modules, e.g., for extracting spectra. * fix: truncated validation --- rul_datasets/adaption.py | 43 +++++++++++++++++-------------------- rul_datasets/baseline.py | 46 ++++++++++++++++++++-------------------- rul_datasets/core.py | 35 ++++++++++++++++-------------- tests/test_adaption.py | 2 +- tests/test_baseline.py | 4 ++-- tests/test_core.py | 45 ++++++++++++++++++++++----------------- tests/test_hydra.py | 8 +++---- 7 files changed, 95 insertions(+), 88 deletions(-) diff --git a/rul_datasets/adaption.py b/rul_datasets/adaption.py index 1751fe4..a5d2d00 100644 --- a/rul_datasets/adaption.py +++ b/rul_datasets/adaption.py @@ -69,8 +69,8 @@ def __init__( self.batch_size = source.batch_size self.inductive = inductive - self.target_truncated = deepcopy(self.target.reader) - self.target_truncated.truncate_val = True + self.target_truncated = deepcopy(self.target) + self.target_truncated.reader.truncate_val = True self._check_compatibility() @@ -85,7 +85,7 @@ def __init__( def _check_compatibility(self): self.source.check_compatibility(self.target) - self.target.reader.check_compatibility(self.target_truncated) + self.target.reader.check_compatibility(self.target_truncated.reader) if self.source.reader.fd == self.target.reader.fd: raise ValueError( f"FD of source and target has to be different for " @@ -463,58 +463,55 @@ def __init__( self.min_distance = min_distance self.distance_mode = distance_mode - self.target_loader = self.target.reader - self.source_loader = self.source.reader - self._check_compatibility() self.save_hyperparameters( { - "fd_source": self.source_loader.fd, - "fd_target": self.target_loader.fd, + "fd_source": self.source.reader.fd, + "fd_target": self.target.reader.fd, "num_samples": self.num_samples, "batch_size": self.batch_size, - "window_size": self.source_loader.window_size, - "max_rul": self.source_loader.max_rul, + "window_size": self.source.reader.window_size, + "max_rul": self.source.reader.max_rul, "min_distance": self.min_distance, - "percent_broken": self.target_loader.percent_broken, - "percent_fail_runs": self.target_loader.percent_fail_runs, - "truncate_target_val": self.target_loader.truncate_val, + "percent_broken": self.target.reader.percent_broken, + "percent_fail_runs": self.target.reader.percent_fail_runs, + "truncate_target_val": self.target.reader.truncate_val, "distance_mode": self.distance_mode, } ) def _check_compatibility(self): self.source.check_compatibility(self.target) - if self.source_loader.fd == self.target_loader.fd: + if self.source.reader.fd == self.target.reader.fd: raise ValueError( f"FD of source and target has to be different for " - f"domain adaption, but is {self.source_loader.fd} bot times." + f"domain adaption, but is {self.source.reader.fd} both times." ) if ( - self.target_loader.percent_broken is None - or self.target_loader.percent_broken == 1.0 + self.target.reader.percent_broken is None + or self.target.reader.percent_broken == 1.0 ): raise ValueError( "Target data needs a percent_broken smaller than 1 for pre-training." ) if ( - self.source_loader.percent_broken is not None - and self.source_loader.percent_broken < 1.0 + self.source.reader.percent_broken is not None + and self.source.reader.percent_broken < 1.0 ): raise ValueError( "Source data cannot have a percent_broken smaller than 1, " "otherwise it would not be failed, labeled data." ) - if not self.target_loader.truncate_val: + if not self.target.reader.truncate_val: warnings.warn( "Validation data of unfailed runs is not truncated. " "The validation metrics will not be valid." ) def prepare_data(self, *args, **kwargs): - self.source_loader.prepare_data() - self.target_loader.prepare_data() + self.source.reader.prepare_data() + self.target.reader.prepare_data() def setup(self, stage: Optional[str] = None): self.source.setup(stage) @@ -539,7 +536,7 @@ def _get_paired_dataset(self, split: str) -> PairedRulDataset: min_distance = 1 if split == "val" else self.min_distance num_samples = 50000 if split == "val" else self.num_samples paired = PairedRulDataset( - [self.source_loader, self.target_loader], + [self.source, self.target], split, num_samples, min_distance, diff --git a/rul_datasets/baseline.py b/rul_datasets/baseline.py index e5f1411..c569976 100644 --- a/rul_datasets/baseline.py +++ b/rul_datasets/baseline.py @@ -133,80 +133,80 @@ def __init__( ): super().__init__() - self.failed_loader = failed_data_module.reader - self.unfailed_loader = unfailed_data_module.reader + self.failed = failed_data_module + self.unfailed = unfailed_data_module self.num_samples = num_samples self.batch_size = failed_data_module.batch_size self.min_distance = min_distance self.distance_mode = distance_mode - self.window_size = self.unfailed_loader.window_size + self.window_size = self.unfailed.reader.window_size self.source = unfailed_data_module self._check_loaders() self.save_hyperparameters( { - "fd_source": self.unfailed_loader.fd, + "fd_source": self.unfailed.reader.fd, "num_samples": self.num_samples, "batch_size": self.batch_size, "window_size": self.window_size, - "max_rul": self.unfailed_loader.max_rul, + "max_rul": self.unfailed.reader.max_rul, "min_distance": self.min_distance, - "percent_broken": self.unfailed_loader.percent_broken, - "percent_fail_runs": self.failed_loader.percent_fail_runs, - "truncate_val": self.unfailed_loader.truncate_val, + "percent_broken": self.unfailed.reader.percent_broken, + "percent_fail_runs": self.failed.reader.percent_fail_runs, + "truncate_val": self.unfailed.reader.truncate_val, "distance_mode": self.distance_mode, } ) def _check_loaders(self): - self.failed_loader.check_compatibility(self.unfailed_loader) - if not self.failed_loader.fd == self.unfailed_loader.fd: + self.failed.reader.check_compatibility(self.unfailed.reader) + if not self.failed.reader.fd == self.unfailed.reader.fd: raise ValueError("Failed and unfailed data need to come from the same FD.") - if self.failed_loader.percent_fail_runs is None or isinstance( - self.failed_loader.percent_fail_runs, float + if self.failed.reader.percent_fail_runs is None or isinstance( + self.failed.reader.percent_fail_runs, float ): raise ValueError( "Failed data needs list of failed runs " "for pre-training but uses a float or is None." ) - if self.unfailed_loader.percent_fail_runs is None or isinstance( - self.unfailed_loader.percent_fail_runs, float + if self.unfailed.reader.percent_fail_runs is None or isinstance( + self.unfailed.reader.percent_fail_runs, float ): raise ValueError( "Unfailed data needs list of failed runs " "for pre-training but uses a float or is None." ) - if set(self.failed_loader.percent_fail_runs).intersection( - self.unfailed_loader.percent_fail_runs + if set(self.failed.reader.percent_fail_runs).intersection( + self.unfailed.reader.percent_fail_runs ): raise ValueError( "Runs of failed and unfailed data overlap. " "Please use mututally exclusive sets of runs." ) if ( - self.unfailed_loader.percent_broken is None - or self.unfailed_loader.percent_broken == 1.0 + self.unfailed.reader.percent_broken is None + or self.unfailed.reader.percent_broken == 1.0 ): raise ValueError( "Unfailed data needs a percent_broken smaller than 1 for pre-training." ) if ( - self.failed_loader.percent_broken is not None - and self.failed_loader.percent_broken < 1.0 + self.failed.reader.percent_broken is not None + and self.failed.reader.percent_broken < 1.0 ): raise ValueError( "Failed data cannot have a percent_broken smaller than 1, " "otherwise it would not be failed data." ) - if not self.unfailed_loader.truncate_val: + if not self.unfailed.reader.truncate_val: warnings.warn( "Validation data of unfailed runs is not truncated. " "The validation metrics will not be valid." ) def prepare_data(self, *args, **kwargs): - self.unfailed_loader.prepare_data() + self.unfailed.reader.prepare_data() def setup(self, stage: Optional[str] = None): self.source.setup(stage) @@ -229,7 +229,7 @@ def _get_paired_dataset(self, split: str) -> PairedRulDataset: min_distance = 1 if split == "val" else self.min_distance num_samples = 25000 if split == "val" else self.num_samples paired = PairedRulDataset( - [self.unfailed_loader, self.failed_loader], + [self.unfailed, self.failed], split, num_samples, min_distance, diff --git a/rul_datasets/core.py b/rul_datasets/core.py index e5a2300..ca33cd9 100644 --- a/rul_datasets/core.py +++ b/rul_datasets/core.py @@ -377,7 +377,7 @@ class PairedRulDataset(IterableDataset): def __init__( self, - readers: List[AbstractReader], + dms: List[RulDataModule], split: str, num_samples: int, min_distance: int, @@ -386,19 +386,19 @@ def __init__( ): super().__init__() - self.readers = readers + self.dms = dms self.split = split self.min_distance = min_distance self.num_samples = num_samples self.deterministic = deterministic self.mode = mode - for reader in self.readers: - reader.check_compatibility(self.readers[0]) + for dm in self.dms: + dm.check_compatibility(self.dms[0]) self._run_domain_idx: np.ndarray - self._features: List[np.ndarray] - self._labels: List[np.ndarray] + self._features: List[torch.Tensor] + self._labels: List[torch.Tensor] self._prepare_datasets() self._max_rul = self._get_max_rul() @@ -412,13 +412,16 @@ def __init__( self._get_pair_func = self._get_labeled_pair_idx def _get_max_rul(self): - max_ruls = [reader.max_rul for reader in self.readers] - if any(m is None for m in max_ruls): + max_ruls = [dm.reader.max_rul for dm in self.dms] + if all(m is None for m in max_ruls): + max_rul = 1e10 + elif any(m is None for m in max_ruls): raise ValueError( - "PairedRulDataset needs a set max_rul for all readers " - "but at least one of them has is None." + "PairedRulDataset needs a set max_rul for all or none of the readers " + "but at least one and not all of them has None." ) - max_rul = max(max_ruls) + else: + max_rul = max(max_ruls) return max_rul @@ -426,8 +429,8 @@ def _prepare_datasets(self): run_domain_idx = [] features = [] labels = [] - for domain_idx, reader in enumerate(self.readers): - run_features, run_labels = reader.load_split(self.split) + for domain_idx, dm in enumerate(self.dms): + run_features, run_labels = dm.load_split(self.split) for feat, lab in zip(run_features, run_labels): if len(feat) > self.min_distance: run_domain_idx.append(domain_idx) @@ -530,14 +533,14 @@ def _get_labeled_pair_idx(self) -> Tuple[int, int, int, int, int]: def _build_pair( self, - run: np.ndarray, + run: torch.Tensor, anchor_idx: int, query_idx: int, distance: int, domain_label: int, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - anchors = utils.feature_to_tensor(run[anchor_idx], torch.float) - queries = utils.feature_to_tensor(run[query_idx], torch.float) + anchors = run[anchor_idx] + queries = run[query_idx] domain_tensor = torch.tensor(domain_label, dtype=torch.float) distances = torch.tensor(distance, dtype=torch.float) / self._max_rul distances = torch.clamp_max(distances, max=1) # max distance is max_rul diff --git a/tests/test_adaption.py b/tests/test_adaption.py index 3dac59d..0c70ff2 100644 --- a/tests/test_adaption.py +++ b/tests/test_adaption.py @@ -131,7 +131,7 @@ def test_test_dataloader(self): def test_truncated_loader(self): self.assertIsNot(self.dataset.target.reader, self.dataset.target_truncated) - self.assertTrue(self.dataset.target_truncated.truncate_val) + self.assertTrue(self.dataset.target_truncated.reader.truncate_val) def test_hparams(self): expected_hparams = { diff --git a/tests/test_baseline.py b/tests/test_baseline.py index f53c5d8..0ec60c2 100644 --- a/tests/test_baseline.py +++ b/tests/test_baseline.py @@ -128,8 +128,8 @@ def test_both_source_datasets_used(self): ) for split in ["dev", "val"]: with self.subTest(split): - num_broken_runs = len(dataset.unfailed_loader.load_split(split)[0]) - num_fail_runs = len(dataset.failed_loader.load_split(split)[0]) + num_broken_runs = len(dataset.unfailed.reader.load_split(split)[0]) + num_fail_runs = len(dataset.failed.reader.load_split(split)[0]) paired_dataset = dataset._get_paired_dataset(split) self.assertEqual( num_broken_runs + num_fail_runs, len(paired_dataset._features) diff --git a/tests/test_core.py b/tests/test_core.py index 4e8c19a..6cc8c2b 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -9,7 +9,7 @@ import torch from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset -from rul_datasets import core, reader +from rul_datasets import core, reader, RulDataModule @pytest.fixture() @@ -249,6 +249,10 @@ class DummyRul(reader.AbstractReader): fd: int = 1 window_size: int = 30 max_rul: int = 125 + percent_broken = None + percent_fail_runs = None + truncate_val = False + truncate_degraded_only = False def __init__(self, length): self.data = { @@ -279,11 +283,11 @@ def check_compatibility(self, other) -> None: def prepare_data(self): pass - def load_complete_split(self, split): + def load_complete_split(self, split, alias): return self.data[split] - def load_split(self, split): - return self.load_complete_split(split) + def load_split(self, split, alias): + return self.load_complete_split(split, alias) @dataclass @@ -293,6 +297,10 @@ class DummyRulShortRuns(reader.AbstractReader): fd: int = 1 window_size: int = 30 max_rul: int = 125 + percent_broken = None + percent_fail_runs = None + truncate_val = False + truncate_degraded_only = False data = { "dev": ( [ @@ -329,14 +337,14 @@ def check_compatibility(self, other) -> None: def prepare_data(self): pass - def load_complete_split(self, split): + def load_complete_split(self, split, alias): if not split == "dev": raise ValueError(f"DummyRulShortRuns does not have a '{split}' split") return self.data["dev"] - def load_split(self, split): - return self.load_complete_split(split) + def load_split(self, split, alias): + return self.load_complete_split(split, alias) @pytest.fixture(scope="module") @@ -346,12 +354,12 @@ def length(): @pytest.fixture def cmapss_normal(length): - return DummyRul(length) + return RulDataModule(DummyRul(length), 32) @pytest.fixture def cmapss_short(): - return DummyRulShortRuns() + return RulDataModule(DummyRulShortRuns(), 32) class TestPairedDataset: @@ -405,12 +413,8 @@ def test_sampled_data(self, cmapss_short): for i, sample in enumerate(data): idx = 3 * i expected_run = data._features[fixed_idx[idx]] - expected_anchor = torch.tensor(expected_run[fixed_idx[idx + 1]]).transpose( - 1, 0 - ) - expected_query = torch.tensor(expected_run[fixed_idx[idx + 2]]).transpose( - 1, 0 - ) + expected_anchor = torch.tensor(expected_run[fixed_idx[idx + 1]]) + expected_query = torch.tensor(expected_run[fixed_idx[idx + 2]]) expected_distance = min(125, fixed_idx[idx + 2] - fixed_idx[idx + 1]) / 125 expected_domain_idx = 0 assert 0 == torch.dist(expected_anchor, sample[0]) @@ -523,10 +527,13 @@ def _is_same_batch(b0, b1): def test_compatability_check(self): mock_check_compat = mock.MagicMock(name="check_compatibility") - loaders = [DummyRulShortRuns(), DummyRulShortRuns(window_size=20)] - for lod in loaders: - lod.check_compatibility = mock_check_compat + dms = [ + RulDataModule(DummyRulShortRuns(), 32), + RulDataModule(DummyRulShortRuns(window_size=20), 32), + ] + for dm in dms: + dm.check_compatibility = mock_check_compat - core.PairedRulDataset(loaders, "dev", 1000, 1) + core.PairedRulDataset(dms, "dev", 1000, 1) assert 2 == mock_check_compat.call_count diff --git a/tests/test_hydra.py b/tests/test_hydra.py index d88ef2d..b3334de 100644 --- a/tests/test_hydra.py +++ b/tests/test_hydra.py @@ -46,8 +46,8 @@ def test_dm_pre(self): ) cmapss_dm = hydra.utils.instantiate(cfg.dm_pre) self.assertIsInstance(cmapss_dm, rul_datasets.PretrainingBaselineDataModule) - self.assertIsInstance(cmapss_dm.failed_loader, rul_datasets.CmapssReader) - self.assertIsInstance(cmapss_dm.failed_loader, rul_datasets.CmapssReader) + self.assertIsInstance(cmapss_dm.failed.reader, rul_datasets.CmapssReader) + self.assertIsInstance(cmapss_dm.failed.reader, rul_datasets.CmapssReader) with self.subTest("femto"): cfg = hydra.compose( @@ -55,8 +55,8 @@ def test_dm_pre(self): ) femto_dm = hydra.utils.instantiate(cfg.dm_pre) self.assertIsInstance(femto_dm, rul_datasets.PretrainingBaselineDataModule) - self.assertIsInstance(femto_dm.failed_loader, rul_datasets.FemtoReader) - self.assertIsInstance(femto_dm.failed_loader, rul_datasets.FemtoReader) + self.assertIsInstance(femto_dm.failed.reader, rul_datasets.FemtoReader) + self.assertIsInstance(femto_dm.failed.reader, rul_datasets.FemtoReader) class TestAdaption(unittest.TestCase):