diff --git a/minerva/analysis/metrics/pixel_accuracy.py b/minerva/analysis/metrics/pixel_accuracy.py index 2c92b5a..6ac90e4 100644 --- a/minerva/analysis/metrics/pixel_accuracy.py +++ b/minerva/analysis/metrics/pixel_accuracy.py @@ -3,7 +3,7 @@ class PixelAccuracy(Metric): - def __init__(self, dist_sync_on_step=False): + def __init__(self, dist_sync_on_step: bool = False): """ Initializes a PixelAccuracy metric object. diff --git a/minerva/data/datasets/supervised_dataset.py b/minerva/data/datasets/supervised_dataset.py index 26abfa6..cccf98e 100644 --- a/minerva/data/datasets/supervised_dataset.py +++ b/minerva/data/datasets/supervised_dataset.py @@ -1,4 +1,4 @@ -from typing import List, Tuple +from typing import List, Optional, Tuple import numpy as np @@ -45,7 +45,7 @@ class SupervisedReconstructionDataset(SimpleDataset): ``` """ - def __init__(self, readers: List[_Reader], transforms: _Transform | None = None): + def __init__(self, readers: List[_Reader], transforms: Optional[_Transform] = None): """A simple dataset class for supervised reconstruction tasks. Parameters @@ -54,7 +54,7 @@ def __init__(self, readers: List[_Reader], transforms: _Transform | None = None) List of data readers. It must contain exactly 2 readers. The first reader for the input data and the second reader for the target data. - transforms: _Transform | None + transforms: Optional[_Transform] Optional data transformation pipeline. Raises diff --git a/minerva/data/readers/csv_reader.py b/minerva/data/readers/csv_reader.py index e03170e..de24f7b 100644 --- a/minerva/data/readers/csv_reader.py +++ b/minerva/data/readers/csv_reader.py @@ -1,17 +1,17 @@ -from typing import Union +from typing import List, Optional, Tuple, Union import pandas as pd + from minerva.data.readers.tabular_reader import TabularReader + class CSVReader(TabularReader): def __init__( self, path: str, - columns_to_select: Union[str, list[str]], - cast_to: str = None, - data_shape: tuple[int, ...] = None, + columns_to_select: Union[str, List[str]], + cast_to: Optional[str] = None, + data_shape: Optional[Tuple[int, ...]] = None, ): df = pd.read_csv(path) super().__init__(df, columns_to_select, cast_to, data_shape) - - diff --git a/minerva/data/readers/patched_array_reader.py b/minerva/data/readers/patched_array_reader.py index 6a7047b..c9cf8d2 100644 --- a/minerva/data/readers/patched_array_reader.py +++ b/minerva/data/readers/patched_array_reader.py @@ -1,4 +1,4 @@ -from typing import List, Tuple +from typing import Dict, List, Optional, Tuple import numpy as np from numpy.typing import ArrayLike @@ -22,10 +22,10 @@ def __init__( self, data: ArrayLike, data_shape: Tuple[int, ...], - stride: Tuple[int, ...] = None, - pad_width: Tuple[Tuple[int, int], ...] = None, + stride: Optional[Tuple[int, ...]] = None, + pad_width: Optional[Tuple[Tuple[int, int], ...]] = None, pad_mode: str = "constant", - pad_kwargs: dict = None, + pad_kwargs: Optional[Dict] = None, ): """Reads data from a NumPy array and generates patches from it. diff --git a/minerva/data/readers/tabular_reader.py b/minerva/data/readers/tabular_reader.py index d9d2706..f8efa74 100644 --- a/minerva/data/readers/tabular_reader.py +++ b/minerva/data/readers/tabular_reader.py @@ -1,18 +1,20 @@ +import re from pathlib import Path -from typing import Union +from typing import List, Optional, Tuple, Union import numpy as np -import re import pandas as pd + from minerva.data.readers.reader import _Reader + class TabularReader(_Reader): def __init__( self, df: pd.DataFrame, - columns_to_select: Union[str, list[str]], - cast_to: str = None, - data_shape: tuple[int, ...] = None, + columns_to_select: Union[str, List[str]], + cast_to: Optional[str] = None, + data_shape: Optional[Tuple[int, ...]] = None, ): """Reader to select columns from a DataFrame and return them as a NumPy array. The DataFrame is indexed by the row number. Each row of the @@ -64,9 +66,7 @@ def __getitem__(self, index: int) -> np.ndarray: # Filter valid columns based on columns_to_select list valid_columns = [] for pattern in self.columns_to_select: - valid_columns.extend( - [col for col in columns if re.match(pattern, col)] - ) + valid_columns.extend([col for col in columns if re.match(pattern, col)]) # Select the elements and return row = self.df.iloc[index][valid_columns] diff --git a/minerva/models/nets/image/setr.py b/minerva/models/nets/image/setr.py index 63c1618..094b4ec 100644 --- a/minerva/models/nets/image/setr.py +++ b/minerva/models/nets/image/setr.py @@ -1,5 +1,5 @@ import warnings -from typing import Dict, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import lightning as L import torch @@ -222,7 +222,7 @@ class _SetR_PUP(nn.Module): def __init__( self, - image_size: int | tuple[int, int], + image_size: Union[int, Tuple[int, int]], patch_size: int, num_layers: int, num_heads: int, @@ -241,14 +241,14 @@ def __init__( conv_act: nn.Module, align_corners: bool, aux_output: bool = False, - aux_output_layers: list[int] | None = None, + aux_output_layers: Optional[List[int]] = None, ): """ Initializes the SETR PUP model. Parameters ---------- - image_size : int or tuple[int, int] + image_size : int or Tuple[int, int] The size of the input image. patch_size : int The size of each patch in the input image. @@ -388,7 +388,7 @@ class SETR_PUP(L.LightningModule): def __init__( self, - image_size: int | tuple[int, int] = 512, + image_size: Union[int, Tuple[int, int]] = 512, patch_size: int = 16, num_layers: int = 24, num_heads: int = 16, @@ -411,15 +411,15 @@ def __init__( val_metrics: Optional[Dict[str, Metric]] = None, test_metrics: Optional[Dict[str, Metric]] = None, aux_output: bool = True, - aux_output_layers: list[int] | None = [9, 14, 19], - aux_weights: list[float] = [0.3, 0.3, 0.3], + aux_output_layers: Optional[List[int]] = [9, 14, 19], + aux_weights: List[float] = [0.3, 0.3, 0.3], ): """ Initializes the SetR model. Parameters ---------- - image_size : int or tuple[int, int] + image_size : int or Tuple[int, int] The input image size. Defaults to 512. patch_size : int The size of each patch. Defaults to 16. @@ -465,9 +465,9 @@ def __init__( The metrics to be used for testing evaluation. Defaults to None. aux_output : bool Whether to include auxiliary output heads in the model. Defaults to True. - aux_output_layers : list[int] | None + aux_output_layers : List[int], optional The indices of the layers to output auxiliary predictions. Defaults to [9, 14, 19]. - aux_weights : list[float] + aux_weights : List[float] The weights for the auxiliary predictions. Defaults to [0.3, 0.3, 0.3]. """ @@ -536,9 +536,9 @@ def _compute_metrics(self, y_hat: torch.Tensor, y: torch.Tensor, step_name: str) def _loss_func( self, - y_hat: ( - torch.Tensor | Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] - ), + y_hat: Union[ + torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + ], y: torch.Tensor, ) -> torch.Tensor: """Calculate the loss between the output and the input data. @@ -626,7 +626,7 @@ def test_step(self, batch: torch.Tensor, batch_idx: int): return self._single_step(batch, batch_idx, "test") def predict_step( - self, batch: torch.Tensor, batch_idx: int, dataloader_idx: int | None = None + self, batch: torch.Tensor, batch_idx: int, dataloader_idx: Optional[int] = None ): x, _ = batch return self.model(x)[0] diff --git a/minerva/models/nets/image/vit.py b/minerva/models/nets/image/vit.py index d87465a..238c63f 100644 --- a/minerva/models/nets/image/vit.py +++ b/minerva/models/nets/image/vit.py @@ -1,7 +1,7 @@ import math from collections import OrderedDict from functools import partial -from typing import Callable, List, Optional +from typing import Callable, List, Optional, Tuple, Union import lightning as L import torch @@ -31,7 +31,7 @@ def __init__( dropout: float, attention_dropout: float, aux_output: bool = False, - aux_output_layers: List[int] | None = None, + aux_output_layers: Optional[List[int]] = None, norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), ): super().__init__() @@ -81,7 +81,7 @@ class _VisionTransformerBackbone(nn.Module): def __init__( self, - image_size: int | tuple[int, int], + image_size: Union[int, Tuple[int, int]], patch_size: int, num_layers: int, num_heads: int, @@ -91,7 +91,7 @@ def __init__( attention_dropout: float = 0.0, num_classes: int = 1000, aux_output: bool = False, - aux_output_layers: List[int] | None = None, + aux_output_layers: Optional[List[int]] = None, norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), conv_stem_configs: Optional[List[ConvStemConfig]] = None, ): @@ -100,7 +100,7 @@ def __init__( Parameters ---------- - image_size : int or tuple[int, int] + image_size : int or Tuple[int, int] The size of the input image. If an int is provided, it is assumed to be a square image. If a tuple of ints is provided, it represents the height and width of the image. patch_size : int @@ -234,14 +234,14 @@ def __init__( if self.conv_proj.conv_last.bias is not None: nn.init.zeros_(self.conv_proj.conv_last.bias) - def _process_input(self, x: torch.Tensor) -> tuple[torch.Tensor, int, int]: + def _process_input(self, x: torch.Tensor) -> Tuple[torch.Tensor, int, int]: """Process the input tensor and return the reshaped tensor and dimensions. Args: x (torch.Tensor): The input tensor. Returns: - tuple[torch.Tensor, int, int]: The reshaped tensor, number of rows, and number of columns. + Tuple[torch.Tensor, int, int]: The reshaped tensor, number of rows, and number of columns. """ n, c, h, w = x.shape p = self.patch_size diff --git a/minerva/models/nets/time_series/cnns.py b/minerva/models/nets/time_series/cnns.py index 15fb9af..edb44b6 100644 --- a/minerva/models/nets/time_series/cnns.py +++ b/minerva/models/nets/time_series/cnns.py @@ -1,4 +1,4 @@ -from typing import List, Tuple +from typing import List, Tuple, Union import torch from torchmetrics import Accuracy @@ -7,7 +7,7 @@ class ZeroPadder2D(torch.nn.Module): - def __init__(self, pad_at: List[int], padding_size: int): + def __init__(self, pad_at: Tuple[int], padding_size: int): super().__init__() self.pad_at = pad_at self.padding_size = padding_size @@ -56,7 +56,9 @@ def __init__( test_metrics={"acc": Accuracy(task="multiclass", num_classes=num_classes)}, ) - def _create_backbone(self, input_shape: Tuple[int, int]) -> torch.nn.Module: + def _create_backbone( + self, input_shape: Union[Tuple[int, int], Tuple[int, int, int]] + ) -> torch.nn.Module: return torch.nn.Sequential( # First 2D convolutional layer torch.nn.Conv2d( @@ -120,7 +122,7 @@ def _create_fc(self, input_features: int, num_classes: int) -> torch.nn.Module: class CNN_HaEtAl_2D(SimpleSupervisedModel): def __init__( self, - pad_at: List[int] = (3,), + pad_at: Tuple[int] = (3,), input_shape: Tuple[int, int, int] = (1, 6, 60), num_classes: int = 6, learning_rate: float = 1e-3, @@ -144,7 +146,9 @@ def __init__( test_metrics={"acc": Accuracy(task="multiclass", num_classes=num_classes)}, ) - def _create_backbone(self, input_shape: Tuple[int, int]) -> torch.nn.Module: + def _create_backbone( + self, input_shape: Union[Tuple[int, int], Tuple[int, int, int]] + ) -> torch.nn.Module: first_kernel_size = 4 return torch.nn.Sequential( # Add padding diff --git a/minerva/models/nets/time_series/resnet.py b/minerva/models/nets/time_series/resnet.py index a31edba..774c3ea 100644 --- a/minerva/models/nets/time_series/resnet.py +++ b/minerva/models/nets/time_series/resnet.py @@ -1,4 +1,4 @@ -from typing import Tuple +from typing import Optional, Tuple import torch from torchmetrics import Accuracy @@ -7,7 +7,7 @@ class ConvolutionalBlock(torch.nn.Module): - def __init__(self, in_channels: int, activation_cls: torch.nn.Module = None): + def __init__(self, in_channels: int, activation_cls: torch.nn.Module): super().__init__() self.in_channels = in_channels self.activation_cls = activation_cls diff --git a/minerva/pipelines/base.py b/minerva/pipelines/base.py index 635e58f..8274189 100644 --- a/minerva/pipelines/base.py +++ b/minerva/pipelines/base.py @@ -1,22 +1,22 @@ -from abc import abstractmethod import copy import os -from pathlib import Path -from typing import Any, List, Dict -from uuid import uuid4 -from time import time -import traceback +import platform import sys -from lightning.pytorch.core.mixins import HyperparametersMixin -import pkg_resources -import yaml +import traceback +from abc import abstractmethod from datetime import datetime +from functools import cached_property +from pathlib import Path +from time import time +from typing import Any, Dict, List, Optional, Union +from uuid import uuid4 import git -import sys -import platform -from functools import cached_property +import pkg_resources +import yaml +from lightning.pytorch.core.mixins import HyperparametersMixin from lightning.pytorch.utilities import rank_zero_only + from minerva.utils.typing import PathLike @@ -43,8 +43,8 @@ class Pipeline(HyperparametersMixin): def __init__( self, - log_dir: Path | str = None, - ignore: str | List[str] = None, + log_dir: Optional[PathLike] = None, + ignore: Union[str, List[str], None] = None, cache_result: bool = False, save_run_status: bool = False, ): @@ -52,8 +52,8 @@ def __init__( Parameters ---------- - log_dir : Path | str, optional - The default logging directory where all related pipeline files + log_dir : PathLike, optional + The default logging directory where all related pipeline files should be saved. By default None (uses current working directory) ignore : str | List[str], optional Pipeline __init__ attributes are saved into config attibute. This @@ -132,9 +132,9 @@ def log_dir(self) -> Path: Path to the pipeline's log_dir """ return self._log_dir - + @log_dir.setter - def log_dir(self, value: Path | str): + def log_dir(self, value: PathLike): """Set the log_dir. Parameters @@ -216,8 +216,7 @@ def system_info(self) -> Dict[str, Any]: # ---------- Add python information ---------- packages = [ - f"{pkg.project_name}=={pkg.version}" - for pkg in pkg_resources.working_set + f"{pkg.project_name}=={pkg.version}" for pkg in pkg_resources.working_set ] d["python"] = { "pip_packages": packages, @@ -342,9 +341,7 @@ def run(self, *args, **kwargs) -> Any: if self._save_run_status: self._cached_run_status.append(self.run_status) - self._save_pipeline_info( - self._log_dir / f"run_{self.pipeline_id}.yaml" - ) + self._save_pipeline_info(self._log_dir / f"run_{self.pipeline_id}.yaml") return result diff --git a/minerva/pipelines/lightning_pipeline.py b/minerva/pipelines/lightning_pipeline.py index fe9842e..8260b6f 100644 --- a/minerva/pipelines/lightning_pipeline.py +++ b/minerva/pipelines/lightning_pipeline.py @@ -1,12 +1,14 @@ +from collections import defaultdict from pathlib import Path -from typing import Any, Dict, Iterable, List, Literal +from typing import Any, Dict, Iterable, List, Literal, Optional + import lightning as L import torch -from minerva.pipelines.base import Pipeline +import yaml from torchmetrics import Metric + +from minerva.pipelines.base import Pipeline from minerva.utils.data import get_full_data_split -from collections import defaultdict -import yaml from minerva.utils.typing import PathLike @@ -46,7 +48,7 @@ def __init__( trainer : L.Trainer The Lightning Trainer to be used. log_dir : PathLike, optional - The default logging directory where all related pipeline files + The default logging directory where all related pipeline files should be saved. By default None (uses current working directory) save_run_status : bool, optional If True, save the status of each run in a YAML file. This file will @@ -74,10 +76,10 @@ def __init__( metrics. If False, the metrics will be calculated for the whole dataset and the results will be a single metric (single-element list). By default False - """ + """ if log_dir is None and trainer.log_dir is not None: log_dir = trainer.log_dir - + super().__init__( log_dir=log_dir, ignore=[ @@ -158,15 +160,14 @@ def _calculate_metrics( for metric_name, metric in metrics.items(): final_results = [ - metric(y_i, y_hat_i).float().item() - for y_i, y_hat_i in zip(y, y_hat) + metric(y_i, y_hat_i).float().item() for y_i, y_hat_i in zip(y, y_hat) ] results[metric_name] = final_results return results # Private methods - def _fit(self, data: L.LightningDataModule, ckpt_path: PathLike): + def _fit(self, data: L.LightningDataModule, ckpt_path: Optional[PathLike]): """Fit the model using the given data. Parameters @@ -181,7 +182,7 @@ def _fit(self, data: L.LightningDataModule, ckpt_path: PathLike): model=self._model, datamodule=data, ckpt_path=ckpt_path ) - def _test(self, data: L.LightningDataModule, ckpt_path: PathLike): + def _test(self, data: L.LightningDataModule, ckpt_path: Optional[PathLike]): """Test the model using the given data. Parameters @@ -199,7 +200,7 @@ def _test(self, data: L.LightningDataModule, ckpt_path: PathLike): def _predict( self, data: L.LightningDataModule, - ckpt_path: PathLike = None, + ckpt_path: Optional[PathLike] = None, ) -> torch.Tensor: """Predict using the given data. @@ -223,7 +224,7 @@ def _predict( def _evaluate( self, data: L.LightningDataModule, - ckpt_path: PathLike = None, + ckpt_path: Optional[PathLike] = None, ) -> Dict[str, Any]: """Evaluate the model and calculate regression and/or classification metrics. @@ -246,9 +247,7 @@ def _evaluate( X, y = get_full_data_split(data, "predict") y = torch.tensor(y, device="cpu") - y_hat = self.trainer.predict( - self._model, datamodule=data, ckpt_path=ckpt_path - ) + y_hat = self.trainer.predict(self._model, datamodule=data, ckpt_path=ckpt_path) y_hat = torch.cat(y_hat).detach().cpu() if self._classification_metrics is not None: @@ -280,7 +279,7 @@ def _run( self, data: L.LightningDataModule, task: Literal["fit", "test", "predict", "evaluate"], - ckpt_path: PathLike = None, + ckpt_path: Optional[PathLike] = None, ): """ Run the specified task on the given data. diff --git a/minerva/transforms/transform.py b/minerva/transforms/transform.py index cb58bda..9d2dd28 100644 --- a/minerva/transforms/transform.py +++ b/minerva/transforms/transform.py @@ -1,5 +1,5 @@ from itertools import product -from typing import Any, List, Sequence +from typing import Any, List, Sequence, Union import numpy as np import torch @@ -50,7 +50,7 @@ def __call__(self, x: Any) -> Any: class Flip(_Transform): """Flip the input data along the specified axis.""" - def __init__(self, axis: int | List[int] = 0): + def __init__(self, axis: Union[int, List[int]] = 0): """Flip the input data along the specified axis. Parameters @@ -95,9 +95,7 @@ def __init__(self, octaves: int, scale: float = 1): Optionally rescale the Perlin noise. Default is 1 (no rescaling) """ if octaves <= 0: - raise ValueError( - f"Number of octaves must be positive, but got {octaves=}" - ) + raise ValueError(f"Number of octaves must be positive, but got {octaves=}") if scale == 0: raise ValueError(f"Scale can't be 0") self.octaves = octaves @@ -161,7 +159,7 @@ def __call__(self, x: np.ndarray) -> np.ndarray: class CastTo(_Transform): """Cast the input data to the specified data type.""" - def __init__(self, dtype: type | str): + def __init__(self, dtype: Union[type, str]): """Cast the input data to the specified data type. Parameters @@ -192,4 +190,4 @@ def __call__(self, x: np.ndarray) -> np.ndarray: padded = np.pad(x, ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect") padded = np.transpose(padded, (2, 0, 1)) - return padded \ No newline at end of file + return padded diff --git a/minerva/utils/data.py b/minerva/utils/data.py index 6a7bdf5..1bca1a3 100644 --- a/minerva/utils/data.py +++ b/minerva/utils/data.py @@ -1,11 +1,12 @@ -from typing import Tuple +from typing import Optional, Tuple, Union + +import lightning as L import torch from torch.utils.data import DataLoader -import lightning as L class SimpleDataset: - def __init__(self, data: torch.Tensor, label: torch.Tensor = None): + def __init__(self, data: torch.Tensor, label: Optional[torch.Tensor] = None): self.data = data self.label = label @@ -23,8 +24,8 @@ class RandomDataModule(L.LightningDataModule): def __init__( self, data_shape: Tuple[int, ...], - label_shape: int | Tuple[int, ...] = None, - num_classes: int = None, + label_shape: Union[int, Tuple[int, ...], None] = None, + num_classes: Optional[int] = None, num_train_samples: int = 128, num_val_samples: int = 8, num_test_samples: int = 8, @@ -54,17 +55,13 @@ def __init__( delattr(self, "val_dataloader") if num_test_samples is not None: - assert ( - num_test_samples > 0 - ), "num_test_samples must be greater than 0" + assert num_test_samples > 0, "num_test_samples must be greater than 0" else: delattr(self, "test_dataloader") # label_shape and num_classes are mutually exclusive if label_shape is not None and num_classes is not None: - raise ValueError( - "label_shape and num_classes are mutually exclusive" - ) + raise ValueError("label_shape and num_classes are mutually exclusive") def _generate_data(self, num_samples, data_shape, label_shape, num_classes): data = torch.rand((num_samples, *data_shape)) @@ -103,7 +100,7 @@ def setup(self, stage): self.num_classes, ) self.test_data = SimpleDataset(data, label) - + elif stage == "predict": if self.num_predict_samples is not None: data, label = self._generate_data( @@ -113,34 +110,24 @@ def setup(self, stage): self.num_classes, ) self.predict_data = SimpleDataset(data, label) - + else: raise ValueError(f"Invalid stage: {stage}") def train_dataloader(self): - return DataLoader( - self.train_data, batch_size=self.batch_size, shuffle=True - ) + return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True) def val_dataloader(self): - return DataLoader( - self.val_data, batch_size=self.batch_size, shuffle=False - ) + return DataLoader(self.val_data, batch_size=self.batch_size, shuffle=False) def test_dataloader(self): - return DataLoader( - self.test_data, batch_size=self.batch_size, shuffle=False - ) - + return DataLoader(self.test_data, batch_size=self.batch_size, shuffle=False) + def predict_dataloader(self): - return DataLoader( - self.predict_data, batch_size=self.batch_size, shuffle=False - ) + return DataLoader(self.predict_data, batch_size=self.batch_size, shuffle=False) -def get_split_dataloader( - stage: str, data_module: L.LightningDataModule -) -> DataLoader: +def get_split_dataloader(stage: str, data_module: L.LightningDataModule) -> DataLoader: if stage == "train": data_module.setup("fit") return data_module.train_dataloader() diff --git a/minerva/utils/position_embedding.py b/minerva/utils/position_embedding.py index bf793d1..0be7959 100644 --- a/minerva/utils/position_embedding.py +++ b/minerva/utils/position_embedding.py @@ -13,7 +13,7 @@ # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py # MoCo v3: https://github.com/facebookresearch/moco-v3 # -------------------------------------------------------- -def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): +def get_2d_sincos_pos_embed(embed_dim: int, grid_size: int, cls_token=False): """ grid_size: int of the grid height and width return: diff --git a/minerva/utils/tensor.py b/minerva/utils/tensor.py index 2902697..949f72a 100644 --- a/minerva/utils/tensor.py +++ b/minerva/utils/tensor.py @@ -9,14 +9,16 @@ def to_tensor(x, dtype=None) -> torch.Tensor: if dtype is not None: x = x.type(dtype) return x - if isinstance(x, np.ndarray): + elif isinstance(x, np.ndarray): x = torch.from_numpy(x) if dtype is not None: x = x.type(dtype) return x - if isinstance(x, Iterable): + elif isinstance(x, Iterable): x = np.array(x) x = torch.from_numpy(x) if dtype is not None: x = x.type(dtype) return x + else: + raise TypeError(f"Unsupported type: {type(x)}")