Skip to content

Commit

Permalink
ModelBase update
Browse files Browse the repository at this point in the history
  • Loading branch information
ChanLumerico committed Aug 7, 2024
1 parent afd9d82 commit 8853650
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions luma/neural/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def __init__(
patience: int,
deep_verbose: bool,
shuffle: bool,
random_state: bool | None
random_state: bool | None,
) -> None:
self.batch_size = batch_size
self.n_epochs = n_epochs
Expand Down Expand Up @@ -401,14 +401,14 @@ def eval(self, X: TensorLike, y: TensorLike) -> list[float]:
valid_loss.append(loss)

return valid_loss

def set_optimizer(self, optimizer: Optimizer, **params: Any) -> None:
self.model.set_optimizer(optimizer, **params)

def set_lr_scheduler(self, scheduler: Scheduler, **params: Any) -> None:
scheduler.set_params(**params)
self.lr_scheduler = scheduler

def set_loss(self, loss: Loss) -> None:
self.loss = loss

Expand All @@ -419,7 +419,7 @@ def update_lr(
train_loss: float,
valid_loss: float,
) -> None:
if mode != self.lr_scheduler.type_:
if self.lr_scheduler is None or mode != self.lr_scheduler.type_:
return

self.lr_scheduler.broadcast(
Expand Down

0 comments on commit 8853650

Please sign in to comment.