diff --git a/pyproject.toml b/pyproject.toml index 787cc53..dffa127 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ requires-python = ">=3.8" version = "0.0.1-dev" dependencies = [ + "ipywidgets", "jsonargparse[all]", "librep@git+https://github.com/discovery-unicamp/hiaac-librep.git@0.0.4-dev", "lightly", diff --git a/ssl_tools/experiments/__init__.py b/ssl_tools/experiments/__init__.py index 7295e5c..f5ef6d6 100644 --- a/ssl_tools/experiments/__init__.py +++ b/ssl_tools/experiments/__init__.py @@ -1,2 +1,7 @@ -from .lightning_cli import LightningTrain, LightningTest -from .ssl_experiment import SSLTrain, SSLTest \ No newline at end of file +from .experiment import Experiment, auto_main +from .lightning_experiment import ( + LightningExperiment, + LightningTrain, + LightningTest, + LightningSSLTrain, +) diff --git a/ssl_tools/experiments/experiment.py b/ssl_tools/experiments/experiment.py new file mode 100644 index 0000000..4dd66a6 --- /dev/null +++ b/ssl_tools/experiments/experiment.py @@ -0,0 +1,79 @@ +from pathlib import Path +from typing import Any, Dict, Union +from abc import ABC, abstractmethod +from datetime import datetime +from jsonargparse import ArgumentParser + +EXPERIMENT_VERSION_FORMAT = "%Y-%m-%d_%H-%M-%S" + + +class Experiment(ABC): + def __init__( + self, + name: str = "experiment", + run_id: Union[str, int] = None, + log_dir: str = "logs", + seed: int = None, + ): + self.name = name + self.run_id = run_id or datetime.now().strftime( + EXPERIMENT_VERSION_FORMAT + ) + self.log_dir = log_dir + self.seed = seed + + @property + def experiment_dir(self) -> Path: + return Path(self.log_dir) / self.name / str(self.run_id) + + def setup(self): + pass + + @abstractmethod + def run(self) -> Any: + raise NotImplementedError + + def teardown(self): + pass + + def execute(self): + print(f"Setting up experiment: {self.name}...") + self.setup() + print(f"Running experiment: {self.name}...") + result = self.run() + print(f"Teardown experiment: {self.name}...") + self.teardown() + return result + + def __call__(self): + return self.execute() + + def __str__(self): + return f"Experiment(name={self.name}, run_id={self.run_id}, cwd={self.experiment_dir})" + + def __repr__(self) -> str: + return str(self) + + +def get_parser(commands: Dict[str, Experiment]): + parser = ArgumentParser() + subcommands = parser.add_subcommands() + + for name, command in commands.items(): + subparser = ArgumentParser() + subparser.add_class_arguments(command) + subcommands.add_subcommand(name, subparser) + + return parser + + +def auto_main(commands: Dict[str, Experiment]): + parser = get_parser(commands) + args = parser.parse_args() + # print(args) + + experiment = commands[args.subcommand](**args[args.subcommand]) + experiment.execute() + + # command = args.subcommand + # command(**args).execute() diff --git a/ssl_tools/experiments/har_classification/cpc.py b/ssl_tools/experiments/har_classification/cpc.py index bfd3d53..107e4a8 100755 --- a/ssl_tools/experiments/har_classification/cpc.py +++ b/ssl_tools/experiments/har_classification/cpc.py @@ -1,16 +1,9 @@ #!/usr/bin/env python -# TODO: A way of removing the need to add the path to the root of -# the project -import sys -from jsonargparse import CLI import lightning as L import torch -sys.path.append("../../../") - - -from ssl_tools.experiments import SSLTrain, SSLTest +from ssl_tools.experiments import LightningSSLTrain, LightningTest, auto_main from ssl_tools.models.ssl.cpc import build_cpc from ssl_tools.data.data_modules import ( MultiModalHARSeriesDataModule, @@ -21,7 +14,7 @@ from ssl_tools.models.ssl.modules.heads import CPCPredictionHead -class CPCTrain(SSLTrain): +class CPCTrain(LightningSSLTrain): _MODEL_NAME = "CPC" def __init__( @@ -64,7 +57,7 @@ def __init__( self.num_classes = num_classes self.update_backbone = update_backbone - def _get_pretrain_model(self) -> L.LightningModule: + def get_pretrain_model(self) -> L.LightningModule: model = build_cpc( encoding_size=self.encoding_size, in_channel=self.in_channel, @@ -74,7 +67,7 @@ def _get_pretrain_model(self) -> L.LightningModule: ) return model - def _get_pretrain_data_module(self) -> L.LightningDataModule: + def get_pretrain_data_module(self) -> L.LightningDataModule: data_module = UserActivityFolderDataModule( data_path=self.data, batch_size=self.batch_size, @@ -83,17 +76,17 @@ def _get_pretrain_data_module(self) -> L.LightningDataModule: ) return data_module - def _get_finetune_model( + def get_finetune_model( self, load_backbone: str = None ) -> L.LightningModule: - model = self._get_pretrain_model() + model = self.get_pretrain_model() if load_backbone is not None: - self._load_model(model, load_backbone) + self.load_checkpoint(model, load_backbone) classifier = CPCPredictionHead( input_dim=self.encoding_size, - hidden_dim2=self.num_classes, + output_dim=self.num_classes, ) task = "multiclass" if self.num_classes > 2 else "binary" @@ -107,18 +100,19 @@ def _get_finetune_model( ) return model - def _get_finetune_data_module(self) -> L.LightningDataModule: + def get_finetune_data_module(self) -> L.LightningDataModule: data_module = MultiModalHARSeriesDataModule( data_path=self.data, batch_size=self.batch_size, label="standard activity code", features_as_channels=True, + num_workers=self.num_workers, ) return data_module -class CPCTest(SSLTest): +class CPCTest(LightningTest): _MODEL_NAME = "CPC" def __init__( @@ -154,7 +148,7 @@ def __init__( self.window_size = window_size self.num_classes = num_classes - def _get_test_model(self, load_backbone: str = None) -> L.LightningModule: + def get_model(self, load_backbone: str = None) -> L.LightningModule: model = build_cpc( encoding_size=self.encoding_size, in_channel=self.in_channel, @@ -162,9 +156,12 @@ def _get_test_model(self, load_backbone: str = None) -> L.LightningModule: n_size=5, ) + if load_backbone is not None: + self.load_checkpoint(model, load_backbone) + classifier = CPCPredictionHead( input_dim=self.encoding_size, - hidden_dim2=self.num_classes, + output_dim=self.num_classes, ) task = "multiclass" if self.num_classes > 2 else "binary" @@ -176,23 +173,21 @@ def _get_test_model(self, load_backbone: str = None) -> L.LightningModule: ) return model - def _get_test_data_module(self) -> L.LightningDataModule: + def get_data_module(self) -> L.LightningDataModule: data_module = MultiModalHARSeriesDataModule( data_path=self.data, batch_size=self.batch_size, label="standard activity code", features_as_channels=True, + num_workers=self.num_workers, ) + return data_module -def main(): - components = { +if __name__ == "__main__": + options = { "fit": CPCTrain, "test": CPCTest, } - CLI(components=components, as_positional=False)() - - -if __name__ == "__main__": - main() + auto_main(options) diff --git a/ssl_tools/experiments/har_classification/scripts/finetune_cpc_motionsense_motionsense.sh b/ssl_tools/experiments/har_classification/scripts/finetune_cpc_motionsense_motionsense.sh index 9ab5254..b814ddd 100755 --- a/ssl_tools/experiments/har_classification/scripts/finetune_cpc_motionsense_motionsense.sh +++ b/ssl_tools/experiments/har_classification/scripts/finetune_cpc_motionsense_motionsense.sh @@ -8,7 +8,7 @@ cd .. --batch_size 128 \ --accelerator gpu \ --devices 1 \ - --load_backbone logs/pretrain/CPC/2024-01-29_19-54-51/checkpoints/last.ckpt \ + --load_backbone logs/pretrain/CPC/2024-01-31_21-14-31/checkpoints/last.ckpt \ --training_mode finetune \ --window_size 60 \ --num_classes 6 \ diff --git a/ssl_tools/experiments/har_classification/scripts/finetune_tfc_motionsense_motionsense.sh b/ssl_tools/experiments/har_classification/scripts/finetune_tfc_motionsense_motionsense.sh index 7aa5f18..eedc667 100755 --- a/ssl_tools/experiments/har_classification/scripts/finetune_tfc_motionsense_motionsense.sh +++ b/ssl_tools/experiments/har_classification/scripts/finetune_tfc_motionsense_motionsense.sh @@ -8,7 +8,7 @@ cd .. --batch_size 128 \ --accelerator gpu \ --devices 1 \ - --load_backbone logs/pretrain/TFC/2024-01-28_22-20-23/checkpoints/last.ckpt \ + --load_backbone logs/pretrain/TFC/2024-01-31_21-07-06/checkpoints/last.ckpt \ --training_mode finetune \ --encoding_size 128 \ --features_as_channels True \ diff --git a/ssl_tools/experiments/har_classification/scripts/finetune_tnc_motionsense_motionsense.sh b/ssl_tools/experiments/har_classification/scripts/finetune_tnc_motionsense_motionsense.sh index e2f8226..ff98aa4 100755 --- a/ssl_tools/experiments/har_classification/scripts/finetune_tnc_motionsense_motionsense.sh +++ b/ssl_tools/experiments/har_classification/scripts/finetune_tnc_motionsense_motionsense.sh @@ -8,7 +8,7 @@ cd .. --batch_size 128 \ --accelerator gpu \ --devices 1 \ - --load_backbone logs/pretrain/TNC/2024-01-29_19-28-11/checkpoints/last.ckpt \ + --load_backbone logs/pretrain/TNC/2024-01-31_21-13-05/checkpoints/last.ckpt \ --training_mode finetune \ --repeat 5 \ --mc_sample_size 20 \ diff --git a/ssl_tools/experiments/har_classification/scripts/inference_cpc_all.sh b/ssl_tools/experiments/har_classification/scripts/inference_cpc_all.sh index b75e64c..b57aa65 100755 --- a/ssl_tools/experiments/har_classification/scripts/inference_cpc_all.sh +++ b/ssl_tools/experiments/har_classification/scripts/inference_cpc_all.sh @@ -6,7 +6,7 @@ for dset in "KuHar" "MotionSense" "RealWorld_thigh" "RealWorld_waist" "UCI"; do ./cpc.py test \ --data /workspaces/hiaac-m4/ssl_tools/data/standartized_balanced/${dset} \ - --load logs/finetune/CPC/2024-01-29_19-58-58/checkpoints/last.ckpt \ + --load logs/finetune/CPC/2024-01-31_21-19-23/checkpoints/last.ckpt \ --batch_size 128 \ --accelerator gpu \ --devices 1 \ diff --git a/ssl_tools/experiments/har_classification/scripts/inference_tfc_all.sh b/ssl_tools/experiments/har_classification/scripts/inference_tfc_all.sh index b2b8fd4..90edad4 100755 --- a/ssl_tools/experiments/har_classification/scripts/inference_tfc_all.sh +++ b/ssl_tools/experiments/har_classification/scripts/inference_tfc_all.sh @@ -6,7 +6,7 @@ for dset in "KuHar" "MotionSense" "RealWorld_thigh" "RealWorld_waist" "UCI"; do ./tfc.py test \ --data /workspaces/hiaac-m4/ssl_tools/data/standartized_balanced/${dset} \ - --load logs/finetune/TFC/2023-12-06_19-35-44/checkpoints/last.ckpt \ + --load logs/finetune/TFC/2024-01-31_21-17-05/checkpoints/last.ckpt \ --batch_size 128 \ --accelerator gpu \ --devices 1 \ diff --git a/ssl_tools/experiments/har_classification/scripts/inference_tnc_all.sh b/ssl_tools/experiments/har_classification/scripts/inference_tnc_all.sh index c23e726..5822dda 100755 --- a/ssl_tools/experiments/har_classification/scripts/inference_tnc_all.sh +++ b/ssl_tools/experiments/har_classification/scripts/inference_tnc_all.sh @@ -6,7 +6,7 @@ for dset in "KuHar" "MotionSense" "RealWorld_thigh" "RealWorld_waist" "UCI"; do ./tnc.py test \ --data /workspaces/hiaac-m4/ssl_tools/data/standartized_balanced/${dset} \ - --load logs/finetune/TNC/2024-01-29_19-36-22/checkpoints/last.ckpt \ + --load logs/finetune/TNC/2024-01-31_21-22-05/checkpoints/last.ckpt \ --batch_size 128 \ --accelerator gpu \ --devices 1 \ diff --git a/ssl_tools/experiments/har_classification/tfc.py b/ssl_tools/experiments/har_classification/tfc.py index 85fb772..e78d2fb 100755 --- a/ssl_tools/experiments/har_classification/tfc.py +++ b/ssl_tools/experiments/har_classification/tfc.py @@ -1,24 +1,17 @@ #!/usr/bin/env python -# TODO: A way of removing the need to add the path to the root of -# the project -import sys -from jsonargparse import CLI import lightning as L import torch -sys.path.append("../../../") - - -from ssl_tools.experiments import SSLTrain, SSLTest -from ssl_tools.models.ssl.tfc import build_tfc_transformer -from ssl_tools.models.ssl.modules.heads import TFCPredictionHead -from ssl_tools.data.data_modules import TFCDataModule +from ssl_tools.experiments import LightningSSLTrain, LightningTest, auto_main from torchmetrics import Accuracy from ssl_tools.models.ssl.classifier import SSLDiscriminator +from ssl_tools.models.ssl.modules.heads import TFCPredictionHead +from ssl_tools.models.ssl.tfc import build_tfc_transformer +from ssl_tools.data.data_modules import TFCDataModule -class TFCTrain(SSLTrain): +class TFCTrain(LightningSSLTrain): _MODEL_NAME = "TFC" def __init__( @@ -88,7 +81,7 @@ def __init__( self.num_classes = num_classes self.update_backbone = update_backbone - def _get_pretrain_model(self) -> L.LightningModule: + def get_pretrain_model(self) -> L.LightningModule: model = build_tfc_transformer( encoding_size=self.encoding_size, in_channels=self.in_channels, @@ -99,7 +92,7 @@ def _get_pretrain_model(self) -> L.LightningModule: ) return model - def _get_pretrain_data_module(self) -> L.LightningDataModule: + def get_pretrain_data_module(self) -> L.LightningDataModule: data_module = TFCDataModule( self.data, batch_size=self.batch_size, @@ -112,13 +105,13 @@ def _get_pretrain_data_module(self) -> L.LightningDataModule: ) return data_module - def _get_finetune_model( + def get_finetune_model( self, load_backbone: str = None ) -> L.LightningModule: - model = self._get_pretrain_model() + model = self.get_pretrain_model() if load_backbone is not None: - self._load_model(model, load_backbone) + self.load_checkpoint(model, load_backbone) classifier = TFCPredictionHead( input_dim=2 * self.encoding_size, @@ -136,7 +129,7 @@ def _get_finetune_model( ) return model - def _get_finetune_data_module(self) -> L.LightningDataModule: + def get_finetune_data_module(self) -> L.LightningDataModule: data_module = TFCDataModule( self.data, batch_size=self.batch_size, @@ -150,7 +143,7 @@ def _get_finetune_data_module(self) -> L.LightningDataModule: return data_module -class TFCTest(SSLTest): +class TFCTest(LightningTest): _MODEL_NAME = "TFC" def __init__( @@ -216,7 +209,7 @@ def __init__( self.features_as_channels = features_as_channels self.num_classes = num_classes - def _get_test_model(self) -> L.LightningModule: + def get_model(self) -> L.LightningModule: model = build_tfc_transformer( encoding_size=self.encoding_size, in_channels=self.in_channels, @@ -239,7 +232,7 @@ def _get_test_model(self) -> L.LightningModule: ) return model - def _get_test_data_module(self) -> L.LightningDataModule: + def get_data_module(self) -> L.LightningDataModule: data_module = TFCDataModule( self.data, batch_size=self.batch_size, @@ -252,13 +245,9 @@ def _get_test_data_module(self) -> L.LightningDataModule: return data_module -def main(): - components = { +if __name__ == "__main__": + options = { "fit": TFCTrain, "test": TFCTest, } - CLI(components=components, as_positional=False)() - - -if __name__ == "__main__": - main() + auto_main(options) diff --git a/ssl_tools/experiments/har_classification/tnc.py b/ssl_tools/experiments/har_classification/tnc.py index 1981f82..b684402 100755 --- a/ssl_tools/experiments/har_classification/tnc.py +++ b/ssl_tools/experiments/har_classification/tnc.py @@ -2,15 +2,10 @@ # TODO: A way of removing the need to add the path to the root of # the project -import sys -from jsonargparse import CLI import lightning as L import torch -sys.path.append("../../../") - -from ssl_tools.experiments import SSLTrain, SSLTest from ssl_tools.models.ssl.tnc import build_tnc from ssl_tools.data.data_modules import ( TNCHARDataModule, @@ -20,8 +15,15 @@ from ssl_tools.models.ssl.classifier import SSLDiscriminator from ssl_tools.models.ssl.modules.heads import TNCPredictionHead +import lightning as L +import torch -class TNCTrain(SSLTrain): +from ssl_tools.experiments import LightningSSLTrain, LightningTest, auto_main +from torchmetrics import Accuracy +from ssl_tools.models.ssl.classifier import SSLDiscriminator + + +class TNCTrain(LightningSSLTrain): _MODEL_NAME = "TNC" def __init__( @@ -72,7 +74,7 @@ def __init__( self.num_classes = num_classes self.update_backbone = update_backbone - def _get_pretrain_model(self) -> L.LightningModule: + def get_pretrain_model(self) -> L.LightningModule: model = build_tnc( encoding_size=self.encoding_size, in_channel=self.in_channel, @@ -82,7 +84,7 @@ def _get_pretrain_model(self) -> L.LightningModule: ) return model - def _get_pretrain_data_module(self) -> L.LightningDataModule: + def get_pretrain_data_module(self) -> L.LightningDataModule: data_module = TNCHARDataModule( self.data, pad=self.pad_length, @@ -95,17 +97,22 @@ def _get_pretrain_data_module(self) -> L.LightningDataModule: ) return data_module - def _get_finetune_model( + def get_finetune_model( self, load_backbone: str = None ) -> L.LightningModule: - model = self._get_pretrain_model() + model = build_tnc( + encoding_size=self.encoding_size, + in_channel=self.in_channel, + mc_sample_size=self.mc_sample_size, + w=self.w, + ) if load_backbone is not None: - self._load_model(model, load_backbone) + self.load_checkpoint(model, load_backbone) classifier = TNCPredictionHead( input_dim=self.encoding_size, - hidden_dim2=self.num_classes, + output_dim=self.num_classes, ) task = "multiclass" if self.num_classes > 2 else "binary" @@ -115,11 +122,10 @@ def _get_finetune_model( loss_fn=torch.nn.CrossEntropyLoss(), learning_rate=self.learning_rate, metrics={"acc": Accuracy(task=task, num_classes=self.num_classes)}, - update_backbone=self.update_backbone, ) return model - def _get_finetune_data_module(self) -> L.LightningDataModule: + def get_finetune_data_module(self) -> L.LightningDataModule: data_module = MultiModalHARSeriesDataModule( self.data, batch_size=self.batch_size, @@ -130,7 +136,7 @@ def _get_finetune_data_module(self) -> L.LightningDataModule: return data_module -class TNCTest(SSLTest): +class TNCTest(LightningTest): _MODEL_NAME = "TNC" def __init__( @@ -168,7 +174,7 @@ def __init__( self.w = w self.num_classes = num_classes - def _get_test_model(self) -> L.LightningModule: + def get_model(self) -> L.LightningModule: model = build_tnc( encoding_size=self.encoding_size, in_channel=self.in_channel, @@ -177,7 +183,7 @@ def _get_test_model(self) -> L.LightningModule: ) classifier = TNCPredictionHead( input_dim=self.encoding_size, - hidden_dim2=self.num_classes, + output_dim=self.num_classes, ) task = "multiclass" if self.num_classes > 2 else "binary" @@ -189,7 +195,7 @@ def _get_test_model(self) -> L.LightningModule: ) return model - def _get_test_data_module(self) -> L.LightningDataModule: + def get_data_module(self) -> L.LightningDataModule: data_module = MultiModalHARSeriesDataModule( self.data, batch_size=self.batch_size, @@ -200,13 +206,9 @@ def _get_test_data_module(self) -> L.LightningDataModule: return data_module -def main(): - components = { +if __name__ == "__main__": + options = { "fit": TNCTrain, "test": TNCTest, } - CLI(components=components, as_positional=False)() - - -if __name__ == "__main__": - main() + auto_main(options) diff --git a/ssl_tools/experiments/lightning_cli.py b/ssl_tools/experiments/lightning_cli.py deleted file mode 100644 index 86ef68a..0000000 --- a/ssl_tools/experiments/lightning_cli.py +++ /dev/null @@ -1,156 +0,0 @@ -from typing import Union -import logging -import os -from datetime import datetime - -EXPERIMENT_VERSION_FORMAT = "%Y-%m-%d_%H-%M-%S" - - -class LightningTrain: - def __init__( - self, - epochs: int = 1, - batch_size: int = 1, - learning_rate: float = 1e-3, - log_dir: str = "logs", - name: str = None, - version: Union[str, int] = None, - load: str = None, - checkpoint_metric: str = None, - checkpoint_metric_mode: str = "min", - accelerator: str = "cpu", - devices: int = 1, - strategy: str = "auto", - limit_train_batches: Union[float, int] = 1.0, - limit_val_batches: Union[float, int] = 1.0, - num_nodes: int = 1, - num_workers: int = None, - seed: int = None, - ): - """Defines the parameters for training a Lightning model. This class - may be used to define the parameters for a Lightning experiment and - CLI. - - Parameters - ---------- - epochs : int, optional - Number of epochs to pre-train the model - batch_size : int, optional - The batch size - learning_rate : float, optional - The learning rate of the optimizer - log_dir : str, optional - Path to the location where logs will be stored - name: str, optional - The name of the experiment (it will be used to compose the path of - the experiments, such as logs and checkpoints) - version: Union[int, str], optional - The version of the experiment. If not is provided the current date - and time will be used as the version - load: str, optional - The path to a checkpoint to load - checkpoint_metric: str, optional - The metric to monitor for checkpointing. If not provided, the last - model will be saved - checkpoint_metric_mode: str, optional - The mode of the metric to monitor (min, max or mean). Defaults to - "min" - accelerator: str, optional - The accelerator to use. Defaults to "cpu" - devices: int, optional - The number of devices to use. Defaults to 1 - strategy: str, optional - The strategy to use. Defaults to "auto" - limit_train_batches: Union[float, int], optional - The number of batches to use for training. Defaults to 1.0 (use - all batches) - limit_val_batches: Union[float, int], optional - The number of batches to use for validation. Defaults to 1.0 (use - all batches) - num_nodes: int, optional - The number of nodes to use. Defaults to 1 - num_workers: int, optional - The number of workers to use for the dataloader. - seed: int, optional - The seed to use. - """ - self.epochs = epochs - self.batch_size = batch_size - self.learning_rate = learning_rate - self.log_dir = log_dir - self.experiment_name = name - self.experiment_version = version or datetime.now().strftime( - EXPERIMENT_VERSION_FORMAT - ) - self.load = load - self.checkpoint_metric = checkpoint_metric - self.checkpoint_metric_mode = checkpoint_metric_mode - self.accelerator = accelerator - self.devices = devices - self.strategy = strategy - self.limit_train_batches = limit_train_batches - self.limit_val_batches = limit_val_batches - self.num_nodes = num_nodes - self.num_workers = num_workers - self.seed = seed - - -class LightningTest: - def __init__( - self, - load: str, - batch_size: int = 1, - log_dir="logs", - name: str = None, - version: str = None, - accelerator: str = "cpu", - devices: int = 1, - limit_test_batches: Union[float, int] = 1.0, - num_nodes: int = 1, - num_workers: int = None, - seed: int = None, - ): - """Defines the parameters for testing a Lightning model. This class - may be used to define the parameters for a Lightning experiment and - CLI. - - Parameters - ---------- - load : str - Path to the checkpoint to load - batch_size : int, optional - The batch size - log_dir : str, optional - Path to the location where logs will be stored - name: str, optional - The name of the experiment (it will be used to compose the path of - the experiments, such as logs and checkpoints) - version: Union[int, str], optional - The version of the experiment. If not is provided the current date - and time will be used as the version - accelerator: str, optional - The accelerator to use. Defaults to "cpu" - devices: int, optional - The number of devices to use. Defaults to 1 - limit_test_batches : Union[float, int], optional - Limit the number of batches to use for testing. - num_nodes: int, optional - The number of nodes to use. Defaults to 1 - num_workers: int, optional - The number of workers to use for the dataloader. - seed: int, optional - The seed to use. - """ - self.load = load - self.batch_size = batch_size - self.log_dir = log_dir - self.experiment_name = name - self.experiment_version = version or datetime.now().strftime( - EXPERIMENT_VERSION_FORMAT - ) - self.accelerator = accelerator - self.devices = devices - self.limit_test_batches = limit_test_batches - self.num_nodes = num_nodes - self.num_workers = num_workers - self.seed = seed diff --git a/ssl_tools/experiments/lightning_experiment.py b/ssl_tools/experiments/lightning_experiment.py new file mode 100644 index 0000000..daa5315 --- /dev/null +++ b/ssl_tools/experiments/lightning_experiment.py @@ -0,0 +1,525 @@ +from pathlib import Path +from typing import Any, List, Union +from abc import abstractmethod +import lightning as L +from lightning.pytorch.loggers import Logger, CSVLogger +from lightning.pytorch.callbacks import ModelCheckpoint, RichProgressBar +import torch +from ssl_tools.callbacks.performance import PerformanceLog +from ssl_tools.experiments.experiment import Experiment + +class LightningExperiment(Experiment): + _MODEL_NAME: str = "model" + _STAGE_NAME: str = "stage" + + def __init__( + self, + name: str = None, + stage_name: str = None, + batch_size: int = 1, + load: str = None, + accelerator: str = "cpu", + devices: int = 1, + strategy: str = "auto", + num_nodes: int = 1, + num_workers: int = None, + log_every_n_steps: int = 50, + *args, + **kwargs, + ): + name = name or self._MODEL_NAME + super().__init__(name=name, *args, **kwargs) + + self.stage_name = stage_name or self._STAGE_NAME + self.batch_size = batch_size + self.load = load + self.accelerator = accelerator + self.devices = devices + self.strategy = strategy + self.num_nodes = num_nodes + self.num_workers = num_workers + self.log_every_n_steps = log_every_n_steps + + self._model = None + self._logger = None + self._callbacks = None + self._data_module = None + self._trainer = None + self._result = None + self._run_count = 0 + + @property + def experiment_dir(self) -> Path: + return ( + Path(self.log_dir) / self.stage_name / self.name / str(self.run_id) + ) + + @property + def checkpoint_dir(self) -> Path: + return self.experiment_dir / "checkpoints" + + @property + def model(self) -> L.LightningModule: + if self._model is None: + self._model = self.get_model() + return self._model + + @property + def data_module(self) -> L.LightningDataModule: + if self._data_module is None: + self._data_module = self.get_data_module() + return self._data_module + + @property + def logger(self) -> Logger: + if self._logger is None: + self._logger = self.get_logger() + return self._logger + + @property + def callbacks(self) ->List[L.Callback]: + if self._callbacks is None: + self._callbacks = self.get_callbacks() + return self._callbacks + + @property + def hyperparameters(self) -> dict: + def nested_convert(data): + if isinstance(data, dict): + return { + key: nested_convert(value) for key, value in data.items() if not key.startswith("_") + } + elif isinstance(data, Path): + return str(data.expanduser()) + else: + return data + + hyperparams = self.__dict__.copy() + + + if getattr(self.model, "get_config", None): + hyperparams.update(self.model.get_config()) + hyperparams = nested_convert(hyperparams) + return hyperparams + + @property + def trainer(self) -> L.Trainer: + if self._trainer is None: + self._trainer = self.get_trainer(self.logger, self.callbacks) + return self._trainer + + @property + def finished(self) -> bool: + return self._run_count > 0 + + def setup(self): + if self.seed is not None: + L.seed_everything(self.seed) + + self.experiment_dir.mkdir(parents=True, exist_ok=True) + self.checkpoint_dir.mkdir(parents=True, exist_ok=True) + + def get_logger(self) -> Logger: + """Get the logger to use for the experiment. + + Returns + ------- + Logger + The logger to use for the experiment + """ + experiment_dir = self.experiment_dir + + logger = CSVLogger( + save_dir=experiment_dir.parents[1], + name=self.experiment_dir.parents[0].name, + version=self.experiment_dir.name, + ) + return logger + + def get_callbacks(self) -> List[L.Callback]: + """Get the callbacks to use for the experiment. + + Returns + ------- + List[L.Callback] + A list of callbacks to use for the experiment + """ + return [] + + def load_checkpoint( + self, model: L.LightningModule, path: Path + ) -> L.LightningModule: + """Load the model to use for the experiment. + + Returns + ------- + L.LightningModule + The model to use for the experiment + """ + print(f"Loading model from: {path}...") + state_dict = torch.load(path)["state_dict"] + model.load_state_dict(state_dict) + print("Model loaded successfully") + return model + + def log_hyperparams(self, logger: Logger) -> dict: + """Log the hyperparameters for reproducibility purposes. + + Parameters + ---------- + model : L.LightningModule + The model to log the hyperparameters from + logger : Logger + The logger to use for logging the hyperparameters + """ + hparams = self.hyperparameters + logger.log_hyperparams(hparams) + + def run(self): + """Runs the experiment. This method: + 1. Instantiates the model and data module (depending on the + ``training_mode``) and load the checkpoint if provided + 2. Instantiates the trainer specific resources (logger, callbacks, etc.) + 3. Logs the hyperparameters (for reproducibility purposes) + 4. Instantiates the trainer + 5. Trains/Tests the model + """ + + # ---------------------------------------------------------------------- + # 1. Instantiate model and data module + # ---------------------------------------------------------------------- + model = self.model + data_module = self.data_module + + if self.load: + model = self.load_checkpoint(model, self.load) + + # ---------------------------------------------------------------------- + # 2. Instantiate trainer specific resources (logger, callbacks, etc.) + # ---------------------------------------------------------------------- + logger = self.logger + + # ---------------------------------------------------------------------- + # 3. Log the hyperparameters (for reproducibility purposes) + # ---------------------------------------------------------------------- + self.log_hyperparams(logger) + + # ---------------------------------------------------------------------- + # 4. Instantiate the trainer + # ---------------------------------------------------------------------- + trainer = self.trainer + # ---------------------------------------------------------------------- + # 5. Train/Tests the model + # ---------------------------------------------------------------------- + self._result = self.run_model(model, data_module, trainer) + + self._run_count += 1 + return self._result + + @abstractmethod + def get_trainer( + self, logger: Logger, callbacks: List[L.Callback] + ) -> L.Trainer: + """Get trainer to use for the experiment. + + Parameters + ---------- + logger : _type_ + The logger to use for the experiment + callbacks : List[L.Callback] + A list of callbacks to use for the experiment + + Returns + ------- + L.Trainer + The trainer to use for the experiment + """ + raise NotImplementedError + + @abstractmethod + def run_model( + self, + model: L.LightningModule, + data_module: L.LightningDataModule, + trainer: L.Trainer, + ): + raise NotImplementedError + + @abstractmethod + def get_model(self) -> L.LightningModule: + """Get the model to use for the experiment. + + Returns + ------- + L.LightningModule + The model to use for the experiment + """ + raise NotImplementedError + + @abstractmethod + def get_data_module(self) -> L.LightningDataModule: + """Get the datamodule to use for the experiment. + + Returns + ------- + L.LightningDataModule + The datamodule to use for the experiment + """ + raise NotImplementedError + + def __str__(self) -> str: + return f"LightningExperiment(experiment_dir={self.experiment_dir}, model={self._MODEL_NAME}, run_id={self.run_id}, finished={self.finished})" + + +class LightningTrain(LightningExperiment): + _STAGE_NAME="train" + + def __init__( + self, + stage_name: str = "train", + epochs: int = 1, + learning_rate: float = 1e-3, + checkpoint_metric: str = None, + checkpoint_metric_mode: str = "min", + limit_train_batches: Union[float, int] = 1.0, + limit_val_batches: Union[float, int] = 1.0, + *args, + **kwargs, + ): + super().__init__(stage_name=stage_name, *args, **kwargs) + self.epochs = epochs + self.learning_rate = learning_rate + self.checkpoint_metric = checkpoint_metric + self.checkpoint_metric_mode = checkpoint_metric_mode + self.limit_train_batches = limit_train_batches + self.limit_val_batches = limit_val_batches + + def get_callbacks(self) -> List[L.Callback]: + """Get the callbacks to use for the experiment. + + Returns + ------- + List[L.Callback] + A list of callbacks to use for the experiment + """ + # Get the checkpoint callback + checkpoint_callback = ModelCheckpoint( + monitor=self.checkpoint_metric, + mode=self.checkpoint_metric_mode, + dirpath=self.checkpoint_dir, + save_last=True, + ) + + performance_log = PerformanceLog() + + rich_progress_bar = RichProgressBar( + leave=False, console_kwargs={"soft_wrap": True} + ) + + return [checkpoint_callback, rich_progress_bar, performance_log] + + def get_trainer( + self, logger: Logger, callbacks: List[L.Callback] + ) -> L.Trainer: + """Get trainer to use for the experiment. + + Parameters + ---------- + logger : _type_ + The logger to use for the experiment + callbacks : List[L.Callback] + A list of callbacks to use for the experiment + + Returns + ------- + L.Trainer + The trainer to use for the experiment + """ + return L.Trainer( + logger=logger, + callbacks=callbacks, + max_epochs=self.epochs, + accelerator=self.accelerator, + devices=self.devices, + strategy=self.strategy, + num_nodes=self.num_nodes, + limit_train_batches=self.limit_train_batches, + limit_val_batches=self.limit_val_batches, + log_every_n_steps=self.log_every_n_steps, + ) + + def run_model( + self, + model: L.LightningModule, + data_module: L.LightningDataModule, + trainer: L.Trainer, + ): + print(f"Training will start") + print(f"\tExperiment path: {self.experiment_dir}") + result = trainer.fit(model, data_module) + + print(f"Training finished") + print(f"Last checkpoint saved at: {self.checkpoint_dir}/last.ckpt") + return result + + +class LightningTest(LightningExperiment): + _STAGE_NAME="test" + + def __init__(self, limit_test_batches: Union[float, int] = 1.0, *args, **kwargs): + super().__init__(*args, **kwargs) + self.limit_test_batches = limit_test_batches + + def get_callbacks(self) -> List[L.Callback]: + """Get the callbacks to use for the experiment. + + Returns + ------- + List[L.Callback] + The list of callbacks to use for the experiment. + """ + performance_log = PerformanceLog() + rich_progress_bar = RichProgressBar( + leave=False, console_kwargs={"soft_wrap": True} + ) + return [rich_progress_bar, performance_log] + + def get_trainer( + self, logger: Logger, callbacks: List[L.Callback] + ) -> L.Trainer: + """Get trainer to use for the experiment. + + Parameters + ---------- + logger : _type_ + The logger to use for the experiment + callbacks : List[L.Callback] + A list of callbacks to use for the experiment + + Returns + ------- + L.Trainer + The trainer to use for the experiment + """ + trainer = L.Trainer( + logger=logger, + callbacks=callbacks, + accelerator=self.accelerator, + devices=self.devices, + num_nodes=self.num_nodes, + limit_test_batches=self.limit_test_batches, + log_every_n_steps=self.log_every_n_steps + ) + return trainer + + def run_model( + self, + model: L.LightningModule, + data_module: L.LightningDataModule, + trainer: L.Trainer, + ) -> Any: + return trainer.test(model, data_module) + + +class LightningSSLTrain(LightningTrain): + def __init__( + self, + training_mode: str = "pretrain", + load_backbone: str = None, + *args, + **kwargs, + ): + """Wraps the LightningTrain class to provide a more specific interface + for SSL experiments (training). + + Parameters + ---------- + training_mode : str, optional + The training mode. It could be either "pretrain" or "finetune" + load_backbone : str, optional + Path to the backbone to load. This is only used when training_mode + is "finetune". In fine-tuning, the backbone is loaded and the + using ``load_backbone``. The ``load`` parameter is used to load the + full model (backbone + head). + """ + super().__init__(stage_name=training_mode,*args, **kwargs) + self.training_mode = training_mode + self.load_backbone = load_backbone + assert self.training_mode in ["pretrain", "finetune"] + + def get_model(self) -> L.LightningModule: + """Get the model to use for the experiment. + + Returns + ------- + L.LightningModule + The model to use for the experiment + """ + if self.training_mode == "pretrain": + return self.get_pretrain_model() + else: + return self.get_finetune_model(self.load_backbone) + + def get_data_module(self) -> L.LightningDataModule: + if self.training_mode == "pretrain": + return self.get_pretrain_data_module() + else: + return self.get_finetune_data_module() + + @abstractmethod + def get_pretrain_model(self) -> L.LightningModule: + """Get the model to use for the pretraining phase. + + Returns + ------- + L.LightningModule + The model to use for the pretraining phase + """ + raise NotImplementedError + + @abstractmethod + def get_finetune_model( + self, load_backbone: str = None + ) -> L.LightningModule: + """Get the model to use for fine-tuning. + + Parameters + ---------- + load_backbone : str, optional + The path to the backbone to load. The backbone must be loaded + inside this method, if it is not None. + + Returns + ------- + L.LightningModule + The model to use for fine-tuning + """ + raise NotImplementedError + + @abstractmethod + def get_pretrain_data_module(self) -> L.LightningDataModule: + """The data module to use for pre-training. + + Returns + ------- + L.LightningDataModule + The data module to use for pre-training + """ + raise NotImplementedError + + @abstractmethod + def get_finetune_data_module(self) -> L.LightningDataModule: + """The data module to use for fine-tuning. + + Returns + ------- + L.LightningDataModule + The data module to use for fine-tuning + + Raises + ------ + NotImplementedError + _description_ + """ + raise NotImplementedError + diff --git a/ssl_tools/experiments/ssl_experiment.py b/ssl_tools/experiments/ssl_experiment.py deleted file mode 100644 index b6c9a54..0000000 --- a/ssl_tools/experiments/ssl_experiment.py +++ /dev/null @@ -1,523 +0,0 @@ -from typing import List, Union -import torch - -from datetime import datetime - -from ssl_tools.experiments import LightningTrain -from ssl_tools.callbacks.performance import PerformanceLog -import lightning as L -from lightning.pytorch.loggers import CSVLogger -from lightning.pytorch.callbacks import ModelCheckpoint, RichProgressBar - -from .lightning_cli import LightningTrain, LightningTest - -from pathlib import Path -import pandas as pd - - -class SSLTrain(LightningTrain): - _MODEL_NAME = "model" - - def __init__( - self, - training_mode: str = "pretrain", - load_backbone: str = None, - *args, - **kwargs, - ): - """Wraps the LightningTrain class to provide a more specific interface - for SSL experiments (training). - - Parameters - ---------- - training_mode : str, optional - The training mode. It could be either "pretrain" or "finetune" - load_backbone : str, optional - Path to the backbone to load. This is only used when training_mode - is "finetune". In fine-tuning, the backbone is loaded and the - using ``load_backbone``. The ``load`` parameter is used to load the - full model (backbone + head). - """ - super().__init__(*args, **kwargs) - self.training_mode = training_mode - self.load_backbone = load_backbone - self.checkpoint_path = None - self.experiment_path = None - - assert self.training_mode in ["pretrain", "finetune"] - - def _set_experiment(self): - """Set the experiment name and version. This method is called before - instantiating the model and data module. It is used to set the - experiment path and checkpoint path. The experiment path is used to - store the logs and checkpoints. The checkpoint path is used to store - the checkpoints. - """ - if self.seed is not None: - L.seed_everything(self.seed) - - self.log_dir = Path(self.log_dir) / self.training_mode - - if self.experiment_name is None: - self.experiment_name = self._MODEL_NAME - - # Same as format as logger - self.experiment_path = ( - self.log_dir / self.experiment_name / self.experiment_version - ) - - self.checkpoint_path = self.experiment_path / "checkpoints" - - def _get_logger(self): - """Get the logger to use for the experiment. - - Returns - ------- - __type__ - The logger to use for the experiment - """ - logger = CSVLogger( - save_dir=self.log_dir, - name=self.experiment_name, - version=self.experiment_version, - # flush_logs_every_n_steps=100, - ) - return logger - - def _get_callbacks(self) -> List[L.Callback]: - """Get the callbacks to use for the experiment. - - Returns - ------- - List[L.Callback] - A list of callbacks to use for the experiment - """ - # Get the checkpoint callback - checkpoint_callback = ModelCheckpoint( - monitor=self.checkpoint_metric, - mode=self.checkpoint_metric_mode, - dirpath=self.checkpoint_path, - save_last=True, - ) - - performance_log = PerformanceLog() - - rich_progress_bar = RichProgressBar( - leave=False, console_kwargs={"soft_wrap": True} - ) - - return [checkpoint_callback, rich_progress_bar, performance_log] - - def _get_trainer(self, logger, callbacks: List[L.Callback]) -> L.Trainer: - """Get trainer to use for the experiment. - - Parameters - ---------- - logger : _type_ - The logger to use for the experiment - callbacks : List[L.Callback] - A list of callbacks to use for the experiment - - Returns - ------- - L.Trainer - The trainer to use for the experiment - """ - - trainer = L.Trainer( - max_epochs=self.epochs, - logger=logger, - # enable_checkpointing=True, - callbacks=callbacks, - accelerator=self.accelerator, - devices=self.devices, - strategy=self.strategy, - limit_train_batches=self.limit_train_batches, - limit_val_batches=self.limit_val_batches, - num_nodes=self.num_nodes, - ) - return trainer - - def _load_model(self, model: L.LightningModule, path: str): - """Load a model from a checkpoint. - - Parameters - ---------- - model : L.LightningModule - The model to load the checkpoint into - path : str - The path to the checkpoint - """ - print(f"Loading model from: {path}") - state_dict = torch.load(path)["state_dict"] - model.load_state_dict(state_dict) - print("Model loaded successfully") - - def _log_hyperparams(self, model: L.LightningModule, logger): - """Log the hyperparameters for reproducibility purposes. - - Parameters - ---------- - model : L.LightningModule - The model to log the hyperparameters from - logger : _type_ - The logger to use for logging the hyperparameters - """ - - def nested_convert(data): - if isinstance(data, dict): - return { - key: nested_convert(value) for key, value in data.items() - } - elif isinstance(data, Path): - return str(data.expanduser()) - else: - return data - - hyperparams = self.__dict__.copy() - if getattr(model, "get_config", None): - hyperparams.update(model.get_config()) - hyperparams = nested_convert(hyperparams) - logger.log_hyperparams(hyperparams) - return hyperparams - - def _get_pretrain_model(self) -> L.LightningModule: - """Get the model to use for pre-training. - - Returns - ------- - L.LightningModule - The model to use for pre-training - """ - raise NotImplementedError - - def _get_pretrain_data_module(self) -> L.LightningDataModule: - """The data module to use for pre-training. - - Returns - ------- - L.LightningDataModule - The data module to use for pre-training - """ - raise NotImplementedError - - def _get_finetune_model( - self, load_backbone: str = None - ) -> L.LightningModule: - """Get the model to use for fine-tuning. - - Parameters - ---------- - load_backbone : str, optional - The path to the backbone to load. The backbone must be loaded - inside this method, if it is not None. - - Returns - ------- - L.LightningModule - The model to use for fine-tuning - """ - raise NotImplementedError - - def _get_finetune_data_module(self) -> L.LightningDataModule: - """The data module to use for fine-tuning. - - Returns - ------- - L.LightningDataModule - The data module to use for fine-tuning - - Raises - ------ - NotImplementedError - _description_ - """ - raise NotImplementedError - - def _train( - self, - model: L.LightningModule, - data_module: L.LightningDataModule, - trainer: L.Trainer, - ) -> None: - """Train the model using the provided trainer. - - Parameters - ---------- - model : L.LightningModule - The model to train - data_module : L.LightningDataModule - The data module to use for training - trainer : L.Trainer - The trainer to use for training - """ - print(f"Training will start") - print(f"\tExperiment path: {self.experiment_path}") - - return trainer.fit(model, data_module) - - def _run(self): - """Runs the experiment. This method is called when the experiment is - called as a function. This method: - 1. Sets the experiment name and version - 2. Instantiates the model and data module (depending on the - ``training_mode``) - 3. Instantiates the trainer specific resources (logger, callbacks, etc.) - 4. Logs the hyperparameters (for reproducibility purposes) - 5. Instantiates the trainer - 6. Trains the model - """ - # ---------------------------------------------------------------------- - # 1. Set experiment name and version - # ---------------------------------------------------------------------- - self._set_experiment() - - # ---------------------------------------------------------------------- - # 2. Instantiate model and data module - # ---------------------------------------------------------------------- - if self.training_mode == "pretrain": - model = self._get_pretrain_model() - data_module = self._get_pretrain_data_module() - else: - model = self._get_finetune_model(load_backbone=self.load_backbone) - data_module = self._get_finetune_data_module() - - if self.load is not None: - self._load_model(model, self.load) - - # ---------------------------------------------------------------------- - # 3. Instantiate trainer specific resources (logger, callbacks, etc.) - # ---------------------------------------------------------------------- - logger = self._get_logger() - callbacks = self._get_callbacks() - - # ---------------------------------------------------------------------- - # 4. Log the hyperparameters (for reproducibility purposes) - # ---------------------------------------------------------------------- - hyperparams = self._log_hyperparams(model, logger) - - # ---------------------------------------------------------------------- - # 5. Instantiate the trainer - # ---------------------------------------------------------------------- - trainer = self._get_trainer(logger, callbacks) - - # ---------------------------------------------------------------------- - # 6. Train the model - # ---------------------------------------------------------------------- - self._train(model, data_module, trainer) - - print(f"Training completed successfully.") - print(f"Last checkpoint saved to: {self.checkpoint_path}/last.ckpt") - - def __call__(self): - self._run() - - -class SSLTest(LightningTest): - _MODEL_NAME = "model" - - def __init__(self, *args, **kwargs): - """Wraps the LightningTest class to provide a more specific interface - for SSL experiments (testing). - """ - super().__init__(*args, **kwargs) - self.experiment_path = None - - def _set_experiment(self): - """Set the experiment name and version. This method is called before - instantiating the model and data module. It sets the experiment path - and results path. The experiment path is used to store the logs and - the results path is used to store the results. - """ - if self.seed is not None: - L.seed_everything(self.seed) - - self.log_dir = Path(self.log_dir) / "test" - if self.experiment_name is None: - self.experiment_name = self._MODEL_NAME - - # Same as format as logger - self.experiment_path = ( - self.log_dir / self.experiment_name / self.experiment_version - ) - - self.results_path = self.experiment_path / "results.csv" - - def _get_logger(self): - """Get the logger to use for the experiment. - - Returns - ------- - _type_ - Get the logger to use for the experiment - """ - logger = CSVLogger( - save_dir=self.log_dir, - name=self.experiment_name, - version=self.experiment_version, - # flush_logs_every_n_steps=100, - ) - return logger - - def _get_callbacks(self) -> List[L.Callback]: - """Get the callbacks to use for the experiment. - - Returns - ------- - List[L.Callback] - The list of callbacks to use for the experiment. - """ - performance_log = PerformanceLog() - rich_progress_bar = RichProgressBar( - leave=False, console_kwargs={"soft_wrap": True} - ) - return [rich_progress_bar, performance_log] - - def _get_trainer(self, logger, callbacks): - trainer = L.Trainer( - logger=logger, - callbacks=callbacks, - accelerator=self.accelerator, - devices=self.devices, - num_nodes=self.num_nodes, - limit_test_batches=self.limit_test_batches, - ) - return trainer - - def _load_model(self, model: L.LightningModule, path: str): - """Loads a model from a checkpoint. - - Parameters - ---------- - model : L.LightningModule - The model to load the checkpoint into - path : str - The path to the checkpoint - """ - print(f"Loading model from: {path}") - state_dict = torch.load(path)["state_dict"] - model.load_state_dict(state_dict) - print("Model loaded successfully") - - def _log_hyperparams(self, model, logger): - def nested_convert(data): - if isinstance(data, dict): - return { - key: nested_convert(value) for key, value in data.items() - } - elif isinstance(data, Path): - return str(data.expanduser()) - else: - return data - - hyperparams = self.__dict__.copy() - if getattr(model, "get_config", None): - hyperparams.update(model.get_config()) - hyperparams = nested_convert(hyperparams) - logger.log_hyperparams(hyperparams) - return hyperparams - - def _get_test_model(self) -> L.LightningModule: - """Get the model to use for testing. - - Returns - ------- - L.LightningModule - The model to use for testing - """ - raise NotImplementedError - - def _get_test_data_module(self) -> L.LightningDataModule: - """The data module to use for testing. - - Returns - ------- - L.LightningDataModule - The data module to use for testing - """ - raise NotImplementedError - - def _test( - self, - model: L.LightningModule, - data_module: L.LightningDataModule, - trainer: L.Trainer, - ): - """Test the model using the provided trainer. - - Parameters - ---------- - model : L.LightningModule - The model to test - data_module : L.LightningDataModule - The data module to use for testing - trainer : L.Trainer - The trainer to use for testing - - Returns - ------- - _type_ - A list of dictionary with the results - """ - return trainer.test(model, data_module) - - def _run(self) -> List[dict]: - """Runs the experiment. This method is called when the experiment is - called as a function. This method: - 1. Sets the experiment name and version - 2. Instantiates the model and data module - 3. Instantiates the trainer specific resources (logger, callbacks, etc.) - 4. Logs the hyperparameters (for reproducibility purposes) - 5. Instantiates the trainer - 6. Tests the model - - Note - ---- - The results are converted to a pandas DataFrame and saved to the - ``results_path``. The results are also returned by this method (as - a list of dictionaries). - - Returns - ------- - List[dict] - A list of dictionary with the results - """ - # ---------------------------------------------------------------------- - # 1. Set experiment name and version - # ---------------------------------------------------------------------- - self._set_experiment() - - # ---------------------------------------------------------------------- - # 2. Instantiate model and data module - # ---------------------------------------------------------------------- - model = self._get_test_model() - data_module = self._get_test_data_module() - self._load_model(model, self.load) - - # ---------------------------------------------------------------------- - # 3. Instantiate trainer specific resources (logger, callbacks, etc.) - # ---------------------------------------------------------------------- - logger = self._get_logger() - callbacks = self._get_callbacks() - - # ---------------------------------------------------------------------- - # 4. Log the hyperparameters (for reproducibility purposes) - # ---------------------------------------------------------------------- - hyperparams = self._log_hyperparams(model, logger) - - # ---------------------------------------------------------------------- - # 5. Instantiate the trainer - # ---------------------------------------------------------------------- - trainer = self._get_trainer(logger, callbacks) - - # ---------------------------------------------------------------------- - # 6. Train the model - # ---------------------------------------------------------------------- - result = self._test(model, data_module, trainer) - - pd.DataFrame(result).to_csv(self.results_path, index=False) - print(f"Results saved to: {self.results_path}") - return result - - def __call__(self): - return self._run() diff --git a/ssl_tools/models/ssl/classifier.py b/ssl_tools/models/ssl/classifier.py index 2462342..f2656c1 100644 --- a/ssl_tools/models/ssl/classifier.py +++ b/ssl_tools/models/ssl/classifier.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Callable, Dict import lightning as L import torch from torchmetrics import Metric @@ -108,6 +108,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: encodings = self.backbone.forward(*x) else: encodings = self.backbone.forward(x) + + if len(encodings.shape) == 1: + encodings = encodings.unsqueeze(0) + predictions = self.head.forward(encodings) return predictions diff --git a/ssl_tools/models/ssl/modules/heads.py b/ssl_tools/models/ssl/modules/heads.py index 1148abe..c72223f 100644 --- a/ssl_tools/models/ssl/modules/heads.py +++ b/ssl_tools/models/ssl/modules/heads.py @@ -44,8 +44,8 @@ def __init__( self, input_dim: int = 10, hidden_dim1: int = 64, - hidden_size2: int = 64, - hidden_dim2: int = 6, + hidden_dim2: int = 64, + output_dim: int = 6, dropout_prob: float = 0, ): super().__init__( @@ -58,15 +58,15 @@ def __init__( ), ( hidden_dim1, - hidden_size2, + hidden_dim2, None, torch.nn.Sequential( torch.nn.ReLU(), torch.nn.Dropout(p=dropout_prob) ), ), ( - hidden_size2, hidden_dim2, + output_dim, None, torch.nn.Softmax(dim=1), ),