diff --git a/bs_scheduler/batch_size_schedulers.py b/bs_scheduler/batch_size_schedulers.py index 0e55923..0306e08 100644 --- a/bs_scheduler/batch_size_schedulers.py +++ b/bs_scheduler/batch_size_schedulers.py @@ -952,8 +952,7 @@ class IncreaseBSOnPlateau(BSScheduler): Increases the batch size when a metric has stopped improving. Models often benefit from increasing the batch size by a factor once the learning stagnates. This scheduler receives a metric value and if no improvement is seen for a given number of epochs, the batch size is increased. - Unfortunately, this class is not compatible with the other batch size schedulers as its step() function needs to - receive the metric value. + The step() function needs to receive the metric value using the `metrics` keyword argument. Args: dataloader (DataLoader): Wrapped dataloader. @@ -1068,9 +1067,9 @@ def get_new_bs(self, **kwargs) -> int: if self.last_epoch == 0: # Don't do anything at initialization. return self.batch_size - metric = kwargs.pop('metric', None) + metric = kwargs.pop('metrics', None) if metric is None: - raise TypeError("IncreaseBSOnPlateau requires passing a 'metric' keyword argument in the step() function.") + raise TypeError("IncreaseBSOnPlateau requires passing a 'metrics' keyword argument in the step() function.") current = float(metric) if self.is_better(current, self.best, self.threshold): diff --git a/tests/test_IncreaseBSOnPlateau.py b/tests/test_IncreaseBSOnPlateau.py index 2ddcee5..54c8dc4 100644 --- a/tests/test_IncreaseBSOnPlateau.py +++ b/tests/test_IncreaseBSOnPlateau.py @@ -16,7 +16,7 @@ def test_constant_metric(self): max_batch_size = 100 n_epochs = 100 - metrics = [{"metric": 0.1}] * n_epochs + metrics = [{"metrics": 0.1}] * n_epochs dataloader = create_dataloader(self.dataset, batch_size=base_batch_size) scheduler1 = IncreaseBSOnPlateau(dataloader, mode='min', threshold_mode='rel', max_batch_size=max_batch_size) @@ -55,7 +55,7 @@ def test_loading_and_unloading(self): self.reloading_scheduler(scheduler) self.torch_save_and_load(scheduler) - scheduler.step(metric=10) + scheduler.step(metrics=10) self.assertEqual(scheduler.mode, mode) self.assertEqual(scheduler.threshold_mode, threshold_mode) @@ -69,7 +69,7 @@ def test_graphic(self): max_batch_size = 100 n_epochs = 100 - metrics = [{"metric": 0.1}] * n_epochs + metrics = [{"metrics": 0.1}] * n_epochs dataloader = create_dataloader(self.dataset, batch_size=base_batch_size) scheduler = IncreaseBSOnPlateau(dataloader, mode='min', threshold_mode='rel', max_batch_size=max_batch_size)