From 2cdf26663634080169c52a719e8c866447099a43 Mon Sep 17 00:00:00 2001 From: Otavio Napoli Date: Fri, 12 Apr 2024 17:24:43 +0000 Subject: [PATCH 1/4] Draft implementation v1 with MLFlow Signed-off-by: Otavio Napoli --- sslt/callbacks/performance.py | 45 ++++ sslt/callbacks/save_best.py | 81 ++++++++ sslt/pipelines/base.py | 16 ++ sslt/pipelines/cli.py | 52 +++++ sslt/pipelines/configs/trainer/default.yaml | 33 +++ sslt/pipelines/configs/trainer/dry_run.yaml | 7 + sslt/pipelines/mlflow_train.py | 217 ++++++++++++++++++++ sslt/pipelines/utils.py | 59 ++++++ 8 files changed, 510 insertions(+) create mode 100644 sslt/callbacks/performance.py create mode 100644 sslt/callbacks/save_best.py create mode 100644 sslt/pipelines/base.py create mode 100644 sslt/pipelines/cli.py create mode 100644 sslt/pipelines/configs/trainer/default.yaml create mode 100644 sslt/pipelines/configs/trainer/dry_run.yaml create mode 100644 sslt/pipelines/mlflow_train.py create mode 100644 sslt/pipelines/utils.py diff --git a/sslt/callbacks/performance.py b/sslt/callbacks/performance.py new file mode 100644 index 0000000..e577b7c --- /dev/null +++ b/sslt/callbacks/performance.py @@ -0,0 +1,45 @@ +import lightning as L +from lightning.pytorch.callbacks import Callback +import time + + +class PerformanceLogger(Callback): + """This callback logs the time taken for each epoch and the overall fit + time. + """ + def __init__(self): + super().__init__() + self.train_epoch_start_time = None + self.fit_start_time = None + + def on_train_epoch_start( + self, trainer: L.Trainer, module: L.LightningModule + ): + """Called when the train epoch begins.""" + self.train_epoch_start_time = time.time() + + def on_train_epoch_end(self, trainer: L.Trainer, module: L.LightningModule): + """Called when the train epoch ends. + """ + end = time.time() + duration = end - self.train_epoch_start_time + module.log( + "train_epoch_time", + duration, + on_step=False, + on_epoch=True, + prog_bar=False, + logger=True, + sync_dist=False, + ) + self.train_epoch_start_time = end + + def on_fit_start(self, trainer: L.Trainer, module: L.LightningModule) -> None: + """Called when fit begins.""" + self.fit_start_time = time.time() + + def on_fit_end(self, trainer: L.Trainer, module: L.LightningModule) -> None: + """Called when fit ends.""" + end = time.time() + duration = end - self.fit_start_time + print(f"--> Overall fit time: {duration:.3f} seconds") \ No newline at end of file diff --git a/sslt/callbacks/save_best.py b/sslt/callbacks/save_best.py new file mode 100644 index 0000000..61076aa --- /dev/null +++ b/sslt/callbacks/save_best.py @@ -0,0 +1,81 @@ +import os +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import Any, Dict, Optional +import torch +import lightning as L +from mlflow.exceptions import MlflowException +import fsspec +from io import BytesIO + + +class PickleBestModelAndLoad(L.Callback): + def __init__( + self, + model_name: str, + filename: str = "best_model.pt", + model_tags: Optional[Dict[str, Any]] = None, + model_description: Optional[str] = None, + ): + self.model_name = model_name + self.filename = filename + self.model_tags = model_tags + self.model_description = model_description + + def on_train_end( + self, trainer: L.Trainer, module: L.LightningModule + ) -> None: + # Check if it is rank 0 + if trainer.global_rank != 0: + return + + # Does it have any + if trainer.checkpoint_callback is not None: + # Get the best model path + best_model_path = getattr( + trainer.checkpoint_callback, "best_model_path", None + ) + + if best_model_path is None: + return + + # Load the model + best_model_path = Path(best_model_path) + # Load the best model checkpoint + with fsspec.open(best_model_path, "rb") as f: + model_bytes = f.read() + model_bytes = BytesIO(model_bytes) + module.load_state_dict(torch.load(model_bytes)["state_dict"]) + + # Lets pickle the model + with TemporaryDirectory(prefix="test", suffix="test") as tempdir: + # Save the whle model in a temporary directory + save_file = Path(tempdir) / self.filename + torch.save(module, save_file) + + # Save the model as an MLFlow artifact + trainer.logger.experiment.log_artifact( + trainer.logger.run_id, save_file, artifact_path=f"model" + ) + + # Locate the artifact path + src = f"runs:/{trainer.logger.run_id}/model/{self.filename}" + + + try: + trainer.logger.experiment.create_registered_model( + name=self.model_name, + tags={ + "pickable": True + } + ) + except MlflowException: + pass + + trainer.logger.experiment.create_model_version( + name=self.model_name, + source=src, + run_id=trainer.logger.run_id, + tags=self.model_tags, + description=self.model_description, + ) diff --git a/sslt/pipelines/base.py b/sslt/pipelines/base.py new file mode 100644 index 0000000..7e5093c --- /dev/null +++ b/sslt/pipelines/base.py @@ -0,0 +1,16 @@ +from abc import abstractmethod +from lightning.pytorch.core.mixins import HyperparametersMixin + +from typing import Any + + +class Pipeline(HyperparametersMixin): + def __init__(self): + self.save_hyperparameters() + + @abstractmethod + def run(self) -> Any: + raise NotImplementedError + + def __call__(self): + return self.run() \ No newline at end of file diff --git a/sslt/pipelines/cli.py b/sslt/pipelines/cli.py new file mode 100644 index 0000000..6743c53 --- /dev/null +++ b/sslt/pipelines/cli.py @@ -0,0 +1,52 @@ +from typing import Any, Dict + + +from typing import Dict +import yaml +from typing import Any, Dict +from jsonargparse import ActionConfigFile, ArgumentParser + +from ssl_tools.pipelines.base import Pipeline + +def get_parser(commands: Dict[str, Pipeline] | Pipeline): + parser = ArgumentParser() + + if isinstance(commands, Pipeline): + commands = {"run": commands} + + subcommands = parser.add_subcommands() + + for name, command in commands.items(): + subparser = ArgumentParser() + subparser.add_class_arguments(command) + subparser.add_argument( + "--config", + action=ActionConfigFile, + help="Path to a configuration file, in YAML format", + ) + subcommands.add_subcommand(name, subparser) + + return parser + + +def auto_main(commands: Dict[str, Pipeline] | Pipeline, print_args: bool = False) -> Any: + parser = get_parser(commands) + + args = parser.parse_args() + config_file = args[args.subcommand].pop("config", None) + + if config_file: + config_file = config_file[0].absolute + with open(config_file, "r") as f: + config_from_file = yaml.safe_load(f) + + config = config_from_file + config.update(args[args.subcommand]) + else: + config = dict(args[args.subcommand]) + + if print_args: + print(config) + + pipeline: Pipeline = commands[args.subcommand](**config) + return pipeline.run() diff --git a/sslt/pipelines/configs/trainer/default.yaml b/sslt/pipelines/configs/trainer/default.yaml new file mode 100644 index 0000000..fba97b2 --- /dev/null +++ b/sslt/pipelines/configs/trainer/default.yaml @@ -0,0 +1,33 @@ +accelerator: gpu +devices: 1 +max_epochs: 100 +strategy: auto +enable_checkpointing: True +log_every_n_steps: 10 +logger: + class_path: lightning.pytorch.loggers.CSVLogger + init_args: {} +callbacks: + # Checkpointing + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: val_loss + mode: min + save_top_k: 5 + verbose: False + save_last: True + save_weights_only: True + every_n_epochs: 2 + save_on_train_epoch_end: False + auto_insert_metric_name: True + filename: "{epoch}-{step}-{val_loss:.2f}" + # Early stopping + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + monitor: val_loss + mode: min + patience: 50 + verbose: False + # Performance logging + - class_path: ssl_tools.callbacks.performance.PerformanceLogger + init_args: {} \ No newline at end of file diff --git a/sslt/pipelines/configs/trainer/dry_run.yaml b/sslt/pipelines/configs/trainer/dry_run.yaml new file mode 100644 index 0000000..fc1c762 --- /dev/null +++ b/sslt/pipelines/configs/trainer/dry_run.yaml @@ -0,0 +1,7 @@ +accelerator: gpu +devices: 1 +max_epochs: 2 +limit_train_batches: 0.1 +limit_val_batches: 0.1 +limit_test_batches: 1 +strategy: auto \ No newline at end of file diff --git a/sslt/pipelines/mlflow_train.py b/sslt/pipelines/mlflow_train.py new file mode 100644 index 0000000..6baf851 --- /dev/null +++ b/sslt/pipelines/mlflow_train.py @@ -0,0 +1,217 @@ +from io import BytesIO +from typing import Dict, List, final +import fsspec +import lightning as L +from lightning.pytorch.loggers import Logger, MLFlowLogger +from lightning.pytorch.callbacks import ( + ModelCheckpoint, + RichProgressBar, + EarlyStopping, + ModelSummary, +) +import torch + +from ssl_tools.callbacks.performance import PerformanceLogger +from ssl_tools.callbacks.save_best import PickleBestModelAndLoad +from ssl_tools.pipelines.base import Pipeline +import mlflow +from ssl_tools.models.ssl.classifier import SSLDiscriminator +from ssl_tools.pipelines.utils import load_model_mlflow + + +class LightningTrainMLFlow(Pipeline): + def __init__( + self, + # Required paramters + experiment_name: str, + model_name: str, + # Optional parameters + run_name: str = None, + accelerator: str = "cpu", + devices: int = 1, + num_nodes: int = 1, + strategy: str = "auto", + max_epochs: int = 1, + batch_size: int = 1, + limit_train_batches: int | float = 1.0, + limit_val_batches: int | float = 1.0, + checkpoint_monitor_metric: str = None, + checkpoint_monitor_mode: str = "min", + patience: int = None, + log_dir: str = "./mlruns", + model_tags: Dict[str, str] = None + ): + """Train a model using Lightning framework. + + Parameters + ---------- + experiment_name : str + Name of the experiment. + model_name : str + Name of the model. + dataset_name : str + Name of the dataset. + run_name : str, optional + The name of the run, by default None + accelerator : str, optional + The accelerator to use, by default "cpu" + devices : int, optional + Number of accelerators to use, by default 1 + num_nodes : int, optional + Number of nodes, by default 1 + strategy : str, optional + Training strategy, by default "auto" + max_epochs : int, optional + Maximium number of epochs, by default 1 + batch_size : int, optional + Batch size, by default 1 + limit_train_batches : int | float, optional + Limit the number of batches to train, by default 1.0 + limit_val_batches : int | float, optional + Limit the number of batches to test, by default 1.0 + checkpoint_monitor_metric : str, optional + The metric to monitor for checkpointing, by default None + checkpoint_monitor_mode : str, optional + The mode for checkpointing, by default "min" + patience : int, optional + The patience for early stopping, by default None + log_dir : str, optional + Location where logs will be saved, by default "./runs" + """ + + super().__init__() + self.experiment_name = experiment_name + self.model_name = model_name + self.run_name = run_name + self.accelerator = accelerator + self.devices = devices + self.num_nodes = num_nodes + self.strategy = strategy + self.max_epochs = max_epochs + self.batch_size = batch_size + self.limit_train_batches = limit_train_batches + self.limit_val_batches = limit_val_batches + self.checkpoint_monitor_metric = checkpoint_monitor_metric + self.checkpoint_monitor_mode = checkpoint_monitor_mode + self.patience = patience + self.log_dir = log_dir + self.model_tags = model_tags + self._hparams = dict(self.hparams) + + def get_model(self) -> L.LightningModule: + raise NotImplementedError + + def get_data_module(self) -> L.LightningDataModule: + raise NotImplementedError + + def get_trainer( + self, logger: Logger, callacks: List[L.Callback] + ) -> L.Trainer: + return L.Trainer( + accelerator=self.accelerator, + devices=self.devices, + num_nodes=self.num_nodes, + strategy=self.strategy, + max_epochs=self.max_epochs, + logger=logger, + callbacks=callacks, + limit_train_batches=self.limit_train_batches, + limit_val_batches=self.limit_val_batches, + ) + + def get_callbacks(self) -> List[L.Callback]: + callbacks = [] + + model_summary = ModelSummary(max_depth=3) + callbacks.append(model_summary) + + model_checkpoint = ModelCheckpoint( + monitor=self.checkpoint_monitor_metric, + mode=self.checkpoint_monitor_mode, + # save_last=True + ) + callbacks.append(model_checkpoint) + + if self.patience: + early_stopping = EarlyStopping( + monitor=self.checkpoint_monitor_metric, + patience=self.patience, + mode=self.checkpoint_monitor_mode, + ) + callbacks.append(early_stopping) + + performance_logger = PerformanceLogger() + callbacks.append(performance_logger) + + # device_stats_monitor = DeviceStatsMonitor() + # callbacks.append(device_stats_monitor) + best_logger = PickleBestModelAndLoad( + model_name=self.model_name, + model_tags=self.model_tags, + ) + callbacks.append(best_logger) + + rich_progress_bar = RichProgressBar() + callbacks.append(rich_progress_bar) + + return callbacks + + def get_logger(self) -> Logger: + return MLFlowLogger( + experiment_name=self.experiment_name, + run_name=self.run_name, + save_dir=self.log_dir, + log_model=True, + tags=self.model_tags + ) + + def run(self) -> L.LightningModule: + # Get all required components + model = self.get_model() + datamodule = self.get_data_module() + logger = self.get_logger() + callbacks = self.get_callbacks() + trainer = self.get_trainer(logger, callbacks) + + # Log the experiment hyperparameters + logger.log_hyperparams(self._hparams) + + # Do fit and return trained model + trainer.fit(model, datamodule) + return model + + + +class LightningFineTuneMLFlow(LightningTrainMLFlow): + def __init__( + self, + registered_model_name: str, + registered_model_tags: Dict[str, str] = None, + update_backbone: bool = False, + **kwargs, + ): + super().__init__(**kwargs) + self.registered_model_name = registered_model_name + self.registered_model_tags = registered_model_tags + self.model_version = dict() + self._mlflow_client = None + self.update_backbone = update_backbone + + @property + def client(self): + if self._mlflow_client is None: + print(f"Initializing MLFlow client at {self.log_dir}...") + self._mlflow_client = mlflow.client.MlflowClient( + tracking_uri=self.log_dir + ) + return self._mlflow_client + + def load_model(self) -> L.LightningModule: + model, model_version = load_model_mlflow( + self.client, + self.registered_model_name, + self.registered_model_tags, + ) + self.model_version = model_version + self._hparams.update({"model_version": model_version}) + return model \ No newline at end of file diff --git a/sslt/pipelines/utils.py b/sslt/pipelines/utils.py new file mode 100644 index 0000000..3dfb1e5 --- /dev/null +++ b/sslt/pipelines/utils.py @@ -0,0 +1,59 @@ +from io import BytesIO +from typing import Dict +import fsspec +import mlflow + +import lightning as L +import torch + + +def tags2str(d: Dict[str, str]) -> str: + """ + Convert a dictionary of tags to a search string compatible with MLflow's search_model_versions method. + + Parameters: + - d: A dictionary containing tags where keys are tag names and values are tag values. + + Returns: + - search_str: A search string formatted for MLflow's search_model_versions method. + """ + search_str = " and ".join( + [f"tags.`{key}`='{value}'" for key, value in d.items()] + ) + return search_str + + +def load_model_mlflow( + client: mlflow.client.MlflowClient, + registered_model_name: str, + registered_model_tags: Dict[str, str] = None, +) -> Dict[L.LightningModule, Dict[str, str]]: + search_string = f"name='{registered_model_name}'" + if registered_model_tags is not None: + search_string += " and " + tags2str(registered_model_tags) + + registered_model = client.search_model_versions( + search_string, order_by=["creation_timestamp DESC"], max_results=1 + ) + + if len(registered_model) == 0: + raise ValueError( + f"No model found with the name '{registered_model_name}' and tags '{registered_model_tags}'. Query string used: {search_string}" + ) + + model_version = registered_model[0] + run_id = model_version.run_id + artifact_path = "/".join(model_version.source.split("/")[2:]) + artifact_uri = ( + client.get_run(run_id).info.artifact_uri + "/" + artifact_path + ) + + # print(f"Loading model from: {artifact_uri}") + + with fsspec.open(artifact_uri, "rb") as f: + model_bytes = f.read() + model_bytes = BytesIO(model_bytes) + model = torch.load(model_bytes) + + print(f"Model loaded from: {artifact_uri}.") + return model, dict(model_version) From 899fa19e8bf79ac2cfb9f5f1968b461b2366a9f0 Mon Sep 17 00:00:00 2001 From: Otavio Napoli Date: Fri, 12 Apr 2024 17:43:05 +0000 Subject: [PATCH 2/4] Added configs from --- .../pipelines/configs/data_modules/2D/f3.yaml | 0 sslt/pipelines/configs/models/2D/wisenet.yaml | 9 + sslt/pipelines/main_regression.py | 359 ++++++++++++++++++ 3 files changed, 368 insertions(+) create mode 100644 sslt/pipelines/configs/data_modules/2D/f3.yaml create mode 100644 sslt/pipelines/configs/models/2D/wisenet.yaml create mode 100644 sslt/pipelines/main_regression.py diff --git a/sslt/pipelines/configs/data_modules/2D/f3.yaml b/sslt/pipelines/configs/data_modules/2D/f3.yaml new file mode 100644 index 0000000..e69de29 diff --git a/sslt/pipelines/configs/models/2D/wisenet.yaml b/sslt/pipelines/configs/models/2D/wisenet.yaml new file mode 100644 index 0000000..c0fd2f0 --- /dev/null +++ b/sslt/pipelines/configs/models/2D/wisenet.yaml @@ -0,0 +1,9 @@ +class_path: minerva.models.nets.WiseNet +init_args: + in_channels: 1 + out_channels: 1 + loss_fn: + class_path: torch.nn.MSELoss + init_args: + reduction: mean + learning_rate: 0.001 \ No newline at end of file diff --git a/sslt/pipelines/main_regression.py b/sslt/pipelines/main_regression.py new file mode 100644 index 0000000..a211cba --- /dev/null +++ b/sslt/pipelines/main_regression.py @@ -0,0 +1,359 @@ +#!/usr/bin/env python3 + +import copy +from dataclasses import dataclass +from datetime import datetime +import random +import traceback +from pathlib import Path +from typing import Any, Dict, List + +import lightning as L +import ray +import yaml +from jsonargparse import CLI +from lightning.pytorch.cli import LightningArgumentParser, LightningCLI +from torchmetrics import Accuracy + +from typing import Any, Dict, List + + +@dataclass +class ExperimentArgs: + trainer: Dict[str, Any] + model: Dict[str, Any] + data: Dict[str, Any] + test_data: Dict[str, Any] + seed: int = 42 + + +def cli_main(experiment: ExperimentArgs): + class DummyModel(L.LightningModule): + def __init__(self, *args, **kwargs): + pass + + class DummyTrainer(L.Trainer): + def __init__(self, *args, **kwargs): + pass + + # Unpack experiment into a dict, ignoring the test_data for now + cli_args = { + "trainer": experiment.trainer, + "model": experiment.model, + "data": experiment.data, + "seed_everything": experiment.seed, + } + + # print(cli_args) + + # Instantiate model, trainer, and train_datamodule + train_cli = LightningCLI( + args=cli_args, run=False, parser_kwargs={"parser_mode": "omegaconf"} + ) + + test_cli = LightningCLI( + model_class=DummyModel, + trainer_class=DummyTrainer, + args={ + "trainer": {}, + "model": {}, + "data": experiment.test_data, + }, + run=False, + ) + + # Shortcut to access the trainer, model and datamodule + trainer = train_cli.trainer + model = train_cli.model + train_data_module = train_cli.datamodule + test_data_module = test_cli.datamodule + + # Attach model test metrics + model.metrics["test"]["accuracy"] = Accuracy( + task="multiclass", num_classes=7 + ) + + # Perform FIT + trainer.fit(model, train_data_module) + + # Perform test and return metrics + metrics = trainer.test(model, test_data_module, ckpt_path="best") + return metrics + + +def _run_experiment_wrapper(experiment_args: ExperimentArgs): + try: + print() + print("*" * 80) + print(f"Running Experiment") + print(f" Model: {experiment_args.model['class_path']}") + print( + f" Train Data: {experiment_args.data['init_args']['data_path']}" + ) + print( + f" Test Data: {experiment_args.test_data['init_args']['data_path']}" + ) + print("*" * 80) + print() + + return cli_main(experiment_args) + except Exception as e: + print(f" ------- Error running evaluator: {e} ----------") + traceback.print_exc() + print("----------------------------------------------------") + raise e + + +def run_using_ray(experiments: List[ExperimentArgs], ray_address: str = None): + print(f"Running {len(experiments)} experiments using RAY...") + ray.init(address=ray_address) + remotes_to_run = [ + ray.remote( + num_gpus=0.25, + num_cpus=4, + max_calls=1, + max_retries=0, + retry_exceptions=False, + )(_run_experiment_wrapper).remote(exp_args) + for exp_args in experiments + ] + ready, not_ready = ray.wait(remotes_to_run, num_returns=len(remotes_to_run)) + print(f"Ready: {len(ready)}. Not ready: {len(not_ready)}") + ray.shutdown() + return ready, not_ready + + +def run_serial(experiments: List[ExperimentArgs]): + print(f"Running {len(experiments)} experiments...") + for exp_args in experiments: + _run_experiment_wrapper(exp_args) + + +class SupervisedConfigParser: + def __init__( + self, + data_path: str, + default_trainer_config: str, + data_module_configs: str, + model_configs: str, + output_dir: str = "benchmarks/", + skip_existing: bool = True, + enable_checkpointing: bool = False, + seed: int = 42, + leave_one_out: bool = False, + ): + self.data_path = data_path + self.default_trainer_config = default_trainer_config + self.data_module_configs = data_module_configs + self.model_configs = model_configs + self.output_dir = output_dir + self.skip_existing = skip_existing + self.enable_checkpointing = enable_checkpointing + self.seed = seed + self.leave_one_out = leave_one_out + + # TODO automate this, using the a query string, sql like + def filter_experiments(self, experiments: List[ExperimentArgs]): + return experiments + # return [ + # exp + # for exp in experiments + # if exp.data["class_path"] == "ssl_tools.data.data_modules.har.AugmentedMultiModalHARSeriesDataModule" + # ] + + def __call__(self) -> List[ExperimentArgs]: + input_type = ["1D", "2D"] + now = datetime.now().strftime("%d-%m-%Y_%H-%M-%S") + + models = list(sorted(self.model_configs.rglob("*.yaml"))) + data_modules = list(sorted(self.data_module_configs.rglob("*.yaml"))) + datasets = list(sorted(self.data_path.glob("*"))) + experiments = [] + + initial_trainer_config = yaml.safe_load( + self.default_trainer_config.read_text() + ) + initial_trainer_config["enable_checkpointing"] = ( + self.enable_checkpointing + ) + + # Segment datasets into train and test + if self.leave_one_out: + datasets = [ + { + "train": [d1 for d1 in datasets if d1 != dataset], + "test": dataset, + } + for dataset in datasets + ] + else: + datasets = [ + { + "train": dataset, + "test": dataset, + } + for dataset in datasets + ] + + # Scanning for configs + for model in models: + input_type = model.parent.stem + model_name = model.stem + initial_model_config = yaml.safe_load(model.read_text()) + + for data_module in data_modules: + data_module_input_type = data_module.parent.stem + data_module_name = data_module.stem + initial_data_module_config = yaml.safe_load( + data_module.read_text() + ) + + if input_type != data_module_input_type: + continue + + for dataset in datasets: + train_set = dataset["train"] + test_set = dataset["test"] + + if isinstance(train_set, list): + train_set_name = "+".join([d.stem for d in train_set]) + else: + train_set_name = train_set.stem + test_set_name = test_set.stem + + name = f"{model_name}-{data_module_name}-train_on_{train_set_name}-test_on_{test_set_name}" + version = now + + if self.skip_existing: + if (self.output_dir / name).exists(): + print(f"-- Skipping existing experiment: {name}") + continue + + trainer_config = copy.deepcopy(initial_trainer_config) + model_config = copy.deepcopy(initial_model_config) + data_module_config = copy.deepcopy( + initial_data_module_config + ) + + trainer_config["logger"] = { + "class_path": "lightning.pytorch.loggers.CSVLogger", + "init_args": { + "save_dir": self.output_dir, + "name": name, + "version": version, + }, + } + + # TODO ------- remove this when all gpu bugs were fixed -------- + if ( + "lstm" in model_name + or "inception" in model_name + or "cnn_haetal_2d" in model_name + or "cnn_pf" in model_name + or "cnn_pff" in model_name + ): + trainer_config["accelerator"] = "cpu" + else: + trainer_config["accelerator"] = "gpu" + + data_module_config["init_args"]["data_path"] = train_set + + test_data_module_config = copy.deepcopy(data_module_config) + test_data_module_config["init_args"]["data_path"] = test_set + + experiments.append( + ExperimentArgs( + trainer=trainer_config, + model=model_config, + data=data_module_config, + test_data=test_data_module_config, + seed=self.seed, + ) + ) + experiments = self.filter_experiments(experiments) + + return experiments + + +def hack_to_avoid_lightning_cli_sys_argv_warning(func, *args, **kwargs): + # Hack to avoid LightningCLI parse warning + # The warning is something like: + # /usr/local/lib/python3.10/dist-packages/lightning/pytorch/cli.py:520: + # LightningCLI's args parameter is intended to run from with in Python like + # if it were from the command line. To prevent mistakes it is not + # recommended to provide both args and command line arguments, got: + # sys.argv[1:]=['--config', 'benchmark_a1_dry_run.yaml'], + def hack_to_avoid_lightning_cli_sys_argv_warning_wrapper(*args, **kwargs): + import sys + + old_args = sys.argv + sys.argv = sys.argv[:1] + func(*args, **kwargs) + sys.argv = old_args + + return hack_to_avoid_lightning_cli_sys_argv_warning_wrapper + + +@hack_to_avoid_lightning_cli_sys_argv_warning +def run( + config_parser: SupervisedConfigParser, + use_ray: bool, + ray_address: str = None, + dry_run: bool = False, + dry_run_limit: int = 3, +): + experiments = config_parser() + if dry_run: + if dry_run_limit is None: + dry_run_limit = len(experiments) + + print( + f"** Dry run. Limiting to a maximum of {dry_run_limit} experiments, shuffled **" + ) + experiments = random.sample( + experiments, min(dry_run_limit, len(experiments)) + ) + + if use_ray: + return run_using_ray(experiments, ray_address) + else: + return run_serial(experiments) + + +def main( + data_path: str, + default_trainer_config_file: str, + data_module_configs_path: str, + model_configs_path: str, + output_path: str = "benchmarks/", + skip_existing: bool = True, + ray_address: str = None, + use_ray: bool = True, + enable_checkpointing: bool = False, + seed: int = 42, + dry_run: bool = False, + dry_run_limit: int = 5, + leave_one_out: bool = False, +): + parser = SupervisedConfigParser( + data_path=Path(data_path), + default_trainer_config=Path(default_trainer_config_file), + data_module_configs=Path(data_module_configs_path), + model_configs=Path(model_configs_path), + output_dir=Path(output_path), + skip_existing=skip_existing, + enable_checkpointing=enable_checkpointing, + seed=seed, + leave_one_out=leave_one_out, + ) + return run( + parser, + use_ray, + ray_address, + dry_run=dry_run, + dry_run_limit=dry_run_limit, + ) + + +if __name__ == "__main__": + CLI(main) From 8a5b919b8ae724f07fefc3322c0cd7acd51bc96d Mon Sep 17 00:00:00 2001 From: Otavio Napoli Date: Wed, 8 May 2024 03:58:53 +0000 Subject: [PATCH 3/4] Implemented base pipeline interface --- minerva/pipelines/base.py | 104 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 104 insertions(+) create mode 100644 minerva/pipelines/base.py diff --git a/minerva/pipelines/base.py b/minerva/pipelines/base.py new file mode 100644 index 0000000..c1d48c5 --- /dev/null +++ b/minerva/pipelines/base.py @@ -0,0 +1,104 @@ + +from abc import abstractmethod +import copy +from pathlib import Path +from lightning.pytorch.core.mixins import HyperparametersMixin +from typing import Any, List, Dict, Tuple +from uuid import uuid4 +from time import time +import traceback +import sys +from jsonargparse import CLI + + +class Pipeline(HyperparametersMixin): + def __init__( + self, + cwd: Path | str = None, + ignore: str | List[str] = None, + cache_result: bool = False, + ): + self._initialize_vars() + self.pipeline_id = str(uuid4().hex) + self._cache_result = cache_result + + self._cwd = cwd or Path.cwd() + if not isinstance(self._cwd, Path): + self._cwd = Path(self._cwd) + self._cwd = self._cwd.absolute() + + ignore = ignore or [] + if isinstance(ignore, str): + ignore = [ignore] + ignore.append("ignore") + + self.save_hyperparameters(ignore=ignore) + + def _initialize_vars(self): + self._created_at = time() + self._run_count = 0 + self._run_start_time = None + self._run_end_time = None + self._result = None + self._run_status = "NOT STARTED" + self._run_exception = None + + def _run(self, *args, **kwargs) -> Any: + raise NotImplementedError + + @staticmethod + def clone(other: "Pipeline") -> "Pipeline": + clone_pipeline = copy.deepcopy(other) + clone_pipeline._initialize_vars() + return clone_pipeline + + @abstractmethod + def run(self, *args, **kwargs): + self._run_count += 1 + self._run_start_time = time() + self._run_status = "RUNNING" + self._result = None + + try: + result = self._run(*args, **kwargs) + except Exception as e: + self._run_status = "FAILED" + exception = "".join(traceback.format_exception(*sys.exc_info())) + self._run_exception = exception + raise e + finally: + self._run_end_time = time() + + self._run_status = "SUCCESS" + + if self._cache_result: + self._result = result + + return result + + @property + def config(self): + params = self.hparams + return dict(params) + + @property + def status(self) -> Dict[str, Any]: + return { + "status": self._run_status, + "working_dir": str(self._cwd), + "id": self.pipeline_id, + "count": self._run_count, + "created": self._created_at, + "start_time": self._run_start_time, + "end_time": self._run_end_time, + "exception_info": self._run_exception, + "cached_result": self._result is not None, + } + + @property + def result(self) -> Any: + return self._result + + @property + def working_dir(self): + return self._cwd From c5507c304f957daa38dc83ad70eb2404b1b9dccb Mon Sep 17 00:00:00 2001 From: Otavio Napoli Date: Wed, 8 May 2024 04:01:44 +0000 Subject: [PATCH 4/4] Removed old files --- sslt/callbacks/performance.py | 45 --- sslt/callbacks/save_best.py | 81 ---- sslt/pipelines/base.py | 16 - sslt/pipelines/cli.py | 52 --- .../pipelines/configs/data_modules/2D/f3.yaml | 0 sslt/pipelines/configs/models/2D/wisenet.yaml | 9 - sslt/pipelines/configs/trainer/default.yaml | 33 -- sslt/pipelines/configs/trainer/dry_run.yaml | 7 - sslt/pipelines/main_regression.py | 359 ------------------ sslt/pipelines/mlflow_train.py | 217 ----------- sslt/pipelines/utils.py | 59 --- 11 files changed, 878 deletions(-) delete mode 100644 sslt/callbacks/performance.py delete mode 100644 sslt/callbacks/save_best.py delete mode 100644 sslt/pipelines/base.py delete mode 100644 sslt/pipelines/cli.py delete mode 100644 sslt/pipelines/configs/data_modules/2D/f3.yaml delete mode 100644 sslt/pipelines/configs/models/2D/wisenet.yaml delete mode 100644 sslt/pipelines/configs/trainer/default.yaml delete mode 100644 sslt/pipelines/configs/trainer/dry_run.yaml delete mode 100644 sslt/pipelines/main_regression.py delete mode 100644 sslt/pipelines/mlflow_train.py delete mode 100644 sslt/pipelines/utils.py diff --git a/sslt/callbacks/performance.py b/sslt/callbacks/performance.py deleted file mode 100644 index e577b7c..0000000 --- a/sslt/callbacks/performance.py +++ /dev/null @@ -1,45 +0,0 @@ -import lightning as L -from lightning.pytorch.callbacks import Callback -import time - - -class PerformanceLogger(Callback): - """This callback logs the time taken for each epoch and the overall fit - time. - """ - def __init__(self): - super().__init__() - self.train_epoch_start_time = None - self.fit_start_time = None - - def on_train_epoch_start( - self, trainer: L.Trainer, module: L.LightningModule - ): - """Called when the train epoch begins.""" - self.train_epoch_start_time = time.time() - - def on_train_epoch_end(self, trainer: L.Trainer, module: L.LightningModule): - """Called when the train epoch ends. - """ - end = time.time() - duration = end - self.train_epoch_start_time - module.log( - "train_epoch_time", - duration, - on_step=False, - on_epoch=True, - prog_bar=False, - logger=True, - sync_dist=False, - ) - self.train_epoch_start_time = end - - def on_fit_start(self, trainer: L.Trainer, module: L.LightningModule) -> None: - """Called when fit begins.""" - self.fit_start_time = time.time() - - def on_fit_end(self, trainer: L.Trainer, module: L.LightningModule) -> None: - """Called when fit ends.""" - end = time.time() - duration = end - self.fit_start_time - print(f"--> Overall fit time: {duration:.3f} seconds") \ No newline at end of file diff --git a/sslt/callbacks/save_best.py b/sslt/callbacks/save_best.py deleted file mode 100644 index 61076aa..0000000 --- a/sslt/callbacks/save_best.py +++ /dev/null @@ -1,81 +0,0 @@ -import os -from pathlib import Path -from tempfile import TemporaryDirectory -from typing import Any, Dict, Optional -import torch -import lightning as L -from mlflow.exceptions import MlflowException -import fsspec -from io import BytesIO - - -class PickleBestModelAndLoad(L.Callback): - def __init__( - self, - model_name: str, - filename: str = "best_model.pt", - model_tags: Optional[Dict[str, Any]] = None, - model_description: Optional[str] = None, - ): - self.model_name = model_name - self.filename = filename - self.model_tags = model_tags - self.model_description = model_description - - def on_train_end( - self, trainer: L.Trainer, module: L.LightningModule - ) -> None: - # Check if it is rank 0 - if trainer.global_rank != 0: - return - - # Does it have any - if trainer.checkpoint_callback is not None: - # Get the best model path - best_model_path = getattr( - trainer.checkpoint_callback, "best_model_path", None - ) - - if best_model_path is None: - return - - # Load the model - best_model_path = Path(best_model_path) - # Load the best model checkpoint - with fsspec.open(best_model_path, "rb") as f: - model_bytes = f.read() - model_bytes = BytesIO(model_bytes) - module.load_state_dict(torch.load(model_bytes)["state_dict"]) - - # Lets pickle the model - with TemporaryDirectory(prefix="test", suffix="test") as tempdir: - # Save the whle model in a temporary directory - save_file = Path(tempdir) / self.filename - torch.save(module, save_file) - - # Save the model as an MLFlow artifact - trainer.logger.experiment.log_artifact( - trainer.logger.run_id, save_file, artifact_path=f"model" - ) - - # Locate the artifact path - src = f"runs:/{trainer.logger.run_id}/model/{self.filename}" - - - try: - trainer.logger.experiment.create_registered_model( - name=self.model_name, - tags={ - "pickable": True - } - ) - except MlflowException: - pass - - trainer.logger.experiment.create_model_version( - name=self.model_name, - source=src, - run_id=trainer.logger.run_id, - tags=self.model_tags, - description=self.model_description, - ) diff --git a/sslt/pipelines/base.py b/sslt/pipelines/base.py deleted file mode 100644 index 7e5093c..0000000 --- a/sslt/pipelines/base.py +++ /dev/null @@ -1,16 +0,0 @@ -from abc import abstractmethod -from lightning.pytorch.core.mixins import HyperparametersMixin - -from typing import Any - - -class Pipeline(HyperparametersMixin): - def __init__(self): - self.save_hyperparameters() - - @abstractmethod - def run(self) -> Any: - raise NotImplementedError - - def __call__(self): - return self.run() \ No newline at end of file diff --git a/sslt/pipelines/cli.py b/sslt/pipelines/cli.py deleted file mode 100644 index 6743c53..0000000 --- a/sslt/pipelines/cli.py +++ /dev/null @@ -1,52 +0,0 @@ -from typing import Any, Dict - - -from typing import Dict -import yaml -from typing import Any, Dict -from jsonargparse import ActionConfigFile, ArgumentParser - -from ssl_tools.pipelines.base import Pipeline - -def get_parser(commands: Dict[str, Pipeline] | Pipeline): - parser = ArgumentParser() - - if isinstance(commands, Pipeline): - commands = {"run": commands} - - subcommands = parser.add_subcommands() - - for name, command in commands.items(): - subparser = ArgumentParser() - subparser.add_class_arguments(command) - subparser.add_argument( - "--config", - action=ActionConfigFile, - help="Path to a configuration file, in YAML format", - ) - subcommands.add_subcommand(name, subparser) - - return parser - - -def auto_main(commands: Dict[str, Pipeline] | Pipeline, print_args: bool = False) -> Any: - parser = get_parser(commands) - - args = parser.parse_args() - config_file = args[args.subcommand].pop("config", None) - - if config_file: - config_file = config_file[0].absolute - with open(config_file, "r") as f: - config_from_file = yaml.safe_load(f) - - config = config_from_file - config.update(args[args.subcommand]) - else: - config = dict(args[args.subcommand]) - - if print_args: - print(config) - - pipeline: Pipeline = commands[args.subcommand](**config) - return pipeline.run() diff --git a/sslt/pipelines/configs/data_modules/2D/f3.yaml b/sslt/pipelines/configs/data_modules/2D/f3.yaml deleted file mode 100644 index e69de29..0000000 diff --git a/sslt/pipelines/configs/models/2D/wisenet.yaml b/sslt/pipelines/configs/models/2D/wisenet.yaml deleted file mode 100644 index c0fd2f0..0000000 --- a/sslt/pipelines/configs/models/2D/wisenet.yaml +++ /dev/null @@ -1,9 +0,0 @@ -class_path: minerva.models.nets.WiseNet -init_args: - in_channels: 1 - out_channels: 1 - loss_fn: - class_path: torch.nn.MSELoss - init_args: - reduction: mean - learning_rate: 0.001 \ No newline at end of file diff --git a/sslt/pipelines/configs/trainer/default.yaml b/sslt/pipelines/configs/trainer/default.yaml deleted file mode 100644 index fba97b2..0000000 --- a/sslt/pipelines/configs/trainer/default.yaml +++ /dev/null @@ -1,33 +0,0 @@ -accelerator: gpu -devices: 1 -max_epochs: 100 -strategy: auto -enable_checkpointing: True -log_every_n_steps: 10 -logger: - class_path: lightning.pytorch.loggers.CSVLogger - init_args: {} -callbacks: - # Checkpointing - - class_path: lightning.pytorch.callbacks.ModelCheckpoint - init_args: - monitor: val_loss - mode: min - save_top_k: 5 - verbose: False - save_last: True - save_weights_only: True - every_n_epochs: 2 - save_on_train_epoch_end: False - auto_insert_metric_name: True - filename: "{epoch}-{step}-{val_loss:.2f}" - # Early stopping - - class_path: lightning.pytorch.callbacks.EarlyStopping - init_args: - monitor: val_loss - mode: min - patience: 50 - verbose: False - # Performance logging - - class_path: ssl_tools.callbacks.performance.PerformanceLogger - init_args: {} \ No newline at end of file diff --git a/sslt/pipelines/configs/trainer/dry_run.yaml b/sslt/pipelines/configs/trainer/dry_run.yaml deleted file mode 100644 index fc1c762..0000000 --- a/sslt/pipelines/configs/trainer/dry_run.yaml +++ /dev/null @@ -1,7 +0,0 @@ -accelerator: gpu -devices: 1 -max_epochs: 2 -limit_train_batches: 0.1 -limit_val_batches: 0.1 -limit_test_batches: 1 -strategy: auto \ No newline at end of file diff --git a/sslt/pipelines/main_regression.py b/sslt/pipelines/main_regression.py deleted file mode 100644 index a211cba..0000000 --- a/sslt/pipelines/main_regression.py +++ /dev/null @@ -1,359 +0,0 @@ -#!/usr/bin/env python3 - -import copy -from dataclasses import dataclass -from datetime import datetime -import random -import traceback -from pathlib import Path -from typing import Any, Dict, List - -import lightning as L -import ray -import yaml -from jsonargparse import CLI -from lightning.pytorch.cli import LightningArgumentParser, LightningCLI -from torchmetrics import Accuracy - -from typing import Any, Dict, List - - -@dataclass -class ExperimentArgs: - trainer: Dict[str, Any] - model: Dict[str, Any] - data: Dict[str, Any] - test_data: Dict[str, Any] - seed: int = 42 - - -def cli_main(experiment: ExperimentArgs): - class DummyModel(L.LightningModule): - def __init__(self, *args, **kwargs): - pass - - class DummyTrainer(L.Trainer): - def __init__(self, *args, **kwargs): - pass - - # Unpack experiment into a dict, ignoring the test_data for now - cli_args = { - "trainer": experiment.trainer, - "model": experiment.model, - "data": experiment.data, - "seed_everything": experiment.seed, - } - - # print(cli_args) - - # Instantiate model, trainer, and train_datamodule - train_cli = LightningCLI( - args=cli_args, run=False, parser_kwargs={"parser_mode": "omegaconf"} - ) - - test_cli = LightningCLI( - model_class=DummyModel, - trainer_class=DummyTrainer, - args={ - "trainer": {}, - "model": {}, - "data": experiment.test_data, - }, - run=False, - ) - - # Shortcut to access the trainer, model and datamodule - trainer = train_cli.trainer - model = train_cli.model - train_data_module = train_cli.datamodule - test_data_module = test_cli.datamodule - - # Attach model test metrics - model.metrics["test"]["accuracy"] = Accuracy( - task="multiclass", num_classes=7 - ) - - # Perform FIT - trainer.fit(model, train_data_module) - - # Perform test and return metrics - metrics = trainer.test(model, test_data_module, ckpt_path="best") - return metrics - - -def _run_experiment_wrapper(experiment_args: ExperimentArgs): - try: - print() - print("*" * 80) - print(f"Running Experiment") - print(f" Model: {experiment_args.model['class_path']}") - print( - f" Train Data: {experiment_args.data['init_args']['data_path']}" - ) - print( - f" Test Data: {experiment_args.test_data['init_args']['data_path']}" - ) - print("*" * 80) - print() - - return cli_main(experiment_args) - except Exception as e: - print(f" ------- Error running evaluator: {e} ----------") - traceback.print_exc() - print("----------------------------------------------------") - raise e - - -def run_using_ray(experiments: List[ExperimentArgs], ray_address: str = None): - print(f"Running {len(experiments)} experiments using RAY...") - ray.init(address=ray_address) - remotes_to_run = [ - ray.remote( - num_gpus=0.25, - num_cpus=4, - max_calls=1, - max_retries=0, - retry_exceptions=False, - )(_run_experiment_wrapper).remote(exp_args) - for exp_args in experiments - ] - ready, not_ready = ray.wait(remotes_to_run, num_returns=len(remotes_to_run)) - print(f"Ready: {len(ready)}. Not ready: {len(not_ready)}") - ray.shutdown() - return ready, not_ready - - -def run_serial(experiments: List[ExperimentArgs]): - print(f"Running {len(experiments)} experiments...") - for exp_args in experiments: - _run_experiment_wrapper(exp_args) - - -class SupervisedConfigParser: - def __init__( - self, - data_path: str, - default_trainer_config: str, - data_module_configs: str, - model_configs: str, - output_dir: str = "benchmarks/", - skip_existing: bool = True, - enable_checkpointing: bool = False, - seed: int = 42, - leave_one_out: bool = False, - ): - self.data_path = data_path - self.default_trainer_config = default_trainer_config - self.data_module_configs = data_module_configs - self.model_configs = model_configs - self.output_dir = output_dir - self.skip_existing = skip_existing - self.enable_checkpointing = enable_checkpointing - self.seed = seed - self.leave_one_out = leave_one_out - - # TODO automate this, using the a query string, sql like - def filter_experiments(self, experiments: List[ExperimentArgs]): - return experiments - # return [ - # exp - # for exp in experiments - # if exp.data["class_path"] == "ssl_tools.data.data_modules.har.AugmentedMultiModalHARSeriesDataModule" - # ] - - def __call__(self) -> List[ExperimentArgs]: - input_type = ["1D", "2D"] - now = datetime.now().strftime("%d-%m-%Y_%H-%M-%S") - - models = list(sorted(self.model_configs.rglob("*.yaml"))) - data_modules = list(sorted(self.data_module_configs.rglob("*.yaml"))) - datasets = list(sorted(self.data_path.glob("*"))) - experiments = [] - - initial_trainer_config = yaml.safe_load( - self.default_trainer_config.read_text() - ) - initial_trainer_config["enable_checkpointing"] = ( - self.enable_checkpointing - ) - - # Segment datasets into train and test - if self.leave_one_out: - datasets = [ - { - "train": [d1 for d1 in datasets if d1 != dataset], - "test": dataset, - } - for dataset in datasets - ] - else: - datasets = [ - { - "train": dataset, - "test": dataset, - } - for dataset in datasets - ] - - # Scanning for configs - for model in models: - input_type = model.parent.stem - model_name = model.stem - initial_model_config = yaml.safe_load(model.read_text()) - - for data_module in data_modules: - data_module_input_type = data_module.parent.stem - data_module_name = data_module.stem - initial_data_module_config = yaml.safe_load( - data_module.read_text() - ) - - if input_type != data_module_input_type: - continue - - for dataset in datasets: - train_set = dataset["train"] - test_set = dataset["test"] - - if isinstance(train_set, list): - train_set_name = "+".join([d.stem for d in train_set]) - else: - train_set_name = train_set.stem - test_set_name = test_set.stem - - name = f"{model_name}-{data_module_name}-train_on_{train_set_name}-test_on_{test_set_name}" - version = now - - if self.skip_existing: - if (self.output_dir / name).exists(): - print(f"-- Skipping existing experiment: {name}") - continue - - trainer_config = copy.deepcopy(initial_trainer_config) - model_config = copy.deepcopy(initial_model_config) - data_module_config = copy.deepcopy( - initial_data_module_config - ) - - trainer_config["logger"] = { - "class_path": "lightning.pytorch.loggers.CSVLogger", - "init_args": { - "save_dir": self.output_dir, - "name": name, - "version": version, - }, - } - - # TODO ------- remove this when all gpu bugs were fixed -------- - if ( - "lstm" in model_name - or "inception" in model_name - or "cnn_haetal_2d" in model_name - or "cnn_pf" in model_name - or "cnn_pff" in model_name - ): - trainer_config["accelerator"] = "cpu" - else: - trainer_config["accelerator"] = "gpu" - - data_module_config["init_args"]["data_path"] = train_set - - test_data_module_config = copy.deepcopy(data_module_config) - test_data_module_config["init_args"]["data_path"] = test_set - - experiments.append( - ExperimentArgs( - trainer=trainer_config, - model=model_config, - data=data_module_config, - test_data=test_data_module_config, - seed=self.seed, - ) - ) - experiments = self.filter_experiments(experiments) - - return experiments - - -def hack_to_avoid_lightning_cli_sys_argv_warning(func, *args, **kwargs): - # Hack to avoid LightningCLI parse warning - # The warning is something like: - # /usr/local/lib/python3.10/dist-packages/lightning/pytorch/cli.py:520: - # LightningCLI's args parameter is intended to run from with in Python like - # if it were from the command line. To prevent mistakes it is not - # recommended to provide both args and command line arguments, got: - # sys.argv[1:]=['--config', 'benchmark_a1_dry_run.yaml'], - def hack_to_avoid_lightning_cli_sys_argv_warning_wrapper(*args, **kwargs): - import sys - - old_args = sys.argv - sys.argv = sys.argv[:1] - func(*args, **kwargs) - sys.argv = old_args - - return hack_to_avoid_lightning_cli_sys_argv_warning_wrapper - - -@hack_to_avoid_lightning_cli_sys_argv_warning -def run( - config_parser: SupervisedConfigParser, - use_ray: bool, - ray_address: str = None, - dry_run: bool = False, - dry_run_limit: int = 3, -): - experiments = config_parser() - if dry_run: - if dry_run_limit is None: - dry_run_limit = len(experiments) - - print( - f"** Dry run. Limiting to a maximum of {dry_run_limit} experiments, shuffled **" - ) - experiments = random.sample( - experiments, min(dry_run_limit, len(experiments)) - ) - - if use_ray: - return run_using_ray(experiments, ray_address) - else: - return run_serial(experiments) - - -def main( - data_path: str, - default_trainer_config_file: str, - data_module_configs_path: str, - model_configs_path: str, - output_path: str = "benchmarks/", - skip_existing: bool = True, - ray_address: str = None, - use_ray: bool = True, - enable_checkpointing: bool = False, - seed: int = 42, - dry_run: bool = False, - dry_run_limit: int = 5, - leave_one_out: bool = False, -): - parser = SupervisedConfigParser( - data_path=Path(data_path), - default_trainer_config=Path(default_trainer_config_file), - data_module_configs=Path(data_module_configs_path), - model_configs=Path(model_configs_path), - output_dir=Path(output_path), - skip_existing=skip_existing, - enable_checkpointing=enable_checkpointing, - seed=seed, - leave_one_out=leave_one_out, - ) - return run( - parser, - use_ray, - ray_address, - dry_run=dry_run, - dry_run_limit=dry_run_limit, - ) - - -if __name__ == "__main__": - CLI(main) diff --git a/sslt/pipelines/mlflow_train.py b/sslt/pipelines/mlflow_train.py deleted file mode 100644 index 6baf851..0000000 --- a/sslt/pipelines/mlflow_train.py +++ /dev/null @@ -1,217 +0,0 @@ -from io import BytesIO -from typing import Dict, List, final -import fsspec -import lightning as L -from lightning.pytorch.loggers import Logger, MLFlowLogger -from lightning.pytorch.callbacks import ( - ModelCheckpoint, - RichProgressBar, - EarlyStopping, - ModelSummary, -) -import torch - -from ssl_tools.callbacks.performance import PerformanceLogger -from ssl_tools.callbacks.save_best import PickleBestModelAndLoad -from ssl_tools.pipelines.base import Pipeline -import mlflow -from ssl_tools.models.ssl.classifier import SSLDiscriminator -from ssl_tools.pipelines.utils import load_model_mlflow - - -class LightningTrainMLFlow(Pipeline): - def __init__( - self, - # Required paramters - experiment_name: str, - model_name: str, - # Optional parameters - run_name: str = None, - accelerator: str = "cpu", - devices: int = 1, - num_nodes: int = 1, - strategy: str = "auto", - max_epochs: int = 1, - batch_size: int = 1, - limit_train_batches: int | float = 1.0, - limit_val_batches: int | float = 1.0, - checkpoint_monitor_metric: str = None, - checkpoint_monitor_mode: str = "min", - patience: int = None, - log_dir: str = "./mlruns", - model_tags: Dict[str, str] = None - ): - """Train a model using Lightning framework. - - Parameters - ---------- - experiment_name : str - Name of the experiment. - model_name : str - Name of the model. - dataset_name : str - Name of the dataset. - run_name : str, optional - The name of the run, by default None - accelerator : str, optional - The accelerator to use, by default "cpu" - devices : int, optional - Number of accelerators to use, by default 1 - num_nodes : int, optional - Number of nodes, by default 1 - strategy : str, optional - Training strategy, by default "auto" - max_epochs : int, optional - Maximium number of epochs, by default 1 - batch_size : int, optional - Batch size, by default 1 - limit_train_batches : int | float, optional - Limit the number of batches to train, by default 1.0 - limit_val_batches : int | float, optional - Limit the number of batches to test, by default 1.0 - checkpoint_monitor_metric : str, optional - The metric to monitor for checkpointing, by default None - checkpoint_monitor_mode : str, optional - The mode for checkpointing, by default "min" - patience : int, optional - The patience for early stopping, by default None - log_dir : str, optional - Location where logs will be saved, by default "./runs" - """ - - super().__init__() - self.experiment_name = experiment_name - self.model_name = model_name - self.run_name = run_name - self.accelerator = accelerator - self.devices = devices - self.num_nodes = num_nodes - self.strategy = strategy - self.max_epochs = max_epochs - self.batch_size = batch_size - self.limit_train_batches = limit_train_batches - self.limit_val_batches = limit_val_batches - self.checkpoint_monitor_metric = checkpoint_monitor_metric - self.checkpoint_monitor_mode = checkpoint_monitor_mode - self.patience = patience - self.log_dir = log_dir - self.model_tags = model_tags - self._hparams = dict(self.hparams) - - def get_model(self) -> L.LightningModule: - raise NotImplementedError - - def get_data_module(self) -> L.LightningDataModule: - raise NotImplementedError - - def get_trainer( - self, logger: Logger, callacks: List[L.Callback] - ) -> L.Trainer: - return L.Trainer( - accelerator=self.accelerator, - devices=self.devices, - num_nodes=self.num_nodes, - strategy=self.strategy, - max_epochs=self.max_epochs, - logger=logger, - callbacks=callacks, - limit_train_batches=self.limit_train_batches, - limit_val_batches=self.limit_val_batches, - ) - - def get_callbacks(self) -> List[L.Callback]: - callbacks = [] - - model_summary = ModelSummary(max_depth=3) - callbacks.append(model_summary) - - model_checkpoint = ModelCheckpoint( - monitor=self.checkpoint_monitor_metric, - mode=self.checkpoint_monitor_mode, - # save_last=True - ) - callbacks.append(model_checkpoint) - - if self.patience: - early_stopping = EarlyStopping( - monitor=self.checkpoint_monitor_metric, - patience=self.patience, - mode=self.checkpoint_monitor_mode, - ) - callbacks.append(early_stopping) - - performance_logger = PerformanceLogger() - callbacks.append(performance_logger) - - # device_stats_monitor = DeviceStatsMonitor() - # callbacks.append(device_stats_monitor) - best_logger = PickleBestModelAndLoad( - model_name=self.model_name, - model_tags=self.model_tags, - ) - callbacks.append(best_logger) - - rich_progress_bar = RichProgressBar() - callbacks.append(rich_progress_bar) - - return callbacks - - def get_logger(self) -> Logger: - return MLFlowLogger( - experiment_name=self.experiment_name, - run_name=self.run_name, - save_dir=self.log_dir, - log_model=True, - tags=self.model_tags - ) - - def run(self) -> L.LightningModule: - # Get all required components - model = self.get_model() - datamodule = self.get_data_module() - logger = self.get_logger() - callbacks = self.get_callbacks() - trainer = self.get_trainer(logger, callbacks) - - # Log the experiment hyperparameters - logger.log_hyperparams(self._hparams) - - # Do fit and return trained model - trainer.fit(model, datamodule) - return model - - - -class LightningFineTuneMLFlow(LightningTrainMLFlow): - def __init__( - self, - registered_model_name: str, - registered_model_tags: Dict[str, str] = None, - update_backbone: bool = False, - **kwargs, - ): - super().__init__(**kwargs) - self.registered_model_name = registered_model_name - self.registered_model_tags = registered_model_tags - self.model_version = dict() - self._mlflow_client = None - self.update_backbone = update_backbone - - @property - def client(self): - if self._mlflow_client is None: - print(f"Initializing MLFlow client at {self.log_dir}...") - self._mlflow_client = mlflow.client.MlflowClient( - tracking_uri=self.log_dir - ) - return self._mlflow_client - - def load_model(self) -> L.LightningModule: - model, model_version = load_model_mlflow( - self.client, - self.registered_model_name, - self.registered_model_tags, - ) - self.model_version = model_version - self._hparams.update({"model_version": model_version}) - return model \ No newline at end of file diff --git a/sslt/pipelines/utils.py b/sslt/pipelines/utils.py deleted file mode 100644 index 3dfb1e5..0000000 --- a/sslt/pipelines/utils.py +++ /dev/null @@ -1,59 +0,0 @@ -from io import BytesIO -from typing import Dict -import fsspec -import mlflow - -import lightning as L -import torch - - -def tags2str(d: Dict[str, str]) -> str: - """ - Convert a dictionary of tags to a search string compatible with MLflow's search_model_versions method. - - Parameters: - - d: A dictionary containing tags where keys are tag names and values are tag values. - - Returns: - - search_str: A search string formatted for MLflow's search_model_versions method. - """ - search_str = " and ".join( - [f"tags.`{key}`='{value}'" for key, value in d.items()] - ) - return search_str - - -def load_model_mlflow( - client: mlflow.client.MlflowClient, - registered_model_name: str, - registered_model_tags: Dict[str, str] = None, -) -> Dict[L.LightningModule, Dict[str, str]]: - search_string = f"name='{registered_model_name}'" - if registered_model_tags is not None: - search_string += " and " + tags2str(registered_model_tags) - - registered_model = client.search_model_versions( - search_string, order_by=["creation_timestamp DESC"], max_results=1 - ) - - if len(registered_model) == 0: - raise ValueError( - f"No model found with the name '{registered_model_name}' and tags '{registered_model_tags}'. Query string used: {search_string}" - ) - - model_version = registered_model[0] - run_id = model_version.run_id - artifact_path = "/".join(model_version.source.split("/")[2:]) - artifact_uri = ( - client.get_run(run_id).info.artifact_uri + "/" + artifact_path - ) - - # print(f"Loading model from: {artifact_uri}") - - with fsspec.open(artifact_uri, "rb") as f: - model_bytes = f.read() - model_bytes = BytesIO(model_bytes) - model = torch.load(model_bytes) - - print(f"Model loaded from: {artifact_uri}.") - return model, dict(model_version)