Skip to content

Commit

Permalink
add .built attr to model classed
Browse files Browse the repository at this point in the history
  • Loading branch information
AnFreTh committed Jul 16, 2024
1 parent 8af86a2 commit 2029700
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 34 deletions.
1 change: 1 addition & 0 deletions mambular/models/sklearn_base_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(self, model, config, **kwargs):
)

self.base_model = model
self.built = False

def get_params(self, deep=True):
"""
Expand Down
71 changes: 37 additions & 34 deletions mambular/models/sklearn_base_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(self, model, config, **kwargs):
)

self.base_model = model
self.built = False

def get_params(self, deep=True):
"""
Expand Down Expand Up @@ -255,6 +256,7 @@ def fit(
weight_decay: float = 1e-06,
checkpoint_path="model_checkpoints",
dataloader_kwargs={},
rebuild=True,
**trainer_kwargs
):
"""
Expand Down Expand Up @@ -306,42 +308,43 @@ def fit(
self : object
The fitted regressor.
"""
if not isinstance(X, pd.DataFrame):
X = pd.DataFrame(X)
if isinstance(y, pd.Series):
y = y.values
if X_val:
if not isinstance(X_val, pd.DataFrame):
X_val = pd.DataFrame(X_val)
if isinstance(y_val, pd.Series):
y_val = y_val.values

self.data_module = MambularDataModule(
preprocessor=self.preprocessor,
batch_size=batch_size,
shuffle=shuffle,
X_val=X_val,
y_val=y_val,
val_size=val_size,
random_state=random_state,
regression=True,
**dataloader_kwargs
)
if not self.built and not rebuild:
if not isinstance(X, pd.DataFrame):
X = pd.DataFrame(X)
if isinstance(y, pd.Series):
y = y.values
if X_val:
if not isinstance(X_val, pd.DataFrame):
X_val = pd.DataFrame(X_val)
if isinstance(y_val, pd.Series):
y_val = y_val.values

self.data_module = MambularDataModule(
preprocessor=self.preprocessor,
batch_size=batch_size,
shuffle=shuffle,
X_val=X_val,
y_val=y_val,
val_size=val_size,
random_state=random_state,
regression=False,
**dataloader_kwargs
)

self.data_module.preprocess_data(
X, y, X_val, y_val, val_size=val_size, random_state=random_state
)
self.data_module.preprocess_data(
X, y, X_val, y_val, val_size=val_size, random_state=random_state
)

self.model = TaskModel(
model_class=self.base_model,
config=self.config,
cat_feature_info=self.data_module.cat_feature_info,
num_feature_info=self.data_module.num_feature_info,
lr=lr,
lr_patience=lr_patience,
lr_factor=factor,
weight_decay=weight_decay,
)
self.model = TaskModel(
model_class=self.base_model,
config=self.config,
cat_feature_info=self.data_module.cat_feature_info,
num_feature_info=self.data_module.num_feature_info,
lr=lr,
lr_patience=lr_patience,
lr_factor=factor,
weight_decay=weight_decay,
)

early_stop_callback = EarlyStopping(
monitor=monitor, min_delta=0.00, patience=patience, verbose=False, mode=mode
Expand Down

0 comments on commit 2029700

Please sign in to comment.