diff --git a/tests/test_ChainedBSScheduler.py b/tests/test_ChainedBSScheduler.py index 72df1d7..0ab20b7 100644 --- a/tests/test_ChainedBSScheduler.py +++ b/tests/test_ChainedBSScheduler.py @@ -24,7 +24,7 @@ def test_dataloader_lengths(self): n_epochs = 10 epoch_lengths = simulate_n_epochs(dataloader, scheduler, n_epochs) - expected_batch_sizes = [100, 110, 121, 133, 14, 16, 18, 20, 22, 24] + expected_batch_sizes = [100, 110, 121, 133, 14, 16, 17, 19, 21, 23] expected_lengths = self.compute_epoch_lengths(expected_batch_sizes, len(self.dataset), drop_last=False) self.assertEqual(epoch_lengths, expected_lengths) @@ -37,7 +37,7 @@ def test_dataloader_batch_size(self): n_epochs = 10 batch_sizes = get_batch_sizes_across_epochs(dataloader, scheduler, n_epochs) - expected_batch_sizes = [100, 110, 121, 133, 14, 16, 18, 20, 22, 24] + expected_batch_sizes = [100, 110, 121, 133, 14, 16, 17, 19, 21, 23] self.assertEqual(batch_sizes, expected_batch_sizes)