diff --git a/mambular/arch_utils/mamba_arch.py b/mambular/arch_utils/mamba_arch.py index 3db39ed..6417c7f 100644 --- a/mambular/arch_utils/mamba_arch.py +++ b/mambular/arch_utils/mamba_arch.py @@ -43,6 +43,9 @@ def __init__( activation=F.silu, bidirectional=False, use_learnable_interaction=False, + layer_norm_eps=1e-05, + AB_weight_decay=False, + AB_layer_norm=True, ): super().__init__() @@ -66,6 +69,9 @@ def __init__( activation, bidirectional, use_learnable_interaction, + layer_norm_eps, + AB_weight_decay, + AB_layer_norm, ) for _ in range(n_layers) ] @@ -105,6 +111,9 @@ def __init__( activation=F.silu, bidirectional=False, use_learnable_interaction=False, + layer_norm_eps=1e-05, + AB_weight_decay=False, + AB_layer_norm=False, ): super().__init__() @@ -149,8 +158,11 @@ def __init__( activation=activation, bidirectional=bidirectional, use_learnable_interaction=use_learnable_interaction, + layer_norm_eps=layer_norm_eps, + AB_weight_decay=AB_weight_decay, + AB_layer_norm=AB_layer_norm, ) - self.norm = norm(d_model) + self.norm = norm(d_model, eps=layer_norm_eps) def forward(self, x): output = self.layers(self.norm(x)) + x @@ -189,6 +201,9 @@ def __init__( activation=F.silu, bidirectional=False, use_learnable_interaction=False, + layer_norm_eps=1e-05, + AB_weight_decay=False, + AB_layer_norm=False, ): super().__init__() self.d_inner = d_model * expand_factor @@ -239,6 +254,7 @@ def __init__( elif dt_init == "random": nn.init.uniform_(self.dt_proj_fwd.weight, -dt_init_std, dt_init_std) if self.bidirectional: + nn.init.uniform_(self.dt_proj_bwd.weight, -dt_init_std, dt_init_std) else: raise NotImplementedError @@ -262,17 +278,35 @@ def __init__( A = torch.arange(1, d_state + 1, dtype=torch.float32).repeat(self.d_inner, 1) self.A_log_fwd = nn.Parameter(torch.log(A)) + self.D_fwd = nn.Parameter(torch.ones(self.d_inner)) + if self.bidirectional: self.A_log_bwd = nn.Parameter(torch.log(A)) + self.D_bwd = nn.Parameter(torch.ones(self.d_inner)) + + if not AB_weight_decay: + self.A_log_fwd._no_weight_decay = True + self.D_fwd._no_weight_decay = True - self.D_fwd = nn.Parameter(torch.ones(self.d_inner)) if self.bidirectional: - self.D_bwd = nn.Parameter(torch.ones(self.d_inner)) + + if not AB_weight_decay: + self.A_log_bwd._no_weight_decay = True + self.D_bwd._no_weight_decay = True self.out_proj = nn.Linear(self.d_inner, d_model, bias=bias) self.dt_rank = dt_rank self.d_state = d_state + if AB_layer_norm: + self.dt_layernorm = RMSNorm(self.dt_rank, eps=layer_norm_eps) + self.B_layernorm = RMSNorm(self.d_state, eps=layer_norm_eps) + self.C_layernorm = RMSNorm(self.d_state, eps=layer_norm_eps) + else: + self.dt_layernorm = None + self.B_layernorm = None + self.C_layernorm = None + def forward(self, x): _, L, _ = x.shape @@ -316,6 +350,15 @@ def forward(self, x): return output + def _apply_layernorms(self, dt, B, C): + if self.dt_layernorm is not None: + dt = self.dt_layernorm(dt) + if self.B_layernorm is not None: + B = self.B_layernorm(B) + if self.C_layernorm is not None: + C = self.C_layernorm(C) + return dt, B, C + def ssm(self, x, forward=True): if forward: A = -torch.exp(self.A_log_fwd.float()) @@ -324,6 +367,7 @@ def ssm(self, x, forward=True): delta, B, C = torch.split( deltaBC, [self.dt_rank, self.d_state, self.d_state], dim=-1 ) + delta, B, C = self._apply_layernorms(delta, B, C) delta = F.softplus(self.dt_proj_fwd(delta)) else: A = -torch.exp(self.A_log_bwd.float()) @@ -332,6 +376,7 @@ def ssm(self, x, forward=True): delta, B, C = torch.split( deltaBC, [self.dt_rank, self.d_state, self.d_state], dim=-1 ) + delta, B, C = self._apply_layernorms(delta, B, C) delta = F.softplus(self.dt_proj_bwd(delta)) y = self.selective_scan_seq(x, delta, A, B, C, D) diff --git a/mambular/base_models/mambular.py b/mambular/base_models/mambular.py index 087f0a9..1b3a332 100644 --- a/mambular/base_models/mambular.py +++ b/mambular/base_models/mambular.py @@ -1,209 +1,224 @@ -import torch -import torch.nn as nn -from ..arch_utils.mamba_arch import Mamba -from ..arch_utils.mlp_utils import MLP -from ..arch_utils.normalization_layers import ( - RMSNorm, - LayerNorm, - LearnableLayerScaling, - BatchNorm, - InstanceNorm, - GroupNorm, -) -from ..configs.mambular_config import DefaultMambularConfig -from .basemodel import BaseModel -from ..arch_utils.embedding_layer import EmbeddingLayer - - -class Mambular(BaseModel): - """ - A PyTorch model for tasks utilizing the Mamba architecture and various normalization techniques. - - Parameters - ---------- - cat_feature_info : dict - Dictionary containing information about categorical features. - num_feature_info : dict - Dictionary containing information about numerical features. - num_classes : int, optional - Number of output classes (default is 1). - config : DefaultMambularConfig, optional - Configuration object containing default hyperparameters for the model (default is DefaultMambularConfig()). - **kwargs : dict - Additional keyword arguments. - - Attributes - ---------- - lr : float - Learning rate. - lr_patience : int - Patience for learning rate scheduler. - weight_decay : float - Weight decay for optimizer. - lr_factor : float - Factor by which the learning rate will be reduced. - pooling_method : str - Method to pool the features. - cat_feature_info : dict - Dictionary containing information about categorical features. - num_feature_info : dict - Dictionary containing information about numerical features. - embedding_activation : callable - Activation function for embeddings. - mamba : Mamba - Mamba architecture component. - norm_f : nn.Module - Normalization layer. - num_embeddings : nn.ModuleList - Module list for numerical feature embeddings. - cat_embeddings : nn.ModuleList - Module list for categorical feature embeddings. - tabular_head : MLP - Multi-layer perceptron head for tabular data. - cls_token : nn.Parameter - Class token parameter. - embedding_norm : nn.Module, optional - Layer normalization applied after embedding if specified. - """ - - def __init__( - self, - cat_feature_info, - num_feature_info, - num_classes=1, - config: DefaultMambularConfig = DefaultMambularConfig(), - **kwargs, - ): - super().__init__(**kwargs) - self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"]) - - self.lr = self.hparams.get("lr", config.lr) - self.lr_patience = self.hparams.get("lr_patience", config.lr_patience) - self.weight_decay = self.hparams.get("weight_decay", config.weight_decay) - self.lr_factor = self.hparams.get("lr_factor", config.lr_factor) - self.pooling_method = self.hparams.get("pooling_method", config.pooling_method) - self.shuffle_embeddings = self.hparams.get( - "shuffle_embeddings", config.shuffle_embeddings - ) - self.cat_feature_info = cat_feature_info - self.num_feature_info = num_feature_info - - self.mamba = Mamba( - d_model=self.hparams.get("d_model", config.d_model), - n_layers=self.hparams.get("n_layers", config.n_layers), - expand_factor=self.hparams.get("expand_factor", config.expand_factor), - bias=self.hparams.get("bias", config.bias), - d_conv=self.hparams.get("d_conv", config.d_conv), - conv_bias=self.hparams.get("conv_bias", config.conv_bias), - dropout=self.hparams.get("dropout", config.dropout), - dt_rank=self.hparams.get("dt_rank", config.dt_rank), - d_state=self.hparams.get("d_state", config.d_state), - dt_scale=self.hparams.get("dt_scale", config.dt_scale), - dt_init=self.hparams.get("dt_init", config.dt_init), - dt_max=self.hparams.get("dt_max", config.dt_max), - dt_min=self.hparams.get("dt_min", config.dt_min), - dt_init_floor=self.hparams.get("dt_init_floor", config.dt_init_floor), - norm=globals()[self.hparams.get("norm", config.norm)], - activation=self.hparams.get("activation", config.activation), - bidirectional=self.hparams.get("bidiretional", config.bidirectional), - use_learnable_interaction=self.hparams.get( - "use_learnable_interactions", config.use_learnable_interaction - ), - ) - - norm_layer = self.hparams.get("norm", config.norm) - if norm_layer == "RMSNorm": - self.norm_f = RMSNorm(self.hparams.get("d_model", config.d_model)) - elif norm_layer == "LayerNorm": - self.norm_f = LayerNorm(self.hparams.get("d_model", config.d_model)) - elif norm_layer == "BatchNorm": - self.norm_f = BatchNorm(self.hparams.get("d_model", config.d_model)) - elif norm_layer == "InstanceNorm": - self.norm_f = InstanceNorm(self.hparams.get("d_model", config.d_model)) - elif norm_layer == "GroupNorm": - self.norm_f = GroupNorm(1, self.hparams.get("d_model", config.d_model)) - elif norm_layer == "LearnableLayerScaling": - self.norm_f = LearnableLayerScaling( - self.hparams.get("d_model", config.d_model) - ) - else: - raise ValueError(f"Unsupported normalization layer: {norm_layer}") - - self.embedding_layer = EmbeddingLayer( - num_feature_info=num_feature_info, - cat_feature_info=cat_feature_info, - d_model=self.hparams.get("d_model", config.d_model), - embedding_activation=self.hparams.get( - "embedding_activation", config.embedding_activation - ), - layer_norm_after_embedding=self.hparams.get("layer_norm_after_embedding"), - use_cls=True, - cls_position=0, - ) - - head_activation = self.hparams.get("head_activation", config.head_activation) - - self.tabular_head = MLP( - self.hparams.get("d_model", config.d_model), - hidden_units_list=self.hparams.get( - "head_layer_sizes", config.head_layer_sizes - ), - dropout_rate=self.hparams.get("head_dropout", config.head_dropout), - use_skip_layers=self.hparams.get( - "head_skip_layers", config.head_skip_layers - ), - activation_fn=head_activation, - use_batch_norm=self.hparams.get( - "head_use_batch_norm", config.head_use_batch_norm - ), - n_output_units=num_classes, - ) - - if self.pooling_method == "cls": - self.use_cls = True - else: - self.use_cls = self.hparams.get("use_cls", config.use_cls) - - if self.shuffle_embeddings: - self.perm = torch.randperm(self.embedding_layer.seq_len) - - def forward(self, num_features, cat_features): - """ - Defines the forward pass of the model. - - Parameters - ---------- - num_features : Tensor - Tensor containing the numerical features. - cat_features : Tensor - Tensor containing the categorical features. - - Returns - ------- - Tensor - The output predictions of the model. - """ - x = self.embedding_layer(num_features, cat_features) - - if self.shuffle_embeddings: - x = x[:, self.perm, :] - - x = self.mamba(x) - - if self.pooling_method == "avg": - x = torch.mean(x, dim=1) - elif self.pooling_method == "max": - x, _ = torch.max(x, dim=1) - elif self.pooling_method == "sum": - x = torch.sum(x, dim=1) - elif self.pooling_method == "cls_token": - x = x[:, -1] - elif self.pooling_method == "last": - x = x[:, -1] - else: - raise ValueError(f"Invalid pooling method: {self.pooling_method}") - - x = self.norm_f(x) - preds = self.tabular_head(x) - - return preds +import torch +import torch.nn as nn +from ..arch_utils.mamba_arch import Mamba +from ..arch_utils.mlp_utils import MLP +from ..arch_utils.normalization_layers import ( + RMSNorm, + LayerNorm, + LearnableLayerScaling, + BatchNorm, + InstanceNorm, + GroupNorm, +) +from ..configs.mambular_config import DefaultMambularConfig +from .basemodel import BaseModel +from ..arch_utils.embedding_layer import EmbeddingLayer + + +class Mambular(BaseModel): + """ + A PyTorch model for tasks utilizing the Mamba architecture and various normalization techniques. + + Parameters + ---------- + cat_feature_info : dict + Dictionary containing information about categorical features. + num_feature_info : dict + Dictionary containing information about numerical features. + num_classes : int, optional + Number of output classes (default is 1). + config : DefaultMambularConfig, optional + Configuration object containing default hyperparameters for the model (default is DefaultMambularConfig()). + **kwargs : dict + Additional keyword arguments. + + Attributes + ---------- + lr : float + Learning rate. + lr_patience : int + Patience for learning rate scheduler. + weight_decay : float + Weight decay for optimizer. + lr_factor : float + Factor by which the learning rate will be reduced. + pooling_method : str + Method to pool the features. + cat_feature_info : dict + Dictionary containing information about categorical features. + num_feature_info : dict + Dictionary containing information about numerical features. + embedding_activation : callable + Activation function for embeddings. + mamba : Mamba + Mamba architecture component. + norm_f : nn.Module + Normalization layer. + num_embeddings : nn.ModuleList + Module list for numerical feature embeddings. + cat_embeddings : nn.ModuleList + Module list for categorical feature embeddings. + tabular_head : MLP + Multi-layer perceptron head for tabular data. + cls_token : nn.Parameter + Class token parameter. + embedding_norm : nn.Module, optional + Layer normalization applied after embedding if specified. + """ + + def __init__( + self, + cat_feature_info, + num_feature_info, + num_classes=1, + config: DefaultMambularConfig = DefaultMambularConfig(), + **kwargs, + ): + super().__init__(**kwargs) + self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"]) + + self.lr = self.hparams.get("lr", config.lr) + self.lr_patience = self.hparams.get("lr_patience", config.lr_patience) + self.weight_decay = self.hparams.get("weight_decay", config.weight_decay) + self.lr_factor = self.hparams.get("lr_factor", config.lr_factor) + self.pooling_method = self.hparams.get("pooling_method", config.pooling_method) + self.shuffle_embeddings = self.hparams.get( + "shuffle_embeddings", config.shuffle_embeddings + ) + self.cat_feature_info = cat_feature_info + self.num_feature_info = num_feature_info + + self.mamba = Mamba( + d_model=self.hparams.get("d_model", config.d_model), + n_layers=self.hparams.get("n_layers", config.n_layers), + expand_factor=self.hparams.get("expand_factor", config.expand_factor), + bias=self.hparams.get("bias", config.bias), + d_conv=self.hparams.get("d_conv", config.d_conv), + conv_bias=self.hparams.get("conv_bias", config.conv_bias), + dropout=self.hparams.get("dropout", config.dropout), + dt_rank=self.hparams.get("dt_rank", config.dt_rank), + d_state=self.hparams.get("d_state", config.d_state), + dt_scale=self.hparams.get("dt_scale", config.dt_scale), + dt_init=self.hparams.get("dt_init", config.dt_init), + dt_max=self.hparams.get("dt_max", config.dt_max), + dt_min=self.hparams.get("dt_min", config.dt_min), + dt_init_floor=self.hparams.get("dt_init_floor", config.dt_init_floor), + norm=globals()[self.hparams.get("norm", config.norm)], + activation=self.hparams.get("activation", config.activation), + bidirectional=self.hparams.get("bidiretional", config.bidirectional), + use_learnable_interaction=self.hparams.get( + "use_learnable_interactions", config.use_learnable_interaction + ), + AB_weight_decay=self.hparams.get("AB_weight_decay", config.AB_weight_decay), + AB_layer_norm=self.hparams.get("AB_layer_norm", config.AB_layer_norm), + layer_norm_eps=self.hparams.get("layer_norm_eps", config.layer_norm_eps), + ) + + norm_layer = self.hparams.get("norm", config.norm) + if norm_layer == "RMSNorm": + self.norm_f = RMSNorm( + self.hparams.get("d_model", config.d_model), eps=config.layer_norm_eps + ) + elif norm_layer == "LayerNorm": + self.norm_f = LayerNorm( + self.hparams.get("d_model", config.d_model), eps=config.layer_norm_eps + ) + elif norm_layer == "BatchNorm": + self.norm_f = BatchNorm( + self.hparams.get("d_model", config.d_model), eps=config.layer_norm_eps + ) + elif norm_layer == "InstanceNorm": + self.norm_f = InstanceNorm( + self.hparams.get("d_model", config.d_model), eps=config.layer_norm_eps + ) + elif norm_layer == "GroupNorm": + self.norm_f = GroupNorm( + 1, + self.hparams.get("d_model", config.d_model), + eps=config.layer_norm_eps, + ) + elif norm_layer == "LearnableLayerScaling": + self.norm_f = LearnableLayerScaling( + self.hparams.get("d_model", config.d_model) + ) + else: + raise ValueError(f"Unsupported normalization layer: {norm_layer}") + + self.embedding_layer = EmbeddingLayer( + num_feature_info=num_feature_info, + cat_feature_info=cat_feature_info, + d_model=self.hparams.get("d_model", config.d_model), + embedding_activation=self.hparams.get( + "embedding_activation", config.embedding_activation + ), + layer_norm_after_embedding=self.hparams.get("layer_norm_after_embedding"), + use_cls=True, + cls_position=0, + ) + + head_activation = self.hparams.get("head_activation", config.head_activation) + + self.tabular_head = MLP( + self.hparams.get("d_model", config.d_model), + hidden_units_list=self.hparams.get( + "head_layer_sizes", config.head_layer_sizes + ), + dropout_rate=self.hparams.get("head_dropout", config.head_dropout), + use_skip_layers=self.hparams.get( + "head_skip_layers", config.head_skip_layers + ), + activation_fn=head_activation, + use_batch_norm=self.hparams.get( + "head_use_batch_norm", config.head_use_batch_norm + ), + n_output_units=num_classes, + ) + + if self.pooling_method == "cls": + self.use_cls = True + else: + self.use_cls = self.hparams.get("use_cls", config.use_cls) + + if self.shuffle_embeddings: + self.perm = torch.randperm(self.embedding_layer.seq_len) + + def forward(self, num_features, cat_features): + """ + Defines the forward pass of the model. + + Parameters + ---------- + num_features : Tensor + Tensor containing the numerical features. + cat_features : Tensor + Tensor containing the categorical features. + + Returns + ------- + Tensor + The output predictions of the model. + """ + x = self.embedding_layer(num_features, cat_features) + + if self.shuffle_embeddings: + x = x[:, self.perm, :] + + x = self.mamba(x) + + if self.pooling_method == "avg": + x = torch.mean(x, dim=1) + elif self.pooling_method == "max": + x, _ = torch.max(x, dim=1) + elif self.pooling_method == "sum": + x = torch.sum(x, dim=1) + elif self.pooling_method == "cls_token": + x = x[:, -1] + elif self.pooling_method == "last": + x = x[:, -1] + else: + raise ValueError(f"Invalid pooling method: {self.pooling_method}") + + x = self.norm_f(x) + preds = self.tabular_head(x) + + return preds diff --git a/mambular/configs/mambular_config.py b/mambular/configs/mambular_config.py index 24ce13f..176ae96 100644 --- a/mambular/configs/mambular_config.py +++ b/mambular/configs/mambular_config.py @@ -1,109 +1,116 @@ -from dataclasses import dataclass -import torch.nn as nn - - -@dataclass -class DefaultMambularConfig: - """ - Configuration class for the Default Mambular model with predefined hyperparameters. - - Parameters - ---------- - lr : float, default=1e-04 - Learning rate for the optimizer. - lr_patience : int, default=10 - Number of epochs with no improvement after which learning rate will be reduced. - weight_decay : float, default=1e-06 - Weight decay (L2 penalty) for the optimizer. - lr_factor : float, default=0.1 - Factor by which the learning rate will be reduced. - d_model : int, default=64 - Dimensionality of the model. - n_layers : int, default=8 - Number of layers in the model. - expand_factor : int, default=2 - Expansion factor for the feed-forward layers. - bias : bool, default=False - Whether to use bias in the linear layers. - d_conv : int, default=16 - Dimensionality of the convolutional layers. - conv_bias : bool, default=True - Whether to use bias in the convolutional layers. - dropout : float, default=0.05 - Dropout rate for regularization. - dt_rank : str, default="auto" - Rank of the decision tree. - d_state : int, default=32 - Dimensionality of the state in recurrent layers. - dt_scale : float, default=1.0 - Scaling factor for decision tree. - dt_init : str, default="random" - Initialization method for decision tree. - dt_max : float, default=0.1 - Maximum value for decision tree initialization. - dt_min : float, default=1e-04 - Minimum value for decision tree initialization. - dt_init_floor : float, default=1e-04 - Floor value for decision tree initialization. - norm : str, default="RMSNorm" - Normalization method to be used. - activation : callable, default=nn.SELU() - Activation function for the model. - embedding_activation : callable, default=nn.Identity() - Activation function for embeddings. - head_layer_sizes : list, default=(128, 64, 32) - Sizes of the layers in the head of the model. - head_dropout : float, default=0.5 - Dropout rate for the head layers. - head_skip_layers : bool, default=False - Whether to skip layers in the head. - head_activation : callable, default=nn.SELU() - Activation function for the head layers. - head_use_batch_norm : bool, default=False - Whether to use batch normalization in the head layers. - layer_norm_after_embedding : bool, default=False - Whether to apply layer normalization after embedding. - pooling_method : str, default="avg" - Pooling method to be used ('avg', 'max', etc.). - bidirectional : bool, default=False - Whether to use bidirectional processing of the input sequences. - use_learnable_interaction : bool, default=False - Whether to use learnable feature interactions before passing through mamba blocks. - use_cls : bool, default=True - Whether to append a cls to the end of each 'sequence'. - shuffle_embeddings : bool, default=False. - Whether to shuffle the embeddings before being passed to the Mamba layers. - """ - - lr: float = 1e-04 - lr_patience: int = 10 - weight_decay: float = 1e-06 - lr_factor: float = 0.1 - d_model: int = 64 - n_layers: int = 4 - expand_factor: int = 2 - bias: bool = False - d_conv: int = 4 - conv_bias: bool = True - dropout: float = 0.0 - dt_rank: str = "auto" - d_state: int = 128 - dt_scale: float = 1.0 - dt_init: str = "random" - dt_max: float = 0.1 - dt_min: float = 1e-04 - dt_init_floor: float = 1e-04 - norm: str = "LayerNorm" - activation: callable = nn.SiLU() - embedding_activation: callable = nn.Identity() - head_layer_sizes: list = () - head_dropout: float = 0.5 - head_skip_layers: bool = False - head_activation: callable = nn.SELU() - head_use_batch_norm: bool = False - layer_norm_after_embedding: bool = False - pooling_method: str = "avg" - bidirectional: bool = False - use_learnable_interaction: bool = False - use_cls: bool = False - shuffle_embeddings: bool = False +from dataclasses import dataclass +import torch.nn as nn + + +@dataclass +class DefaultMambularConfig: + """ + Configuration class for the Default Mambular model with predefined hyperparameters. + + Parameters + ---------- + lr : float, default=1e-04 + Learning rate for the optimizer. + lr_patience : int, default=10 + Number of epochs with no improvement after which learning rate will be reduced. + weight_decay : float, default=1e-06 + Weight decay (L2 penalty) for the optimizer. + lr_factor : float, default=0.1 + Factor by which the learning rate will be reduced. + d_model : int, default=64 + Dimensionality of the model. + n_layers : int, default=8 + Number of layers in the model. + expand_factor : int, default=2 + Expansion factor for the feed-forward layers. + bias : bool, default=False + Whether to use bias in the linear layers. + d_conv : int, default=16 + Dimensionality of the convolutional layers. + conv_bias : bool, default=True + Whether to use bias in the convolutional layers. + dropout : float, default=0.05 + Dropout rate for regularization. + dt_rank : str, default="auto" + Rank of the decision tree. + d_state : int, default=32 + Dimensionality of the state in recurrent layers. + dt_scale : float, default=1.0 + Scaling factor for decision tree. + dt_init : str, default="random" + Initialization method for decision tree. + dt_max : float, default=0.1 + Maximum value for decision tree initialization. + dt_min : float, default=1e-04 + Minimum value for decision tree initialization. + dt_init_floor : float, default=1e-04 + Floor value for decision tree initialization. + norm : str, default="RMSNorm" + Normalization method to be used. + activation : callable, default=nn.SELU() + Activation function for the model. + embedding_activation : callable, default=nn.Identity() + Activation function for embeddings. + head_layer_sizes : list, default=(128, 64, 32) + Sizes of the layers in the head of the model. + head_dropout : float, default=0.5 + Dropout rate for the head layers. + head_skip_layers : bool, default=False + Whether to skip layers in the head. + head_activation : callable, default=nn.SELU() + Activation function for the head layers. + head_use_batch_norm : bool, default=False + Whether to use batch normalization in the head layers. + layer_norm_after_embedding : bool, default=False + Whether to apply layer normalization after embedding. + pooling_method : str, default="avg" + Pooling method to be used ('avg', 'max', etc.). + bidirectional : bool, default=False + Whether to use bidirectional processing of the input sequences. + use_learnable_interaction : bool, default=False + Whether to use learnable feature interactions before passing through mamba blocks. + use_cls : bool, default=True + Whether to append a cls to the end of each 'sequence'. + shuffle_embeddings : bool, default=False. + Whether to shuffle the embeddings before being passed to the Mamba layers. + layer_norm_eps : float, default=1e-05 + Epsilon value for layer normalization. + AB_weight_decay : bool, default=False + wether weight decay is also applied to A-B matrices + """ + + lr: float = 1e-04 + lr_patience: int = 10 + weight_decay: float = 1e-06 + lr_factor: float = 0.1 + d_model: int = 64 + n_layers: int = 4 + expand_factor: int = 2 + bias: bool = False + d_conv: int = 4 + conv_bias: bool = True + dropout: float = 0.0 + dt_rank: str = "auto" + d_state: int = 128 + dt_scale: float = 1.0 + dt_init: str = "random" + dt_max: float = 0.1 + dt_min: float = 1e-04 + dt_init_floor: float = 1e-04 + norm: str = "LayerNorm" + activation: callable = nn.SiLU() + embedding_activation: callable = nn.Identity() + head_layer_sizes: list = () + head_dropout: float = 0.5 + head_skip_layers: bool = False + head_activation: callable = nn.SELU() + head_use_batch_norm: bool = False + layer_norm_after_embedding: bool = False + pooling_method: str = "avg" + bidirectional: bool = False + use_learnable_interaction: bool = False + use_cls: bool = False + shuffle_embeddings: bool = False + layer_norm_eps: float = 1e-05 + AB_weight_decay: bool = False + AB_layer_norm: bool = True diff --git a/mambular/models/sklearn_base_classifier.py b/mambular/models/sklearn_base_classifier.py index c2d6f17..0e67db3 100644 --- a/mambular/models/sklearn_base_classifier.py +++ b/mambular/models/sklearn_base_classifier.py @@ -1,557 +1,557 @@ -import lightning as pl -import pandas as pd -import torch -from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint -from sklearn.base import BaseEstimator -from sklearn.metrics import accuracy_score -import warnings -from ..base_models.lightning_wrapper import TaskModel -from ..data_utils.datamodule import MambularDataModule -from ..preprocessing import Preprocessor -import numpy as np - - -class SklearnBaseClassifier(BaseEstimator): - def __init__(self, model, config, **kwargs): - preprocessor_arg_names = [ - "n_bins", - "numerical_preprocessing", - "use_decision_tree_bins", - "binning_strategy", - "task", - "cat_cutoff", - "treat_all_integers_as_numerical", - "knots", - "degree", - ] - - self.config_kwargs = { - k: v for k, v in kwargs.items() if k not in preprocessor_arg_names - } - self.config = config(**self.config_kwargs) - - preprocessor_kwargs = { - k: v for k, v in kwargs.items() if k in preprocessor_arg_names - } - - self.preprocessor = Preprocessor(**preprocessor_kwargs) - self.model = None - - # Raise a warning if task is set to 'classification' - if preprocessor_kwargs.get("task") == "regression": - warnings.warn( - "The task is set to 'regression'. The Classifier is designed for classification tasks.", - UserWarning, - ) - - self.base_model = model - self.built = False - - def get_params(self, deep=True): - """ - Get parameters for this estimator. Overrides the BaseEstimator method. - - Parameters - ---------- - deep : bool, default=True - If True, returns the parameters for this estimator and contained sub-objects that are estimators. - - Returns - ------- - params : dict - Parameter names mapped to their values. - """ - params = self.config_kwargs # Parameters used to initialize DefaultConfig - - # If deep=True, include parameters from nested components like preprocessor - if deep: - # Assuming Preprocessor has a get_params method - preprocessor_params = { - "preprocessor__" + key: value - for key, value in self.preprocessor.get_params().items() - } - params.update(preprocessor_params) - - return params - - def set_params(self, **parameters): - """ - Set the parameters of this estimator. Overrides the BaseEstimator method. - - Parameters - ---------- - **parameters : dict - Estimator parameters to be set. - - Returns - ------- - self : object - The instance with updated parameters. - """ - # Update config_kwargs with provided parameters - valid_config_keys = self.config_kwargs.keys() - config_updates = {k: v for k, v in parameters.items() if k in valid_config_keys} - self.config_kwargs.update(config_updates) - - # Update the config object - for key, value in config_updates.items(): - setattr(self.config, key, value) - - # Handle preprocessor parameters (prefixed with 'preprocessor__') - preprocessor_params = { - k.split("__")[1]: v - for k, v in parameters.items() - if k.startswith("preprocessor__") - } - if preprocessor_params: - # Assuming Preprocessor has a set_params method - self.preprocessor.set_params(**preprocessor_params) - - return self - - def build_model( - self, - X, - y, - val_size: float = 0.2, - X_val=None, - y_val=None, - random_state: int = 101, - batch_size: int = 128, - shuffle: bool = True, - lr: float = 1e-4, - lr_patience: int = 10, - factor: float = 0.1, - weight_decay: float = 1e-06, - dataloader_kwargs={}, - ): - """ - Builds the model using the provided training data. - - Parameters - ---------- - X : DataFrame or array-like, shape (n_samples, n_features) - The training input samples. - y : array-like, shape (n_samples,) or (n_samples, n_targets) - The target values (real numbers). - val_size : float, default=0.2 - The proportion of the dataset to include in the validation split if `X_val` is None. Ignored if `X_val` is provided. - X_val : DataFrame or array-like, shape (n_samples, n_features), optional - The validation input samples. If provided, `X` and `y` are not split and this data is used for validation. - y_val : array-like, shape (n_samples,) or (n_samples, n_targets), optional - The validation target values. Required if `X_val` is provided. - random_state : int, default=101 - Controls the shuffling applied to the data before applying the split. - batch_size : int, default=64 - Number of samples per gradient update. - shuffle : bool, default=True - Whether to shuffle the training data before each epoch. - lr : float, default=1e-3 - Learning rate for the optimizer. - lr_patience : int, default=10 - Number of epochs with no improvement on the validation loss to wait before reducing the learning rate. - factor : float, default=0.1 - Factor by which the learning rate will be reduced. - weight_decay : float, default=0.025 - Weight decay (L2 penalty) coefficient. - dataloader_kwargs: dict, default={} - The kwargs for the pytorch dataloader class. - - - - Returns - ------- - self : object - The built classifier. - """ - if not isinstance(X, pd.DataFrame): - X = pd.DataFrame(X) - if isinstance(y, pd.Series): - y = y.values - if X_val: - if not isinstance(X_val, pd.DataFrame): - X_val = pd.DataFrame(X_val) - if isinstance(y_val, pd.Series): - y_val = y_val.values - - self.data_module = MambularDataModule( - preprocessor=self.preprocessor, - batch_size=batch_size, - shuffle=shuffle, - X_val=X_val, - y_val=y_val, - val_size=val_size, - random_state=random_state, - regression=False, - **dataloader_kwargs - ) - - self.data_module.preprocess_data( - X, y, X_val, y_val, val_size=val_size, random_state=random_state - ) - - num_classes = len(np.unique(y)) - - self.model = TaskModel( - model_class=self.base_model, - num_classes=num_classes, - config=self.config, - cat_feature_info=self.data_module.cat_feature_info, - num_feature_info=self.data_module.num_feature_info, - lr=lr, - lr_patience=lr_patience, - lr_factor=factor, - weight_decay=weight_decay, - ) - - self.built = True - - return self - - def get_number_of_params(self, requires_grad=True): - """ - Calculate the number of parameters in the model. - - Parameters - ---------- - requires_grad : bool, optional - If True, only count the parameters that require gradients (trainable parameters). - If False, count all parameters. Default is True. - - Returns - ------- - int - The total number of parameters in the model. - - Raises - ------ - ValueError - If the model has not been built prior to calling this method. - """ - if not self.built: - raise ValueError( - "The model must be built before the number of parameters can be estimated" - ) - else: - if requires_grad: - return sum( - p.numel() for p in self.model.parameters() if p.requires_grad - ) - else: - return sum(p.numel() for p in self.model.parameters()) - - def fit( - self, - X, - y, - val_size: float = 0.2, - X_val=None, - y_val=None, - max_epochs: int = 100, - random_state: int = 101, - batch_size: int = 128, - shuffle: bool = True, - patience: int = 15, - monitor: str = "val_loss", - mode: str = "min", - lr: float = 1e-4, - lr_patience: int = 10, - factor: float = 0.1, - weight_decay: float = 1e-06, - checkpoint_path="model_checkpoints", - dataloader_kwargs={}, - rebuild=True, - **trainer_kwargs - ): - """ - Trains the classification model using the provided training data. Optionally, a separate validation set can be used. - - Parameters - ---------- - X : DataFrame or array-like, shape (n_samples, n_features) - The training input samples. - y : array-like, shape (n_samples,) or (n_samples, n_targets) - The target values (real numbers). - val_size : float, default=0.2 - The proportion of the dataset to include in the validation split if `X_val` is None. Ignored if `X_val` is provided. - X_val : DataFrame or array-like, shape (n_samples, n_features), optional - The validation input samples. If provided, `X` and `y` are not split and this data is used for validation. - y_val : array-like, shape (n_samples,) or (n_samples, n_targets), optional - The validation target values. Required if `X_val` is provided. - max_epochs : int, default=100 - Maximum number of epochs for training. - random_state : int, default=101 - Controls the shuffling applied to the data before applying the split. - batch_size : int, default=64 - Number of samples per gradient update. - shuffle : bool, default=True - Whether to shuffle the training data before each epoch. - patience : int, default=10 - Number of epochs with no improvement on the validation loss to wait before early stopping. - monitor : str, default="val_loss" - The metric to monitor for early stopping. - mode : str, default="min" - Whether the monitored metric should be minimized (`min`) or maximized (`max`). - lr : float, default=1e-3 - Learning rate for the optimizer. - lr_patience : int, default=10 - Number of epochs with no improvement on the validation loss to wait before reducing the learning rate. - factor : float, default=0.1 - Factor by which the learning rate will be reduced. - weight_decay : float, default=0.025 - Weight decay (L2 penalty) coefficient. - checkpoint_path : str, default="model_checkpoints" - Path where the checkpoints are being saved. - dataloader_kwargs: dict, default={} - The kwargs for the pytorch dataloader class. - rebuild: bool, default=True - Whether to rebuild the model when it already was built. - **trainer_kwargs : Additional keyword arguments for PyTorch Lightning's Trainer class. - - - Returns - ------- - self : object - The fitted classifier. - """ - if not self.built and not rebuild: - if not isinstance(X, pd.DataFrame): - X = pd.DataFrame(X) - if isinstance(y, pd.Series): - y = y.values - if X_val: - if not isinstance(X_val, pd.DataFrame): - X_val = pd.DataFrame(X_val) - if isinstance(y_val, pd.Series): - y_val = y_val.values - - self.data_module = MambularDataModule( - preprocessor=self.preprocessor, - batch_size=batch_size, - shuffle=shuffle, - X_val=X_val, - y_val=y_val, - val_size=val_size, - random_state=random_state, - regression=False, - **dataloader_kwargs - ) - - self.data_module.preprocess_data( - X, y, X_val, y_val, val_size=val_size, random_state=random_state - ) - - num_classes = len(np.unique(y)) - - self.model = TaskModel( - model_class=self.base_model, - num_classes=num_classes, - config=self.config, - cat_feature_info=self.data_module.cat_feature_info, - num_feature_info=self.data_module.num_feature_info, - lr=lr, - lr_patience=lr_patience, - lr_factor=factor, - weight_decay=weight_decay, - ) - - early_stop_callback = EarlyStopping( - monitor=monitor, min_delta=0.00, patience=patience, verbose=False, mode=mode - ) - - checkpoint_callback = ModelCheckpoint( - monitor="val_loss", # Adjust according to your validation metric - mode="min", - save_top_k=1, - dirpath=checkpoint_path, # Specify the directory to save checkpoints - filename="best_model", - ) - - # Initialize the trainer and train the model - trainer = pl.Trainer( - max_epochs=max_epochs, - callbacks=[early_stop_callback, checkpoint_callback], - **trainer_kwargs - ) - trainer.fit(self.model, self.data_module) - - best_model_path = checkpoint_callback.best_model_path - if best_model_path: - checkpoint = torch.load(best_model_path) - self.model.load_state_dict(checkpoint["state_dict"]) - - return self - - def predict(self, X): - """ - Predicts target values for the given input samples. - - Parameters - ---------- - X : DataFrame or array-like, shape (n_samples, n_features) - The input samples for which to predict target values. - - - Returns - ------- - predictions : ndarray, shape (n_samples,) or (n_samples, n_outputs) - The predicted target values. - """ - # Ensure model and data module are initialized - if self.model is None or self.data_module is None: - raise ValueError("The model or data module has not been fitted yet.") - - # Preprocess the data using the data module - cat_tensors, num_tensors = self.data_module.preprocess_test_data(X) - - # Move tensors to appropriate device - device = next(self.model.parameters()).device - if isinstance(cat_tensors, list): - cat_tensors = [tensor.to(device) for tensor in cat_tensors] - else: - cat_tensors = cat_tensors.to(device) - - if isinstance(num_tensors, list): - num_tensors = [tensor.to(device) for tensor in num_tensors] - else: - num_tensors = num_tensors.to(device) - - # Set model to evaluation mode - self.model.eval() - - # Perform inference - with torch.no_grad(): - logits = self.model(num_features=num_tensors, cat_features=cat_tensors) - - # Check the shape of the logits to determine binary or multi-class classification - if logits.shape[1] == 1: - # Binary classification - probabilities = torch.sigmoid(logits) - predictions = (probabilities > 0.5).long().squeeze() - else: - # Multi-class classification - probabilities = torch.softmax(logits, dim=1) - predictions = torch.argmax(probabilities, dim=1) - - # Convert predictions to NumPy array and return - return predictions.cpu().numpy() - - def predict_proba(self, X): - """ - Predict class probabilities for the given input samples. - - Parameters - ---------- - X : array-like or pd.DataFrame of shape (n_samples, n_features) - The input samples for which to predict class probabilities. - - - Notes - ----- - The method preprocesses the input data using the same preprocessor used during training, - sets the model to evaluation mode, and then performs inference to predict the class probabilities. - Softmax is applied to the logits to obtain probabilities, which are then converted from a PyTorch tensor - to a NumPy array before being returned. - - - Examples - -------- - >>> from sklearn.metrics import accuracy_score, precision_score, f1_score, roc_auc_score - >>> # Define the metrics you want to evaluate - >>> metrics = { - ... 'Accuracy': (accuracy_score, False), - ... 'Precision': (precision_score, False), - ... 'F1 Score': (f1_score, False), - ... 'AUC Score': (roc_auc_score, True) - ... } - >>> # Assuming 'X_test' and 'y_test' are your test dataset and labels - >>> # Evaluate using the specified metrics - >>> results = classifier.evaluate(X_test, y_test, metrics=metrics) - - - Returns - ------- - probabilities : ndarray of shape (n_samples, n_classes) - Predicted class probabilities for each input sample. - - """ - # Preprocess the data - if not isinstance(X, pd.DataFrame): - X = pd.DataFrame(X) - device = next(self.model.parameters()).device - cat_tensors, num_tensors = self.data_module.preprocess_test_data(X) - if isinstance(cat_tensors, list): - cat_tensors = [tensor.to(device) for tensor in cat_tensors] - else: - cat_tensors = cat_tensors.to(device) - - if isinstance(num_tensors, list): - num_tensors = [tensor.to(device) for tensor in num_tensors] - else: - num_tensors = num_tensors.to(device) - - # Set the model to evaluation mode - self.model.eval() - - # Perform inference - with torch.no_grad(): - logits = self.model(num_features=num_tensors, cat_features=cat_tensors) - if logits.shape[1] > 1: - probabilities = torch.softmax(logits, dim=1) - else: - probabilities = torch.sigmoid(logits) - - # Convert probabilities to NumPy array and return - return probabilities.cpu().numpy() - - def evaluate(self, X, y_true, metrics=None): - """ - Evaluate the model on the given data using specified metrics. - - Parameters - ---------- - X : array-like or pd.DataFrame of shape (n_samples, n_features) - The input samples to predict. - y_true : array-like of shape (n_samples,) - The true class labels against which to evaluate the predictions. - metrics : dict - A dictionary where keys are metric names and values are tuples containing the metric function - and a boolean indicating whether the metric requires probability scores (True) or class labels (False). - - - Returns - ------- - scores : dict - A dictionary with metric names as keys and their corresponding scores as values. - - - Notes - ----- - This method uses either the `predict` or `predict_proba` method depending on the metric requirements. - """ - # Ensure input is in the correct format - if metrics is None: - metrics = {"Accuracy": (accuracy_score, False)} - - if not isinstance(X, pd.DataFrame): - X = pd.DataFrame(X) - - # Initialize dictionary to store results - scores = {} - - # Generate class probabilities if any metric requires them - if any(use_proba for _, use_proba in metrics.values()): - probabilities = self.predict_proba(X) - - # Generate class labels if any metric requires them - if any(not use_proba for _, use_proba in metrics.values()): - predictions = self.predict(X) - - # Compute each metric - for metric_name, (metric_func, use_proba) in metrics.items(): - if use_proba: - scores[metric_name] = metric_func(y_true, probabilities) - else: - scores[metric_name] = metric_func(y_true, predictions) - - return scores +import lightning as pl +import pandas as pd +import torch +from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint +from sklearn.base import BaseEstimator +from sklearn.metrics import accuracy_score +import warnings +from ..base_models.lightning_wrapper import TaskModel +from ..data_utils.datamodule import MambularDataModule +from ..preprocessing import Preprocessor +import numpy as np + + +class SklearnBaseClassifier(BaseEstimator): + def __init__(self, model, config, **kwargs): + preprocessor_arg_names = [ + "n_bins", + "numerical_preprocessing", + "use_decision_tree_bins", + "binning_strategy", + "task", + "cat_cutoff", + "treat_all_integers_as_numerical", + "knots", + "degree", + ] + + self.config_kwargs = { + k: v for k, v in kwargs.items() if k not in preprocessor_arg_names + } + self.config = config(**self.config_kwargs) + + preprocessor_kwargs = { + k: v for k, v in kwargs.items() if k in preprocessor_arg_names + } + + self.preprocessor = Preprocessor(**preprocessor_kwargs) + self.model = None + + # Raise a warning if task is set to 'classification' + if preprocessor_kwargs.get("task") == "regression": + warnings.warn( + "The task is set to 'regression'. The Classifier is designed for classification tasks.", + UserWarning, + ) + + self.base_model = model + self.built = False + + def get_params(self, deep=True): + """ + Get parameters for this estimator. Overrides the BaseEstimator method. + + Parameters + ---------- + deep : bool, default=True + If True, returns the parameters for this estimator and contained sub-objects that are estimators. + + Returns + ------- + params : dict + Parameter names mapped to their values. + """ + params = self.config_kwargs # Parameters used to initialize DefaultConfig + + # If deep=True, include parameters from nested components like preprocessor + if deep: + # Assuming Preprocessor has a get_params method + preprocessor_params = { + "preprocessor__" + key: value + for key, value in self.preprocessor.get_params().items() + } + params.update(preprocessor_params) + + return params + + def set_params(self, **parameters): + """ + Set the parameters of this estimator. Overrides the BaseEstimator method. + + Parameters + ---------- + **parameters : dict + Estimator parameters to be set. + + Returns + ------- + self : object + The instance with updated parameters. + """ + # Update config_kwargs with provided parameters + valid_config_keys = self.config_kwargs.keys() + config_updates = {k: v for k, v in parameters.items() if k in valid_config_keys} + self.config_kwargs.update(config_updates) + + # Update the config object + for key, value in config_updates.items(): + setattr(self.config, key, value) + + # Handle preprocessor parameters (prefixed with 'preprocessor__') + preprocessor_params = { + k.split("__")[1]: v + for k, v in parameters.items() + if k.startswith("preprocessor__") + } + if preprocessor_params: + # Assuming Preprocessor has a set_params method + self.preprocessor.set_params(**preprocessor_params) + + return self + + def build_model( + self, + X, + y, + val_size: float = 0.2, + X_val=None, + y_val=None, + random_state: int = 101, + batch_size: int = 128, + shuffle: bool = True, + lr: float = 1e-4, + lr_patience: int = 10, + factor: float = 0.1, + weight_decay: float = 1e-06, + dataloader_kwargs={}, + ): + """ + Builds the model using the provided training data. + + Parameters + ---------- + X : DataFrame or array-like, shape (n_samples, n_features) + The training input samples. + y : array-like, shape (n_samples,) or (n_samples, n_targets) + The target values (real numbers). + val_size : float, default=0.2 + The proportion of the dataset to include in the validation split if `X_val` is None. Ignored if `X_val` is provided. + X_val : DataFrame or array-like, shape (n_samples, n_features), optional + The validation input samples. If provided, `X` and `y` are not split and this data is used for validation. + y_val : array-like, shape (n_samples,) or (n_samples, n_targets), optional + The validation target values. Required if `X_val` is provided. + random_state : int, default=101 + Controls the shuffling applied to the data before applying the split. + batch_size : int, default=64 + Number of samples per gradient update. + shuffle : bool, default=True + Whether to shuffle the training data before each epoch. + lr : float, default=1e-3 + Learning rate for the optimizer. + lr_patience : int, default=10 + Number of epochs with no improvement on the validation loss to wait before reducing the learning rate. + factor : float, default=0.1 + Factor by which the learning rate will be reduced. + weight_decay : float, default=0.025 + Weight decay (L2 penalty) coefficient. + dataloader_kwargs: dict, default={} + The kwargs for the pytorch dataloader class. + + + + Returns + ------- + self : object + The built classifier. + """ + if not isinstance(X, pd.DataFrame): + X = pd.DataFrame(X) + if isinstance(y, pd.Series): + y = y.values + if X_val: + if not isinstance(X_val, pd.DataFrame): + X_val = pd.DataFrame(X_val) + if isinstance(y_val, pd.Series): + y_val = y_val.values + + self.data_module = MambularDataModule( + preprocessor=self.preprocessor, + batch_size=batch_size, + shuffle=shuffle, + X_val=X_val, + y_val=y_val, + val_size=val_size, + random_state=random_state, + regression=False, + **dataloader_kwargs + ) + + self.data_module.preprocess_data( + X, y, X_val, y_val, val_size=val_size, random_state=random_state + ) + + num_classes = len(np.unique(y)) + + self.model = TaskModel( + model_class=self.base_model, + num_classes=num_classes, + config=self.config, + cat_feature_info=self.data_module.cat_feature_info, + num_feature_info=self.data_module.num_feature_info, + lr=lr, + lr_patience=lr_patience, + lr_factor=factor, + weight_decay=weight_decay, + ) + + self.built = True + + return self + + def get_number_of_params(self, requires_grad=True): + """ + Calculate the number of parameters in the model. + + Parameters + ---------- + requires_grad : bool, optional + If True, only count the parameters that require gradients (trainable parameters). + If False, count all parameters. Default is True. + + Returns + ------- + int + The total number of parameters in the model. + + Raises + ------ + ValueError + If the model has not been built prior to calling this method. + """ + if not self.built: + raise ValueError( + "The model must be built before the number of parameters can be estimated" + ) + else: + if requires_grad: + return sum( + p.numel() for p in self.model.parameters() if p.requires_grad + ) + else: + return sum(p.numel() for p in self.model.parameters()) + + def fit( + self, + X, + y, + val_size: float = 0.2, + X_val=None, + y_val=None, + max_epochs: int = 100, + random_state: int = 101, + batch_size: int = 128, + shuffle: bool = True, + patience: int = 15, + monitor: str = "val_loss", + mode: str = "min", + lr: float = 1e-4, + lr_patience: int = 10, + factor: float = 0.1, + weight_decay: float = 1e-06, + checkpoint_path="model_checkpoints", + dataloader_kwargs={}, + rebuild=True, + **trainer_kwargs + ): + """ + Trains the classification model using the provided training data. Optionally, a separate validation set can be used. + + Parameters + ---------- + X : DataFrame or array-like, shape (n_samples, n_features) + The training input samples. + y : array-like, shape (n_samples,) or (n_samples, n_targets) + The target values (real numbers). + val_size : float, default=0.2 + The proportion of the dataset to include in the validation split if `X_val` is None. Ignored if `X_val` is provided. + X_val : DataFrame or array-like, shape (n_samples, n_features), optional + The validation input samples. If provided, `X` and `y` are not split and this data is used for validation. + y_val : array-like, shape (n_samples,) or (n_samples, n_targets), optional + The validation target values. Required if `X_val` is provided. + max_epochs : int, default=100 + Maximum number of epochs for training. + random_state : int, default=101 + Controls the shuffling applied to the data before applying the split. + batch_size : int, default=64 + Number of samples per gradient update. + shuffle : bool, default=True + Whether to shuffle the training data before each epoch. + patience : int, default=10 + Number of epochs with no improvement on the validation loss to wait before early stopping. + monitor : str, default="val_loss" + The metric to monitor for early stopping. + mode : str, default="min" + Whether the monitored metric should be minimized (`min`) or maximized (`max`). + lr : float, default=1e-3 + Learning rate for the optimizer. + lr_patience : int, default=10 + Number of epochs with no improvement on the validation loss to wait before reducing the learning rate. + factor : float, default=0.1 + Factor by which the learning rate will be reduced. + weight_decay : float, default=0.025 + Weight decay (L2 penalty) coefficient. + checkpoint_path : str, default="model_checkpoints" + Path where the checkpoints are being saved. + dataloader_kwargs: dict, default={} + The kwargs for the pytorch dataloader class. + rebuild: bool, default=True + Whether to rebuild the model when it already was built. + **trainer_kwargs : Additional keyword arguments for PyTorch Lightning's Trainer class. + + + Returns + ------- + self : object + The fitted classifier. + """ + if (not self.built) or (self.built and rebuild): + if not isinstance(X, pd.DataFrame): + X = pd.DataFrame(X) + if isinstance(y, pd.Series): + y = y.values + if X_val: + if not isinstance(X_val, pd.DataFrame): + X_val = pd.DataFrame(X_val) + if isinstance(y_val, pd.Series): + y_val = y_val.values + + self.data_module = MambularDataModule( + preprocessor=self.preprocessor, + batch_size=batch_size, + shuffle=shuffle, + X_val=X_val, + y_val=y_val, + val_size=val_size, + random_state=random_state, + regression=False, + **dataloader_kwargs + ) + + self.data_module.preprocess_data( + X, y, X_val, y_val, val_size=val_size, random_state=random_state + ) + + num_classes = len(np.unique(y)) + + self.model = TaskModel( + model_class=self.base_model, + num_classes=num_classes, + config=self.config, + cat_feature_info=self.data_module.cat_feature_info, + num_feature_info=self.data_module.num_feature_info, + lr=lr, + lr_patience=lr_patience, + lr_factor=factor, + weight_decay=weight_decay, + ) + + early_stop_callback = EarlyStopping( + monitor=monitor, min_delta=0.00, patience=patience, verbose=False, mode=mode + ) + + checkpoint_callback = ModelCheckpoint( + monitor="val_loss", # Adjust according to your validation metric + mode="min", + save_top_k=1, + dirpath=checkpoint_path, # Specify the directory to save checkpoints + filename="best_model", + ) + + # Initialize the trainer and train the model + trainer = pl.Trainer( + max_epochs=max_epochs, + callbacks=[early_stop_callback, checkpoint_callback], + **trainer_kwargs + ) + trainer.fit(self.model, self.data_module) + + best_model_path = checkpoint_callback.best_model_path + if best_model_path: + checkpoint = torch.load(best_model_path) + self.model.load_state_dict(checkpoint["state_dict"]) + + return self + + def predict(self, X): + """ + Predicts target values for the given input samples. + + Parameters + ---------- + X : DataFrame or array-like, shape (n_samples, n_features) + The input samples for which to predict target values. + + + Returns + ------- + predictions : ndarray, shape (n_samples,) or (n_samples, n_outputs) + The predicted target values. + """ + # Ensure model and data module are initialized + if self.model is None or self.data_module is None: + raise ValueError("The model or data module has not been fitted yet.") + + # Preprocess the data using the data module + cat_tensors, num_tensors = self.data_module.preprocess_test_data(X) + + # Move tensors to appropriate device + device = next(self.model.parameters()).device + if isinstance(cat_tensors, list): + cat_tensors = [tensor.to(device) for tensor in cat_tensors] + else: + cat_tensors = cat_tensors.to(device) + + if isinstance(num_tensors, list): + num_tensors = [tensor.to(device) for tensor in num_tensors] + else: + num_tensors = num_tensors.to(device) + + # Set model to evaluation mode + self.model.eval() + + # Perform inference + with torch.no_grad(): + logits = self.model(num_features=num_tensors, cat_features=cat_tensors) + + # Check the shape of the logits to determine binary or multi-class classification + if logits.shape[1] == 1: + # Binary classification + probabilities = torch.sigmoid(logits) + predictions = (probabilities > 0.5).long().squeeze() + else: + # Multi-class classification + probabilities = torch.softmax(logits, dim=1) + predictions = torch.argmax(probabilities, dim=1) + + # Convert predictions to NumPy array and return + return predictions.cpu().numpy() + + def predict_proba(self, X): + """ + Predict class probabilities for the given input samples. + + Parameters + ---------- + X : array-like or pd.DataFrame of shape (n_samples, n_features) + The input samples for which to predict class probabilities. + + + Notes + ----- + The method preprocesses the input data using the same preprocessor used during training, + sets the model to evaluation mode, and then performs inference to predict the class probabilities. + Softmax is applied to the logits to obtain probabilities, which are then converted from a PyTorch tensor + to a NumPy array before being returned. + + + Examples + -------- + >>> from sklearn.metrics import accuracy_score, precision_score, f1_score, roc_auc_score + >>> # Define the metrics you want to evaluate + >>> metrics = { + ... 'Accuracy': (accuracy_score, False), + ... 'Precision': (precision_score, False), + ... 'F1 Score': (f1_score, False), + ... 'AUC Score': (roc_auc_score, True) + ... } + >>> # Assuming 'X_test' and 'y_test' are your test dataset and labels + >>> # Evaluate using the specified metrics + >>> results = classifier.evaluate(X_test, y_test, metrics=metrics) + + + Returns + ------- + probabilities : ndarray of shape (n_samples, n_classes) + Predicted class probabilities for each input sample. + + """ + # Preprocess the data + if not isinstance(X, pd.DataFrame): + X = pd.DataFrame(X) + device = next(self.model.parameters()).device + cat_tensors, num_tensors = self.data_module.preprocess_test_data(X) + if isinstance(cat_tensors, list): + cat_tensors = [tensor.to(device) for tensor in cat_tensors] + else: + cat_tensors = cat_tensors.to(device) + + if isinstance(num_tensors, list): + num_tensors = [tensor.to(device) for tensor in num_tensors] + else: + num_tensors = num_tensors.to(device) + + # Set the model to evaluation mode + self.model.eval() + + # Perform inference + with torch.no_grad(): + logits = self.model(num_features=num_tensors, cat_features=cat_tensors) + if logits.shape[1] > 1: + probabilities = torch.softmax(logits, dim=1) + else: + probabilities = torch.sigmoid(logits) + + # Convert probabilities to NumPy array and return + return probabilities.cpu().numpy() + + def evaluate(self, X, y_true, metrics=None): + """ + Evaluate the model on the given data using specified metrics. + + Parameters + ---------- + X : array-like or pd.DataFrame of shape (n_samples, n_features) + The input samples to predict. + y_true : array-like of shape (n_samples,) + The true class labels against which to evaluate the predictions. + metrics : dict + A dictionary where keys are metric names and values are tuples containing the metric function + and a boolean indicating whether the metric requires probability scores (True) or class labels (False). + + + Returns + ------- + scores : dict + A dictionary with metric names as keys and their corresponding scores as values. + + + Notes + ----- + This method uses either the `predict` or `predict_proba` method depending on the metric requirements. + """ + # Ensure input is in the correct format + if metrics is None: + metrics = {"Accuracy": (accuracy_score, False)} + + if not isinstance(X, pd.DataFrame): + X = pd.DataFrame(X) + + # Initialize dictionary to store results + scores = {} + + # Generate class probabilities if any metric requires them + if any(use_proba for _, use_proba in metrics.values()): + probabilities = self.predict_proba(X) + + # Generate class labels if any metric requires them + if any(not use_proba for _, use_proba in metrics.values()): + predictions = self.predict(X) + + # Compute each metric + for metric_name, (metric_func, use_proba) in metrics.items(): + if use_proba: + scores[metric_name] = metric_func(y_true, probabilities) + else: + scores[metric_name] = metric_func(y_true, predictions) + + return scores diff --git a/mambular/models/sklearn_base_lss.py b/mambular/models/sklearn_base_lss.py index 3298cff..e4a0444 100644 --- a/mambular/models/sklearn_base_lss.py +++ b/mambular/models/sklearn_base_lss.py @@ -1,556 +1,559 @@ -import lightning as pl -import pandas as pd -import torch -from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint -from sklearn.base import BaseEstimator -from sklearn.metrics import accuracy_score -import warnings -from ..base_models.lightning_wrapper import TaskModel -from ..data_utils.datamodule import MambularDataModule -from ..preprocessing import Preprocessor -import numpy as np -from ..utils.distributional_metrics import ( - beta_brier_score, - dirichlet_error, - gamma_deviance, - inverse_gamma_loss, - negative_binomial_deviance, - poisson_deviance, - student_t_loss, -) -from sklearn.metrics import accuracy_score, mean_squared_error -import properscoring as ps -from ..utils.distributions import ( - BetaDistribution, - CategoricalDistribution, - DirichletDistribution, - GammaDistribution, - InverseGammaDistribution, - NegativeBinomialDistribution, - NormalDistribution, - PoissonDistribution, - StudentTDistribution, -) - - -class SklearnBaseLSS(BaseEstimator): - def __init__(self, model, config, **kwargs): - preprocessor_arg_names = [ - "n_bins", - "numerical_preprocessing", - "use_decision_tree_bins", - "binning_strategy", - "task", - "cat_cutoff", - "treat_all_integers_as_numerical", - "knots", - "degree", - ] - - self.config_kwargs = { - k: v for k, v in kwargs.items() if k not in preprocessor_arg_names - } - self.config = config(**self.config_kwargs) - - preprocessor_kwargs = { - k: v for k, v in kwargs.items() if k in preprocessor_arg_names - } - - self.preprocessor = Preprocessor(**preprocessor_kwargs) - self.model = None - - # Raise a warning if task is set to 'classification' - if preprocessor_kwargs.get("task") == "classification": - warnings.warn( - "The task is set to 'classification'. Be aware of your preferred distribution, that this might lead to unsatisfactory results.", - UserWarning, - ) - - self.base_model = model - - def get_params(self, deep=True): - """ - Get parameters for this estimator. Overrides the BaseEstimator method. - - Parameters - ---------- - deep : bool, default=True - If True, returns the parameters for this estimator and contained sub-objects that are estimators. - - Returns - ------- - params : dict - Parameter names mapped to their values. - """ - params = self.config_kwargs # Parameters used to initialize DefaultConfig - - # If deep=True, include parameters from nested components like preprocessor - if deep: - # Assuming Preprocessor has a get_params method - preprocessor_params = { - "preprocessor__" + key: value - for key, value in self.preprocessor.get_params().items() - } - params.update(preprocessor_params) - - return params - - def set_params(self, **parameters): - """ - Set the parameters of this estimator. Overrides the BaseEstimator method. - - Parameters - ---------- - **parameters : dict - Estimator parameters to be set. - - Returns - ------- - self : object - The instance with updated parameters. - """ - # Update config_kwargs with provided parameters - valid_config_keys = self.config_kwargs.keys() - config_updates = {k: v for k, v in parameters.items() if k in valid_config_keys} - self.config_kwargs.update(config_updates) - - # Update the config object - for key, value in config_updates.items(): - setattr(self.config, key, value) - - # Handle preprocessor parameters (prefixed with 'preprocessor__') - preprocessor_params = { - k.split("__")[1]: v - for k, v in parameters.items() - if k.startswith("preprocessor__") - } - if preprocessor_params: - # Assuming Preprocessor has a set_params method - self.preprocessor.set_params(**preprocessor_params) - - return self - - def build_model( - self, - X, - y, - val_size: float = 0.2, - X_val=None, - y_val=None, - random_state: int = 101, - batch_size: int = 128, - shuffle: bool = True, - lr: float = 1e-4, - lr_patience: int = 10, - factor: float = 0.1, - weight_decay: float = 1e-06, - dataloader_kwargs={}, - ): - """ - Builds the model using the provided training data. - - Parameters - ---------- - X : DataFrame or array-like, shape (n_samples, n_features) - The training input samples. - y : array-like, shape (n_samples,) or (n_samples, n_targets) - The target values (real numbers). - val_size : float, default=0.2 - The proportion of the dataset to include in the validation split if `X_val` is None. Ignored if `X_val` is provided. - X_val : DataFrame or array-like, shape (n_samples, n_features), optional - The validation input samples. If provided, `X` and `y` are not split and this data is used for validation. - y_val : array-like, shape (n_samples,) or (n_samples, n_targets), optional - The validation target values. Required if `X_val` is provided. - random_state : int, default=101 - Controls the shuffling applied to the data before applying the split. - batch_size : int, default=64 - Number of samples per gradient update. - shuffle : bool, default=True - Whether to shuffle the training data before each epoch. - lr : float, default=1e-3 - Learning rate for the optimizer. - lr_patience : int, default=10 - Number of epochs with no improvement on the validation loss to wait before reducing the learning rate. - factor : float, default=0.1 - Factor by which the learning rate will be reduced. - weight_decay : float, default=0.025 - Weight decay (L2 penalty) coefficient. - dataloader_kwargs: dict, default={} - The kwargs for the pytorch dataloader class. - - Returns - ------- - self : object - The built distributional regressor. - """ - if not isinstance(X, pd.DataFrame): - X = pd.DataFrame(X) - if isinstance(y, pd.Series): - y = y.values - if X_val: - if not isinstance(X_val, pd.DataFrame): - X_val = pd.DataFrame(X_val) - if isinstance(y_val, pd.Series): - y_val = y_val.values - - self.data_module = MambularDataModule( - preprocessor=self.preprocessor, - batch_size=batch_size, - shuffle=shuffle, - X_val=X_val, - y_val=y_val, - val_size=val_size, - random_state=random_state, - regression=False, - **dataloader_kwargs - ) - - self.data_module.preprocess_data( - X, y, X_val, y_val, val_size=val_size, random_state=random_state - ) - - num_classes = len(np.unique(y)) - - self.model = TaskModel( - model_class=self.base_model, - num_classes=num_classes, - config=self.config, - cat_feature_info=self.data_module.cat_feature_info, - num_feature_info=self.data_module.num_feature_info, - lr=lr, - lr_patience=lr_patience, - lr_factor=factor, - weight_decay=weight_decay, - ) - - self.built = True - - return self - - def get_number_of_params(self, requires_grad=True): - """ - Calculate the number of parameters in the model. - - Parameters - ---------- - requires_grad : bool, optional - If True, only count the parameters that require gradients (trainable parameters). - If False, count all parameters. Default is True. - - Returns - ------- - int - The total number of parameters in the model. - - Raises - ------ - ValueError - If the model has not been built prior to calling this method. - """ - if not self.built: - raise ValueError( - "The model must be built before the number of parameters can be estimated" - ) - else: - if requires_grad: - return sum( - p.numel() for p in self.model.parameters() if p.requires_grad - ) - else: - return sum(p.numel() for p in self.model.parameters()) - - def fit( - self, - X, - y, - family, - val_size: float = 0.2, - X_val=None, - y_val=None, - max_epochs: int = 100, - random_state: int = 101, - batch_size: int = 128, - shuffle: bool = True, - patience: int = 15, - monitor: str = "val_loss", - mode: str = "min", - lr: float = 1e-4, - lr_patience: int = 10, - factor: float = 0.1, - weight_decay: float = 1e-06, - checkpoint_path="model_checkpoints", - distributional_kwargs=None, - dataloader_kwargs={}, - **trainer_kwargs - ): - """ - Trains the regression model using the provided training data. Optionally, a separate validation set can be used. - - Parameters - ---------- - X : DataFrame or array-like, shape (n_samples, n_features) - The training input samples. - y : array-like, shape (n_samples,) or (n_samples, n_targets) - The target values (real numbers). - family : str - The name of the distribution family to use for the loss function. Examples include 'normal' for regression tasks. - val_size : float, default=0.2 - The proportion of the dataset to include in the validation split if `X_val` is None. Ignored if `X_val` is provided. - X_val : DataFrame or array-like, shape (n_samples, n_features), optional - The validation input samples. If provided, `X` and `y` are not split and this data is used for validation. - y_val : array-like, shape (n_samples,) or (n_samples, n_targets), optional - The validation target values. Required if `X_val` is provided. - max_epochs : int, default=100 - Maximum number of epochs for training. - random_state : int, default=101 - Controls the shuffling applied to the data before applying the split. - batch_size : int, default=64 - Number of samples per gradient update. - shuffle : bool, default=True - Whether to shuffle the training data before each epoch. - patience : int, default=10 - Number of epochs with no improvement on the validation loss to wait before early stopping. - monitor : str, default="val_loss" - The metric to monitor for early stopping. - mode : str, default="min" - Whether the monitored metric should be minimized (`min`) or maximized (`max`). - lr : float, default=1e-3 - Learning rate for the optimizer. - lr_patience : int, default=10 - Number of epochs with no improvement on the validation loss to wait before reducing the learning rate. - factor : float, default=0.1 - Factor by which the learning rate will be reduced. - weight_decay : float, default=0.025 - Weight decay (L2 penalty) coefficient. - distributional_kwargs : dict, default=None - any arguments taht are specific for a certain distribution. - checkpoint_path : str, default="model_checkpoints" - Path where the checkpoints are being saved. - dataloader_kwargs: dict, default={} - The kwargs for the pytorch dataloader class. - **trainer_kwargs : Additional keyword arguments for PyTorch Lightning's Trainer class. - - - Returns - ------- - self : object - The fitted regressor. - """ - distribution_classes = { - "normal": NormalDistribution, - "poisson": PoissonDistribution, - "gamma": GammaDistribution, - "beta": BetaDistribution, - "dirichlet": DirichletDistribution, - "studentt": StudentTDistribution, - "negativebinom": NegativeBinomialDistribution, - "inversegamma": InverseGammaDistribution, - "categorical": CategoricalDistribution, - } - - if distributional_kwargs is None: - distributional_kwargs = {} - - if family in distribution_classes: - self.family = distribution_classes[family](**distributional_kwargs) - else: - raise ValueError("Unsupported family: {}".format(family)) - - if not isinstance(X, pd.DataFrame): - X = pd.DataFrame(X) - if isinstance(y, pd.Series): - y = y.values - if X_val: - if not isinstance(X_val, pd.DataFrame): - X_val = pd.DataFrame(X_val) - if isinstance(y_val, pd.Series): - y_val = y_val.values - - self.data_module = MambularDataModule( - preprocessor=self.preprocessor, - batch_size=batch_size, - shuffle=shuffle, - X_val=X_val, - y_val=y_val, - val_size=val_size, - random_state=random_state, - regression=True, - **dataloader_kwargs - ) - - self.data_module.preprocess_data( - X, y, X_val, y_val, val_size=val_size, random_state=random_state - ) - - self.model = TaskModel( - model_class=self.base_model, - num_classes=self.family.param_count, - family=self.family, - config=self.config, - cat_feature_info=self.data_module.cat_feature_info, - num_feature_info=self.data_module.num_feature_info, - lr=lr, - lr_patience=lr_patience, - lr_factor=factor, - weight_decay=weight_decay, - lss=True, - ) - - early_stop_callback = EarlyStopping( - monitor=monitor, min_delta=0.00, patience=patience, verbose=False, mode=mode - ) - - checkpoint_callback = ModelCheckpoint( - monitor="val_loss", # Adjust according to your validation metric - mode="min", - save_top_k=1, - dirpath=checkpoint_path, # Specify the directory to save checkpoints - filename="best_model", - ) - - # Initialize the trainer and train the model - trainer = pl.Trainer( - max_epochs=max_epochs, - callbacks=[early_stop_callback, checkpoint_callback], - **trainer_kwargs - ) - trainer.fit(self.model, self.data_module) - - best_model_path = checkpoint_callback.best_model_path - if best_model_path: - checkpoint = torch.load(best_model_path) - self.model.load_state_dict(checkpoint["state_dict"]) - - return self - - def predict(self, X, raw=False): - """ - Predicts target values for the given input samples. - - Parameters - ---------- - X : DataFrame or array-like, shape (n_samples, n_features) - The input samples for which to predict target values. - - - Returns - ------- - predictions : ndarray, shape (n_samples,) or (n_samples, n_outputs) - The predicted target values. - """ - # Ensure model and data module are initialized - if self.model is None or self.data_module is None: - raise ValueError("The model or data module has not been fitted yet.") - - # Preprocess the data using the data module - cat_tensors, num_tensors = self.data_module.preprocess_test_data(X) - - # Move tensors to appropriate device - device = next(self.model.parameters()).device - if isinstance(cat_tensors, list): - cat_tensors = [tensor.to(device) for tensor in cat_tensors] - else: - cat_tensors = cat_tensors.to(device) - - if isinstance(num_tensors, list): - num_tensors = [tensor.to(device) for tensor in num_tensors] - else: - num_tensors = num_tensors.to(device) - - # Set model to evaluation mode - self.model.eval() - - # Perform inference - with torch.no_grad(): - predictions = self.model(num_features=num_tensors, cat_features=cat_tensors) - - if not raw: - return self.model.family(predictions).cpu().numpy() - - # Convert predictions to NumPy array and return - else: - return predictions.cpu().numpy() - - def evaluate(self, X, y_true, metrics=None, distribution_family=None): - """ - Evaluate the model on the given data using specified metrics. - - Parameters - ---------- - X : array-like or pd.DataFrame of shape (n_samples, n_features) - The input samples to predict. - y_true : array-like of shape (n_samples,) - The true class labels against which to evaluate the predictions. - metrics : dict - A dictionary where keys are metric names and values are tuples containing the metric function - and a boolean indicating whether the metric requires probability scores (True) or class labels (False). - distribution_family : str, optional - Specifies the distribution family the model is predicting for. If None, it will attempt to infer based - on the model's settings. - - - Returns - ------- - scores : dict - A dictionary with metric names as keys and their corresponding scores as values. - - - Notes - ----- - This method uses either the `predict` or `predict_proba` method depending on the metric requirements. - """ - # Infer distribution family from model settings if not provided - if distribution_family is None: - distribution_family = getattr(self.model, "distribution_family", "normal") - - # Setup default metrics if none are provided - if metrics is None: - metrics = self.get_default_metrics(distribution_family) - - # Make predictions - predictions = self.predict(X, raw=False) - - # Initialize dictionary to store results - scores = {} - - # Compute each metric - for metric_name, metric_func in metrics.items(): - scores[metric_name] = metric_func(y_true, predictions) - - return scores - - def get_default_metrics(self, distribution_family): - """ - Provides default metrics based on the distribution family. - - Parameters - ---------- - distribution_family : str - The distribution family for which to provide default metrics. - - - Returns - ------- - metrics : dict - A dictionary of default metric functions. - """ - default_metrics = { - "normal": { - "MSE": lambda y, pred: mean_squared_error(y, pred[:, 0]), - "CRPS": lambda y, pred: np.mean( - [ - ps.crps_gaussian(y[i], mu=pred[i, 0], sig=np.sqrt(pred[i, 1])) - for i in range(len(y)) - ] - ), - }, - "poisson": {"Poisson Deviance": poisson_deviance}, - "gamma": {"Gamma Deviance": gamma_deviance}, - "beta": {"Brier Score": beta_brier_score}, - "dirichlet": {"Dirichlet Error": dirichlet_error}, - "studentt": {"Student-T Loss": student_t_loss}, - "negativebinom": {"Negative Binomial Deviance": negative_binomial_deviance}, - "inversegamma": {"Inverse Gamma Loss": inverse_gamma_loss}, - "categorical": {"Accuracy": accuracy_score}, - } - return default_metrics.get(distribution_family, {}) +import lightning as pl +import pandas as pd +import torch +from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint +from sklearn.base import BaseEstimator +from sklearn.metrics import accuracy_score +import warnings +from ..base_models.lightning_wrapper import TaskModel +from ..data_utils.datamodule import MambularDataModule +from ..preprocessing import Preprocessor +import numpy as np +from ..utils.distributional_metrics import ( + beta_brier_score, + dirichlet_error, + gamma_deviance, + inverse_gamma_loss, + negative_binomial_deviance, + poisson_deviance, + student_t_loss, +) +from sklearn.metrics import accuracy_score, mean_squared_error +import properscoring as ps +from ..utils.distributions import ( + BetaDistribution, + CategoricalDistribution, + DirichletDistribution, + GammaDistribution, + InverseGammaDistribution, + NegativeBinomialDistribution, + NormalDistribution, + PoissonDistribution, + StudentTDistribution, +) + + +class SklearnBaseLSS(BaseEstimator): + def __init__(self, model, config, **kwargs): + preprocessor_arg_names = [ + "n_bins", + "numerical_preprocessing", + "use_decision_tree_bins", + "binning_strategy", + "task", + "cat_cutoff", + "treat_all_integers_as_numerical", + "knots", + "degree", + ] + + self.config_kwargs = { + k: v for k, v in kwargs.items() if k not in preprocessor_arg_names + } + self.config = config(**self.config_kwargs) + + preprocessor_kwargs = { + k: v for k, v in kwargs.items() if k in preprocessor_arg_names + } + + self.preprocessor = Preprocessor(**preprocessor_kwargs) + self.model = None + + # Raise a warning if task is set to 'classification' + if preprocessor_kwargs.get("task") == "classification": + warnings.warn( + "The task is set to 'classification'. Be aware of your preferred distribution, that this might lead to unsatisfactory results.", + UserWarning, + ) + + self.base_model = model + + def get_params(self, deep=True): + """ + Get parameters for this estimator. Overrides the BaseEstimator method. + + Parameters + ---------- + deep : bool, default=True + If True, returns the parameters for this estimator and contained sub-objects that are estimators. + + Returns + ------- + params : dict + Parameter names mapped to their values. + """ + params = self.config_kwargs # Parameters used to initialize DefaultConfig + + # If deep=True, include parameters from nested components like preprocessor + if deep: + # Assuming Preprocessor has a get_params method + preprocessor_params = { + "preprocessor__" + key: value + for key, value in self.preprocessor.get_params().items() + } + params.update(preprocessor_params) + + return params + + def set_params(self, **parameters): + """ + Set the parameters of this estimator. Overrides the BaseEstimator method. + + Parameters + ---------- + **parameters : dict + Estimator parameters to be set. + + Returns + ------- + self : object + The instance with updated parameters. + """ + # Update config_kwargs with provided parameters + valid_config_keys = self.config_kwargs.keys() + config_updates = {k: v for k, v in parameters.items() if k in valid_config_keys} + self.config_kwargs.update(config_updates) + + # Update the config object + for key, value in config_updates.items(): + setattr(self.config, key, value) + + # Handle preprocessor parameters (prefixed with 'preprocessor__') + preprocessor_params = { + k.split("__")[1]: v + for k, v in parameters.items() + if k.startswith("preprocessor__") + } + if preprocessor_params: + # Assuming Preprocessor has a set_params method + self.preprocessor.set_params(**preprocessor_params) + + return self + + def build_model( + self, + X, + y, + val_size: float = 0.2, + X_val=None, + y_val=None, + random_state: int = 101, + batch_size: int = 128, + shuffle: bool = True, + lr: float = 1e-4, + lr_patience: int = 10, + factor: float = 0.1, + weight_decay: float = 1e-06, + dataloader_kwargs={}, + ): + """ + Builds the model using the provided training data. + + Parameters + ---------- + X : DataFrame or array-like, shape (n_samples, n_features) + The training input samples. + y : array-like, shape (n_samples,) or (n_samples, n_targets) + The target values (real numbers). + val_size : float, default=0.2 + The proportion of the dataset to include in the validation split if `X_val` is None. Ignored if `X_val` is provided. + X_val : DataFrame or array-like, shape (n_samples, n_features), optional + The validation input samples. If provided, `X` and `y` are not split and this data is used for validation. + y_val : array-like, shape (n_samples,) or (n_samples, n_targets), optional + The validation target values. Required if `X_val` is provided. + random_state : int, default=101 + Controls the shuffling applied to the data before applying the split. + batch_size : int, default=64 + Number of samples per gradient update. + shuffle : bool, default=True + Whether to shuffle the training data before each epoch. + lr : float, default=1e-3 + Learning rate for the optimizer. + lr_patience : int, default=10 + Number of epochs with no improvement on the validation loss to wait before reducing the learning rate. + factor : float, default=0.1 + Factor by which the learning rate will be reduced. + weight_decay : float, default=0.025 + Weight decay (L2 penalty) coefficient. + dataloader_kwargs: dict, default={} + The kwargs for the pytorch dataloader class. + + Returns + ------- + self : object + The built distributional regressor. + """ + if not isinstance(X, pd.DataFrame): + X = pd.DataFrame(X) + if isinstance(y, pd.Series): + y = y.values + if X_val: + if not isinstance(X_val, pd.DataFrame): + X_val = pd.DataFrame(X_val) + if isinstance(y_val, pd.Series): + y_val = y_val.values + + self.data_module = MambularDataModule( + preprocessor=self.preprocessor, + batch_size=batch_size, + shuffle=shuffle, + X_val=X_val, + y_val=y_val, + val_size=val_size, + random_state=random_state, + regression=False, + **dataloader_kwargs + ) + + self.data_module.preprocess_data( + X, y, X_val, y_val, val_size=val_size, random_state=random_state + ) + + num_classes = len(np.unique(y)) + + self.model = TaskModel( + model_class=self.base_model, + num_classes=num_classes, + config=self.config, + cat_feature_info=self.data_module.cat_feature_info, + num_feature_info=self.data_module.num_feature_info, + lr=lr, + lr_patience=lr_patience, + lr_factor=factor, + weight_decay=weight_decay, + ) + + self.built = True + + return self + + def get_number_of_params(self, requires_grad=True): + """ + Calculate the number of parameters in the model. + + Parameters + ---------- + requires_grad : bool, optional + If True, only count the parameters that require gradients (trainable parameters). + If False, count all parameters. Default is True. + + Returns + ------- + int + The total number of parameters in the model. + + Raises + ------ + ValueError + If the model has not been built prior to calling this method. + """ + if not self.built: + raise ValueError( + "The model must be built before the number of parameters can be estimated" + ) + else: + if requires_grad: + return sum( + p.numel() for p in self.model.parameters() if p.requires_grad + ) + else: + return sum(p.numel() for p in self.model.parameters()) + + def fit( + self, + X, + y, + family, + val_size: float = 0.2, + X_val=None, + y_val=None, + max_epochs: int = 100, + random_state: int = 101, + batch_size: int = 128, + shuffle: bool = True, + patience: int = 15, + monitor: str = "val_loss", + mode: str = "min", + lr: float = 1e-4, + lr_patience: int = 10, + factor: float = 0.1, + weight_decay: float = 1e-06, + checkpoint_path="model_checkpoints", + distributional_kwargs=None, + dataloader_kwargs={}, + **trainer_kwargs + ): + """ + Trains the regression model using the provided training data. Optionally, a separate validation set can be used. + + Parameters + ---------- + X : DataFrame or array-like, shape (n_samples, n_features) + The training input samples. + y : array-like, shape (n_samples,) or (n_samples, n_targets) + The target values (real numbers). + family : str + The name of the distribution family to use for the loss function. Examples include 'normal' for regression tasks. + val_size : float, default=0.2 + The proportion of the dataset to include in the validation split if `X_val` is None. Ignored if `X_val` is provided. + X_val : DataFrame or array-like, shape (n_samples, n_features), optional + The validation input samples. If provided, `X` and `y` are not split and this data is used for validation. + y_val : array-like, shape (n_samples,) or (n_samples, n_targets), optional + The validation target values. Required if `X_val` is provided. + max_epochs : int, default=100 + Maximum number of epochs for training. + random_state : int, default=101 + Controls the shuffling applied to the data before applying the split. + batch_size : int, default=64 + Number of samples per gradient update. + shuffle : bool, default=True + Whether to shuffle the training data before each epoch. + patience : int, default=10 + Number of epochs with no improvement on the validation loss to wait before early stopping. + monitor : str, default="val_loss" + The metric to monitor for early stopping. + mode : str, default="min" + Whether the monitored metric should be minimized (`min`) or maximized (`max`). + lr : float, default=1e-3 + Learning rate for the optimizer. + lr_patience : int, default=10 + Number of epochs with no improvement on the validation loss to wait before reducing the learning rate. + factor : float, default=0.1 + Factor by which the learning rate will be reduced. + weight_decay : float, default=0.025 + Weight decay (L2 penalty) coefficient. + distributional_kwargs : dict, default=None + any arguments taht are specific for a certain distribution. + checkpoint_path : str, default="model_checkpoints" + Path where the checkpoints are being saved. + dataloader_kwargs: dict, default={} + The kwargs for the pytorch dataloader class. + **trainer_kwargs : Additional keyword arguments for PyTorch Lightning's Trainer class. + + + Returns + ------- + self : object + The fitted regressor. + """ + distribution_classes = { + "normal": NormalDistribution, + "poisson": PoissonDistribution, + "gamma": GammaDistribution, + "beta": BetaDistribution, + "dirichlet": DirichletDistribution, + "studentt": StudentTDistribution, + "negativebinom": NegativeBinomialDistribution, + "inversegamma": InverseGammaDistribution, + "categorical": CategoricalDistribution, + } + + if distributional_kwargs is None: + distributional_kwargs = {} + + if family in distribution_classes: + self.family = distribution_classes[family](**distributional_kwargs) + else: + raise ValueError("Unsupported family: {}".format(family)) + + if (not self.built) or (self.built and rebuild): + if not isinstance(X, pd.DataFrame): + X = pd.DataFrame(X) + if isinstance(y, pd.Series): + y = y.values + if X_val: + if not isinstance(X_val, pd.DataFrame): + X_val = pd.DataFrame(X_val) + if isinstance(y_val, pd.Series): + y_val = y_val.values + + self.data_module = MambularDataModule( + preprocessor=self.preprocessor, + batch_size=batch_size, + shuffle=shuffle, + X_val=X_val, + y_val=y_val, + val_size=val_size, + random_state=random_state, + regression=True, + **dataloader_kwargs + ) + + self.data_module.preprocess_data( + X, y, X_val, y_val, val_size=val_size, random_state=random_state + ) + + self.model = TaskModel( + model_class=self.base_model, + num_classes=self.family.param_count, + family=self.family, + config=self.config, + cat_feature_info=self.data_module.cat_feature_info, + num_feature_info=self.data_module.num_feature_info, + lr=lr, + lr_patience=lr_patience, + lr_factor=factor, + weight_decay=weight_decay, + lss=True, + ) + else: + assert self.built, "The model must be built before calling the fit method." + + early_stop_callback = EarlyStopping( + monitor=monitor, min_delta=0.00, patience=patience, verbose=False, mode=mode + ) + + checkpoint_callback = ModelCheckpoint( + monitor="val_loss", # Adjust according to your validation metric + mode="min", + save_top_k=1, + dirpath=checkpoint_path, # Specify the directory to save checkpoints + filename="best_model", + ) + + # Initialize the trainer and train the model + trainer = pl.Trainer( + max_epochs=max_epochs, + callbacks=[early_stop_callback, checkpoint_callback], + **trainer_kwargs + ) + trainer.fit(self.model, self.data_module) + + best_model_path = checkpoint_callback.best_model_path + if best_model_path: + checkpoint = torch.load(best_model_path) + self.model.load_state_dict(checkpoint["state_dict"]) + + return self + + def predict(self, X, raw=False): + """ + Predicts target values for the given input samples. + + Parameters + ---------- + X : DataFrame or array-like, shape (n_samples, n_features) + The input samples for which to predict target values. + + + Returns + ------- + predictions : ndarray, shape (n_samples,) or (n_samples, n_outputs) + The predicted target values. + """ + # Ensure model and data module are initialized + if self.model is None or self.data_module is None: + raise ValueError("The model or data module has not been fitted yet.") + + # Preprocess the data using the data module + cat_tensors, num_tensors = self.data_module.preprocess_test_data(X) + + # Move tensors to appropriate device + device = next(self.model.parameters()).device + if isinstance(cat_tensors, list): + cat_tensors = [tensor.to(device) for tensor in cat_tensors] + else: + cat_tensors = cat_tensors.to(device) + + if isinstance(num_tensors, list): + num_tensors = [tensor.to(device) for tensor in num_tensors] + else: + num_tensors = num_tensors.to(device) + + # Set model to evaluation mode + self.model.eval() + + # Perform inference + with torch.no_grad(): + predictions = self.model(num_features=num_tensors, cat_features=cat_tensors) + + if not raw: + return self.model.family(predictions).cpu().numpy() + + # Convert predictions to NumPy array and return + else: + return predictions.cpu().numpy() + + def evaluate(self, X, y_true, metrics=None, distribution_family=None): + """ + Evaluate the model on the given data using specified metrics. + + Parameters + ---------- + X : array-like or pd.DataFrame of shape (n_samples, n_features) + The input samples to predict. + y_true : array-like of shape (n_samples,) + The true class labels against which to evaluate the predictions. + metrics : dict + A dictionary where keys are metric names and values are tuples containing the metric function + and a boolean indicating whether the metric requires probability scores (True) or class labels (False). + distribution_family : str, optional + Specifies the distribution family the model is predicting for. If None, it will attempt to infer based + on the model's settings. + + + Returns + ------- + scores : dict + A dictionary with metric names as keys and their corresponding scores as values. + + + Notes + ----- + This method uses either the `predict` or `predict_proba` method depending on the metric requirements. + """ + # Infer distribution family from model settings if not provided + if distribution_family is None: + distribution_family = getattr(self.model, "distribution_family", "normal") + + # Setup default metrics if none are provided + if metrics is None: + metrics = self.get_default_metrics(distribution_family) + + # Make predictions + predictions = self.predict(X, raw=False) + + # Initialize dictionary to store results + scores = {} + + # Compute each metric + for metric_name, metric_func in metrics.items(): + scores[metric_name] = metric_func(y_true, predictions) + + return scores + + def get_default_metrics(self, distribution_family): + """ + Provides default metrics based on the distribution family. + + Parameters + ---------- + distribution_family : str + The distribution family for which to provide default metrics. + + + Returns + ------- + metrics : dict + A dictionary of default metric functions. + """ + default_metrics = { + "normal": { + "MSE": lambda y, pred: mean_squared_error(y, pred[:, 0]), + "CRPS": lambda y, pred: np.mean( + [ + ps.crps_gaussian(y[i], mu=pred[i, 0], sig=np.sqrt(pred[i, 1])) + for i in range(len(y)) + ] + ), + }, + "poisson": {"Poisson Deviance": poisson_deviance}, + "gamma": {"Gamma Deviance": gamma_deviance}, + "beta": {"Brier Score": beta_brier_score}, + "dirichlet": {"Dirichlet Error": dirichlet_error}, + "studentt": {"Student-T Loss": student_t_loss}, + "negativebinom": {"Negative Binomial Deviance": negative_binomial_deviance}, + "inversegamma": {"Inverse Gamma Loss": inverse_gamma_loss}, + "categorical": {"Accuracy": accuracy_score}, + } + return default_metrics.get(distribution_family, {}) diff --git a/mambular/models/sklearn_base_regressor.py b/mambular/models/sklearn_base_regressor.py index c128914..17d3c56 100644 --- a/mambular/models/sklearn_base_regressor.py +++ b/mambular/models/sklearn_base_regressor.py @@ -1,474 +1,474 @@ -import lightning as pl -import pandas as pd -import torch -from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint -from sklearn.base import BaseEstimator -from sklearn.metrics import mean_squared_error -import warnings -from ..base_models.lightning_wrapper import TaskModel -from ..data_utils.datamodule import MambularDataModule -from ..preprocessing import Preprocessor - - -class SklearnBaseRegressor(BaseEstimator): - def __init__(self, model, config, **kwargs): - preprocessor_arg_names = [ - "n_bins", - "numerical_preprocessing", - "use_decision_tree_bins", - "binning_strategy", - "task", - "cat_cutoff", - "treat_all_integers_as_numerical", - "knots", - "degree", - ] - - self.config_kwargs = { - k: v for k, v in kwargs.items() if k not in preprocessor_arg_names - } - self.config = config(**self.config_kwargs) - - preprocessor_kwargs = { - k: v for k, v in kwargs.items() if k in preprocessor_arg_names - } - - self.preprocessor = Preprocessor(**preprocessor_kwargs) - self.model = None - - # Raise a warning if task is set to 'classification' - if preprocessor_kwargs.get("task") == "classification": - warnings.warn( - "The task is set to 'classification'. The Regressor is designed for regression tasks.", - UserWarning, - ) - - self.base_model = model - self.built = False - - def get_params(self, deep=True): - """ - Get parameters for this estimator. Overrides the BaseEstimator method. - - Parameters - ---------- - deep : bool, default=True - If True, returns the parameters for this estimator and contained sub-objects that are estimators. - - Returns - ------- - params : dict - Parameter names mapped to their values. - """ - params = self.config_kwargs # Parameters used to initialize DefaultConfig - - # If deep=True, include parameters from nested components like preprocessor - if deep: - # Assuming Preprocessor has a get_params method - preprocessor_params = { - "preprocessor__" + key: value - for key, value in self.preprocessor.get_params().items() - } - params.update(preprocessor_params) - - return params - - def set_params(self, **parameters): - """ - Set the parameters of this estimator. Overrides the BaseEstimator method. - - Parameters - ---------- - **parameters : dict - Estimator parameters to be set. - - Returns - ------- - self : object - The instance with updated parameters. - """ - # Update config_kwargs with provided parameters - valid_config_keys = self.config_kwargs.keys() - config_updates = {k: v for k, v in parameters.items() if k in valid_config_keys} - self.config_kwargs.update(config_updates) - - # Update the config object - for key, value in config_updates.items(): - setattr(self.config, key, value) - - # Handle preprocessor parameters (prefixed with 'preprocessor__') - preprocessor_params = { - k.split("__")[1]: v - for k, v in parameters.items() - if k.startswith("preprocessor__") - } - if preprocessor_params: - # Assuming Preprocessor has a set_params method - self.preprocessor.set_params(**preprocessor_params) - - return self - - def build_model( - self, - X, - y, - val_size: float = 0.2, - X_val=None, - y_val=None, - random_state: int = 101, - batch_size: int = 128, - shuffle: bool = True, - lr: float = 1e-4, - lr_patience: int = 10, - factor: float = 0.1, - weight_decay: float = 1e-06, - dataloader_kwargs={}, - ): - """ - Builds the model using the provided training data. - - Parameters - ---------- - X : DataFrame or array-like, shape (n_samples, n_features) - The training input samples. - y : array-like, shape (n_samples,) or (n_samples, n_targets) - The target values (real numbers). - val_size : float, default=0.2 - The proportion of the dataset to include in the validation split if `X_val` is None. Ignored if `X_val` is provided. - X_val : DataFrame or array-like, shape (n_samples, n_features), optional - The validation input samples. If provided, `X` and `y` are not split and this data is used for validation. - y_val : array-like, shape (n_samples,) or (n_samples, n_targets), optional - The validation target values. Required if `X_val` is provided. - random_state : int, default=101 - Controls the shuffling applied to the data before applying the split. - batch_size : int, default=64 - Number of samples per gradient update. - shuffle : bool, default=True - Whether to shuffle the training data before each epoch. - lr : float, default=1e-3 - Learning rate for the optimizer. - lr_patience : int, default=10 - Number of epochs with no improvement on the validation loss to wait before reducing the learning rate. - factor : float, default=0.1 - Factor by which the learning rate will be reduced. - weight_decay : float, default=0.025 - Weight decay (L2 penalty) coefficient. - dataloader_kwargs: dict, default={} - The kwargs for the pytorch dataloader class. - - - - Returns - ------- - self : object - The built regressor. - """ - if not isinstance(X, pd.DataFrame): - X = pd.DataFrame(X) - if isinstance(y, pd.Series): - y = y.values - if X_val: - if not isinstance(X_val, pd.DataFrame): - X_val = pd.DataFrame(X_val) - if isinstance(y_val, pd.Series): - y_val = y_val.values - - self.data_module = MambularDataModule( - preprocessor=self.preprocessor, - batch_size=batch_size, - shuffle=shuffle, - X_val=X_val, - y_val=y_val, - val_size=val_size, - random_state=random_state, - regression=True, - **dataloader_kwargs - ) - - self.data_module.preprocess_data( - X, y, X_val, y_val, val_size=val_size, random_state=random_state - ) - - self.model = TaskModel( - model_class=self.base_model, - config=self.config, - cat_feature_info=self.data_module.cat_feature_info, - num_feature_info=self.data_module.num_feature_info, - lr=lr, - lr_patience=lr_patience, - lr_factor=factor, - weight_decay=weight_decay, - ) - - self.built = True - - return self - - def get_number_of_params(self, requires_grad=True): - """ - Calculate the number of parameters in the model. - - Parameters - ---------- - requires_grad : bool, optional - If True, only count the parameters that require gradients (trainable parameters). - If False, count all parameters. Default is True. - - Returns - ------- - int - The total number of parameters in the model. - - Raises - ------ - ValueError - If the model has not been built prior to calling this method. - """ - if not self.built: - raise ValueError( - "The model must be built before the number of parameters can be estimated" - ) - else: - if requires_grad: - return sum( - p.numel() for p in self.model.parameters() if p.requires_grad - ) - else: - return sum(p.numel() for p in self.model.parameters()) - - def fit( - self, - X, - y, - val_size: float = 0.2, - X_val=None, - y_val=None, - max_epochs: int = 100, - random_state: int = 101, - batch_size: int = 128, - shuffle: bool = True, - patience: int = 15, - monitor: str = "val_loss", - mode: str = "min", - lr: float = 1e-4, - lr_patience: int = 10, - factor: float = 0.1, - weight_decay: float = 1e-06, - checkpoint_path="model_checkpoints", - dataloader_kwargs={}, - rebuild=True, - **trainer_kwargs - ): - """ - Trains the regression model using the provided training data. Optionally, a separate validation set can be used. - - Parameters - ---------- - X : DataFrame or array-like, shape (n_samples, n_features) - The training input samples. - y : array-like, shape (n_samples,) or (n_samples, n_targets) - The target values (real numbers). - val_size : float, default=0.2 - The proportion of the dataset to include in the validation split if `X_val` is None. Ignored if `X_val` is provided. - X_val : DataFrame or array-like, shape (n_samples, n_features), optional - The validation input samples. If provided, `X` and `y` are not split and this data is used for validation. - y_val : array-like, shape (n_samples,) or (n_samples, n_targets), optional - The validation target values. Required if `X_val` is provided. - max_epochs : int, default=100 - Maximum number of epochs for training. - random_state : int, default=101 - Controls the shuffling applied to the data before applying the split. - batch_size : int, default=64 - Number of samples per gradient update. - shuffle : bool, default=True - Whether to shuffle the training data before each epoch. - patience : int, default=10 - Number of epochs with no improvement on the validation loss to wait before early stopping. - monitor : str, default="val_loss" - The metric to monitor for early stopping. - mode : str, default="min" - Whether the monitored metric should be minimized (`min`) or maximized (`max`). - lr : float, default=1e-3 - Learning rate for the optimizer. - lr_patience : int, default=10 - Number of epochs with no improvement on the validation loss to wait before reducing the learning rate. - factor : float, default=0.1 - Factor by which the learning rate will be reduced. - weight_decay : float, default=0.025 - Weight decay (L2 penalty) coefficient. - checkpoint_path : str, default="model_checkpoints" - Path where the checkpoints are being saved. - dataloader_kwargs: dict, default={} - The kwargs for the pytorch dataloader class. - **trainer_kwargs : Additional keyword arguments for PyTorch Lightning's Trainer class. - - - Returns - ------- - self : object - The fitted regressor. - """ - if rebuild: - if not isinstance(X, pd.DataFrame): - X = pd.DataFrame(X) - if isinstance(y, pd.Series): - y = y.values - if X_val: - if not isinstance(X_val, pd.DataFrame): - X_val = pd.DataFrame(X_val) - if isinstance(y_val, pd.Series): - y_val = y_val.values - - self.data_module = MambularDataModule( - preprocessor=self.preprocessor, - batch_size=batch_size, - shuffle=shuffle, - X_val=X_val, - y_val=y_val, - val_size=val_size, - random_state=random_state, - regression=True, - **dataloader_kwargs - ) - - self.data_module.preprocess_data( - X, y, X_val, y_val, val_size=val_size, random_state=random_state - ) - - self.model = TaskModel( - model_class=self.base_model, - config=self.config, - cat_feature_info=self.data_module.cat_feature_info, - num_feature_info=self.data_module.num_feature_info, - lr=lr, - lr_patience=lr_patience, - lr_factor=factor, - weight_decay=weight_decay, - ) - - else: - assert self.built, "The model must be built before calling the fit method." - - early_stop_callback = EarlyStopping( - monitor=monitor, min_delta=0.00, patience=patience, verbose=False, mode=mode - ) - - checkpoint_callback = ModelCheckpoint( - monitor="val_loss", # Adjust according to your validation metric - mode="min", - save_top_k=1, - dirpath=checkpoint_path, # Specify the directory to save checkpoints - filename="best_model", - ) - - # Initialize the trainer and train the model - trainer = pl.Trainer( - max_epochs=max_epochs, - callbacks=[early_stop_callback, checkpoint_callback], - **trainer_kwargs - ) - trainer.fit(self.model, self.data_module) - - best_model_path = checkpoint_callback.best_model_path - if best_model_path: - checkpoint = torch.load(best_model_path) - self.model.load_state_dict(checkpoint["state_dict"]) - - return self - - def predict(self, X): - """ - Predicts target values for the given input samples. - - Parameters - ---------- - X : DataFrame or array-like, shape (n_samples, n_features) - The input samples for which to predict target values. - - - Returns - ------- - predictions : ndarray, shape (n_samples,) or (n_samples, n_outputs) - The predicted target values. - """ - # Ensure model and data module are initialized - if self.model is None or self.data_module is None: - raise ValueError("The model or data module has not been fitted yet.") - - # Preprocess the data using the data module - cat_tensors, num_tensors = self.data_module.preprocess_test_data(X) - - # Move tensors to appropriate device - device = next(self.model.parameters()).device - if isinstance(cat_tensors, list): - cat_tensors = [tensor.to(device) for tensor in cat_tensors] - else: - cat_tensors = cat_tensors.to(device) - - if isinstance(num_tensors, list): - num_tensors = [tensor.to(device) for tensor in num_tensors] - else: - num_tensors = num_tensors.to(device) - - # Set model to evaluation mode - self.model.eval() - - # Perform inference - with torch.no_grad(): - predictions = self.model(num_features=num_tensors, cat_features=cat_tensors) - - # Convert predictions to NumPy array and return - return predictions.cpu().numpy() - - def evaluate(self, X, y_true, metrics=None): - """ - Evaluate the model on the given data using specified metrics. - - Parameters - ---------- - X : array-like or pd.DataFrame of shape (n_samples, n_features) - The input samples to predict. - y_true : array-like of shape (n_samples,) or (n_samples, n_outputs) - The true target values against which to evaluate the predictions. - metrics : dict - A dictionary where keys are metric names and values are the metric functions. - - - Notes - ----- - This method uses the `predict` method to generate predictions and computes each metric. - - - Examples - -------- - >>> from sklearn.metrics import mean_squared_error, r2_score - >>> from sklearn.model_selection import train_test_split - >>> from mambular.models import MambularRegressor - >>> metrics = { - ... 'Mean Squared Error': mean_squared_error, - ... 'R2 Score': r2_score - ... } - >>> # Assuming 'X_test' and 'y_test' are your test dataset and labels - >>> # Evaluate using the specified metrics - >>> results = regressor.evaluate(X_test, y_test, metrics=metrics) - - - Returns - ------- - scores : dict - A dictionary with metric names as keys and their corresponding scores as values. - """ - if metrics is None: - metrics = {"Mean Squared Error": mean_squared_error} - - # Generate predictions using the trained model - predictions = self.predict(X) - - # Initialize dictionary to store results - scores = {} - - # Compute each metric - for metric_name, metric_func in metrics.items(): - scores[metric_name] = metric_func(y_true, predictions) - - return scores +import lightning as pl +import pandas as pd +import torch +from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint +from sklearn.base import BaseEstimator +from sklearn.metrics import mean_squared_error +import warnings +from ..base_models.lightning_wrapper import TaskModel +from ..data_utils.datamodule import MambularDataModule +from ..preprocessing import Preprocessor + + +class SklearnBaseRegressor(BaseEstimator): + def __init__(self, model, config, **kwargs): + preprocessor_arg_names = [ + "n_bins", + "numerical_preprocessing", + "use_decision_tree_bins", + "binning_strategy", + "task", + "cat_cutoff", + "treat_all_integers_as_numerical", + "knots", + "degree", + ] + + self.config_kwargs = { + k: v for k, v in kwargs.items() if k not in preprocessor_arg_names + } + self.config = config(**self.config_kwargs) + + preprocessor_kwargs = { + k: v for k, v in kwargs.items() if k in preprocessor_arg_names + } + + self.preprocessor = Preprocessor(**preprocessor_kwargs) + self.model = None + + # Raise a warning if task is set to 'classification' + if preprocessor_kwargs.get("task") == "classification": + warnings.warn( + "The task is set to 'classification'. The Regressor is designed for regression tasks.", + UserWarning, + ) + + self.base_model = model + self.built = False + + def get_params(self, deep=True): + """ + Get parameters for this estimator. Overrides the BaseEstimator method. + + Parameters + ---------- + deep : bool, default=True + If True, returns the parameters for this estimator and contained sub-objects that are estimators. + + Returns + ------- + params : dict + Parameter names mapped to their values. + """ + params = self.config_kwargs # Parameters used to initialize DefaultConfig + + # If deep=True, include parameters from nested components like preprocessor + if deep: + # Assuming Preprocessor has a get_params method + preprocessor_params = { + "preprocessor__" + key: value + for key, value in self.preprocessor.get_params().items() + } + params.update(preprocessor_params) + + return params + + def set_params(self, **parameters): + """ + Set the parameters of this estimator. Overrides the BaseEstimator method. + + Parameters + ---------- + **parameters : dict + Estimator parameters to be set. + + Returns + ------- + self : object + The instance with updated parameters. + """ + # Update config_kwargs with provided parameters + valid_config_keys = self.config_kwargs.keys() + config_updates = {k: v for k, v in parameters.items() if k in valid_config_keys} + self.config_kwargs.update(config_updates) + + # Update the config object + for key, value in config_updates.items(): + setattr(self.config, key, value) + + # Handle preprocessor parameters (prefixed with 'preprocessor__') + preprocessor_params = { + k.split("__")[1]: v + for k, v in parameters.items() + if k.startswith("preprocessor__") + } + if preprocessor_params: + # Assuming Preprocessor has a set_params method + self.preprocessor.set_params(**preprocessor_params) + + return self + + def build_model( + self, + X, + y, + val_size: float = 0.2, + X_val=None, + y_val=None, + random_state: int = 101, + batch_size: int = 128, + shuffle: bool = True, + lr: float = 1e-4, + lr_patience: int = 10, + factor: float = 0.1, + weight_decay: float = 1e-06, + dataloader_kwargs={}, + ): + """ + Builds the model using the provided training data. + + Parameters + ---------- + X : DataFrame or array-like, shape (n_samples, n_features) + The training input samples. + y : array-like, shape (n_samples,) or (n_samples, n_targets) + The target values (real numbers). + val_size : float, default=0.2 + The proportion of the dataset to include in the validation split if `X_val` is None. Ignored if `X_val` is provided. + X_val : DataFrame or array-like, shape (n_samples, n_features), optional + The validation input samples. If provided, `X` and `y` are not split and this data is used for validation. + y_val : array-like, shape (n_samples,) or (n_samples, n_targets), optional + The validation target values. Required if `X_val` is provided. + random_state : int, default=101 + Controls the shuffling applied to the data before applying the split. + batch_size : int, default=64 + Number of samples per gradient update. + shuffle : bool, default=True + Whether to shuffle the training data before each epoch. + lr : float, default=1e-3 + Learning rate for the optimizer. + lr_patience : int, default=10 + Number of epochs with no improvement on the validation loss to wait before reducing the learning rate. + factor : float, default=0.1 + Factor by which the learning rate will be reduced. + weight_decay : float, default=0.025 + Weight decay (L2 penalty) coefficient. + dataloader_kwargs: dict, default={} + The kwargs for the pytorch dataloader class. + + + + Returns + ------- + self : object + The built regressor. + """ + if not isinstance(X, pd.DataFrame): + X = pd.DataFrame(X) + if isinstance(y, pd.Series): + y = y.values + if X_val: + if not isinstance(X_val, pd.DataFrame): + X_val = pd.DataFrame(X_val) + if isinstance(y_val, pd.Series): + y_val = y_val.values + + self.data_module = MambularDataModule( + preprocessor=self.preprocessor, + batch_size=batch_size, + shuffle=shuffle, + X_val=X_val, + y_val=y_val, + val_size=val_size, + random_state=random_state, + regression=True, + **dataloader_kwargs + ) + + self.data_module.preprocess_data( + X, y, X_val, y_val, val_size=val_size, random_state=random_state + ) + + self.model = TaskModel( + model_class=self.base_model, + config=self.config, + cat_feature_info=self.data_module.cat_feature_info, + num_feature_info=self.data_module.num_feature_info, + lr=lr, + lr_patience=lr_patience, + lr_factor=factor, + weight_decay=weight_decay, + ) + + self.built = True + + return self + + def get_number_of_params(self, requires_grad=True): + """ + Calculate the number of parameters in the model. + + Parameters + ---------- + requires_grad : bool, optional + If True, only count the parameters that require gradients (trainable parameters). + If False, count all parameters. Default is True. + + Returns + ------- + int + The total number of parameters in the model. + + Raises + ------ + ValueError + If the model has not been built prior to calling this method. + """ + if not self.built: + raise ValueError( + "The model must be built before the number of parameters can be estimated" + ) + else: + if requires_grad: + return sum( + p.numel() for p in self.model.parameters() if p.requires_grad + ) + else: + return sum(p.numel() for p in self.model.parameters()) + + def fit( + self, + X, + y, + val_size: float = 0.2, + X_val=None, + y_val=None, + max_epochs: int = 100, + random_state: int = 101, + batch_size: int = 128, + shuffle: bool = True, + patience: int = 15, + monitor: str = "val_loss", + mode: str = "min", + lr: float = 1e-4, + lr_patience: int = 10, + factor: float = 0.1, + weight_decay: float = 1e-06, + checkpoint_path="model_checkpoints", + dataloader_kwargs={}, + rebuild=True, + **trainer_kwargs + ): + """ + Trains the regression model using the provided training data. Optionally, a separate validation set can be used. + + Parameters + ---------- + X : DataFrame or array-like, shape (n_samples, n_features) + The training input samples. + y : array-like, shape (n_samples,) or (n_samples, n_targets) + The target values (real numbers). + val_size : float, default=0.2 + The proportion of the dataset to include in the validation split if `X_val` is None. Ignored if `X_val` is provided. + X_val : DataFrame or array-like, shape (n_samples, n_features), optional + The validation input samples. If provided, `X` and `y` are not split and this data is used for validation. + y_val : array-like, shape (n_samples,) or (n_samples, n_targets), optional + The validation target values. Required if `X_val` is provided. + max_epochs : int, default=100 + Maximum number of epochs for training. + random_state : int, default=101 + Controls the shuffling applied to the data before applying the split. + batch_size : int, default=64 + Number of samples per gradient update. + shuffle : bool, default=True + Whether to shuffle the training data before each epoch. + patience : int, default=10 + Number of epochs with no improvement on the validation loss to wait before early stopping. + monitor : str, default="val_loss" + The metric to monitor for early stopping. + mode : str, default="min" + Whether the monitored metric should be minimized (`min`) or maximized (`max`). + lr : float, default=1e-3 + Learning rate for the optimizer. + lr_patience : int, default=10 + Number of epochs with no improvement on the validation loss to wait before reducing the learning rate. + factor : float, default=0.1 + Factor by which the learning rate will be reduced. + weight_decay : float, default=0.025 + Weight decay (L2 penalty) coefficient. + checkpoint_path : str, default="model_checkpoints" + Path where the checkpoints are being saved. + dataloader_kwargs: dict, default={} + The kwargs for the pytorch dataloader class. + **trainer_kwargs : Additional keyword arguments for PyTorch Lightning's Trainer class. + + + Returns + ------- + self : object + The fitted regressor. + """ + if (not self.built) or (self.built and rebuild): + if not isinstance(X, pd.DataFrame): + X = pd.DataFrame(X) + if isinstance(y, pd.Series): + y = y.values + if X_val: + if not isinstance(X_val, pd.DataFrame): + X_val = pd.DataFrame(X_val) + if isinstance(y_val, pd.Series): + y_val = y_val.values + + self.data_module = MambularDataModule( + preprocessor=self.preprocessor, + batch_size=batch_size, + shuffle=shuffle, + X_val=X_val, + y_val=y_val, + val_size=val_size, + random_state=random_state, + regression=True, + **dataloader_kwargs + ) + + self.data_module.preprocess_data( + X, y, X_val, y_val, val_size=val_size, random_state=random_state + ) + + self.model = TaskModel( + model_class=self.base_model, + config=self.config, + cat_feature_info=self.data_module.cat_feature_info, + num_feature_info=self.data_module.num_feature_info, + lr=lr, + lr_patience=lr_patience, + lr_factor=factor, + weight_decay=weight_decay, + ) + + else: + assert self.built, "The model must be built before calling the fit method." + + early_stop_callback = EarlyStopping( + monitor=monitor, min_delta=0.00, patience=patience, verbose=False, mode=mode + ) + + checkpoint_callback = ModelCheckpoint( + monitor="val_loss", # Adjust according to your validation metric + mode="min", + save_top_k=1, + dirpath=checkpoint_path, # Specify the directory to save checkpoints + filename="best_model", + ) + + # Initialize the trainer and train the model + trainer = pl.Trainer( + max_epochs=max_epochs, + callbacks=[early_stop_callback, checkpoint_callback], + **trainer_kwargs + ) + trainer.fit(self.model, self.data_module) + + best_model_path = checkpoint_callback.best_model_path + if best_model_path: + checkpoint = torch.load(best_model_path) + self.model.load_state_dict(checkpoint["state_dict"]) + + return self + + def predict(self, X): + """ + Predicts target values for the given input samples. + + Parameters + ---------- + X : DataFrame or array-like, shape (n_samples, n_features) + The input samples for which to predict target values. + + + Returns + ------- + predictions : ndarray, shape (n_samples,) or (n_samples, n_outputs) + The predicted target values. + """ + # Ensure model and data module are initialized + if self.model is None or self.data_module is None: + raise ValueError("The model or data module has not been fitted yet.") + + # Preprocess the data using the data module + cat_tensors, num_tensors = self.data_module.preprocess_test_data(X) + + # Move tensors to appropriate device + device = next(self.model.parameters()).device + if isinstance(cat_tensors, list): + cat_tensors = [tensor.to(device) for tensor in cat_tensors] + else: + cat_tensors = cat_tensors.to(device) + + if isinstance(num_tensors, list): + num_tensors = [tensor.to(device) for tensor in num_tensors] + else: + num_tensors = num_tensors.to(device) + + # Set model to evaluation mode + self.model.eval() + + # Perform inference + with torch.no_grad(): + predictions = self.model(num_features=num_tensors, cat_features=cat_tensors) + + # Convert predictions to NumPy array and return + return predictions.cpu().numpy() + + def evaluate(self, X, y_true, metrics=None): + """ + Evaluate the model on the given data using specified metrics. + + Parameters + ---------- + X : array-like or pd.DataFrame of shape (n_samples, n_features) + The input samples to predict. + y_true : array-like of shape (n_samples,) or (n_samples, n_outputs) + The true target values against which to evaluate the predictions. + metrics : dict + A dictionary where keys are metric names and values are the metric functions. + + + Notes + ----- + This method uses the `predict` method to generate predictions and computes each metric. + + + Examples + -------- + >>> from sklearn.metrics import mean_squared_error, r2_score + >>> from sklearn.model_selection import train_test_split + >>> from mambular.models import MambularRegressor + >>> metrics = { + ... 'Mean Squared Error': mean_squared_error, + ... 'R2 Score': r2_score + ... } + >>> # Assuming 'X_test' and 'y_test' are your test dataset and labels + >>> # Evaluate using the specified metrics + >>> results = regressor.evaluate(X_test, y_test, metrics=metrics) + + + Returns + ------- + scores : dict + A dictionary with metric names as keys and their corresponding scores as values. + """ + if metrics is None: + metrics = {"Mean Squared Error": mean_squared_error} + + # Generate predictions using the trained model + predictions = self.predict(X) + + # Initialize dictionary to store results + scores = {} + + # Compute each metric + for metric_name, metric_func in metrics.items(): + scores[metric_name] = metric_func(y_true, predictions) + + return scores