From 40c6b19f67da8e274eee0bb4ef474b3b655a5051 Mon Sep 17 00:00:00 2001 From: Gabriel Gutierrez Date: Thu, 4 Jul 2024 23:24:10 -0300 Subject: [PATCH] fixing issues --- minerva/models/nets/time_series/cnns.py | 8 ++------ minerva/pipelines/base.py | 4 ++-- minerva/pipelines/lightning_pipeline.py | 4 ++-- 3 files changed, 6 insertions(+), 10 deletions(-) diff --git a/minerva/models/nets/time_series/cnns.py b/minerva/models/nets/time_series/cnns.py index edb44b6..4d4452d 100644 --- a/minerva/models/nets/time_series/cnns.py +++ b/minerva/models/nets/time_series/cnns.py @@ -56,9 +56,7 @@ def __init__( test_metrics={"acc": Accuracy(task="multiclass", num_classes=num_classes)}, ) - def _create_backbone( - self, input_shape: Union[Tuple[int, int], Tuple[int, int, int]] - ) -> torch.nn.Module: + def _create_backbone(self, input_shape: Tuple[int, int, int]) -> torch.nn.Module: return torch.nn.Sequential( # First 2D convolutional layer torch.nn.Conv2d( @@ -146,9 +144,7 @@ def __init__( test_metrics={"acc": Accuracy(task="multiclass", num_classes=num_classes)}, ) - def _create_backbone( - self, input_shape: Union[Tuple[int, int], Tuple[int, int, int]] - ) -> torch.nn.Module: + def _create_backbone(self, input_shape: Tuple[int, int, int]) -> torch.nn.Module: first_kernel_size = 4 return torch.nn.Sequential( # Add padding diff --git a/minerva/pipelines/base.py b/minerva/pipelines/base.py index 8274189..f626a3f 100644 --- a/minerva/pipelines/base.py +++ b/minerva/pipelines/base.py @@ -44,7 +44,7 @@ class Pipeline(HyperparametersMixin): def __init__( self, log_dir: Optional[PathLike] = None, - ignore: Union[str, List[str], None] = None, + ignore: Optional[Union[str, List[str]]] = None, cache_result: bool = False, save_run_status: bool = False, ): @@ -55,7 +55,7 @@ def __init__( log_dir : PathLike, optional The default logging directory where all related pipeline files should be saved. By default None (uses current working directory) - ignore : str | List[str], optional + ignore : Union[str, List[str]], optional Pipeline __init__ attributes are saved into config attibute. This option allows to ignore some attributes from being saved. This is quite useful when the attributes are not serializable or very large. diff --git a/minerva/pipelines/lightning_pipeline.py b/minerva/pipelines/lightning_pipeline.py index 8260b6f..a05dfc4 100644 --- a/minerva/pipelines/lightning_pipeline.py +++ b/minerva/pipelines/lightning_pipeline.py @@ -167,7 +167,7 @@ def _calculate_metrics( return results # Private methods - def _fit(self, data: L.LightningDataModule, ckpt_path: Optional[PathLike]): + def _fit(self, data: L.LightningDataModule, ckpt_path: Optional[PathLike] = None): """Fit the model using the given data. Parameters @@ -182,7 +182,7 @@ def _fit(self, data: L.LightningDataModule, ckpt_path: Optional[PathLike]): model=self._model, datamodule=data, ckpt_path=ckpt_path ) - def _test(self, data: L.LightningDataModule, ckpt_path: Optional[PathLike]): + def _test(self, data: L.LightningDataModule, ckpt_path: Optional[PathLike] = None): """Test the model using the given data. Parameters