Skip to content

Commit

Permalink
Made tests faster
Browse files Browse the repository at this point in the history
  • Loading branch information
ancestor-mithril committed Dec 6, 2024
1 parent 67d8f10 commit 8256b45
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 13 deletions.
4 changes: 2 additions & 2 deletions tests/test_CosineAnnealingBS.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_dataloader_lengths(self):
def test_dataloader_batch_size(self):
base_batch_size = 10
total_iters = 50
n_epochs = 500
n_epochs = 200
max_batch_size = 100
dataloader = create_dataloader(self.dataset, batch_size=base_batch_size)
scheduler = CosineAnnealingBS(dataloader, total_iters=total_iters, max_batch_size=max_batch_size)
Expand All @@ -47,7 +47,7 @@ def test_dataloader_batch_size(self):
47, 49, 52, 55, 58, 61, 63, 66, 69, 72, 74, 77, 79, 81, 84, 86, 88, 90, 91, 93, 94, 96,
97, 98, 99, 99, 100, 100, 100, 100, 100, 99, 99, 98, 97, 96, 94, 93, 91, 90, 88, 86, 84,
81, 79, 77, 74, 72, 69, 66, 63, 61, 58, 55, 52, 49, 47, 44, 41, 38, 36, 33, 31, 29, 26,
24, 22, 20, 19, 17, 16, 14, 13, 12, 11, 11, 10, 10] * 5
24, 22, 20, 19, 17, 16, 14, 13, 12, 11, 11, 10, 10] * 2

self.assertEqual(batch_sizes, expected_batch_sizes)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_CyclicBS.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,11 +201,11 @@ def test_graphic_exp_range(self):
base_batch_size = 100
dataloader = create_dataloader(self.dataset, batch_size=base_batch_size)
max_batch_size = 200
step_size_down = 50
step_size_down = 25
gamma = 0.9
scheduler = CyclicBS(dataloader, base_batch_size=base_batch_size, max_batch_size=max_batch_size,
step_size_down=step_size_down, mode='exp_range', gamma=gamma)
n_epochs = 10 * step_size_down
n_epochs = 6 * step_size_down

batch_sizes = get_batch_sizes_across_epochs(dataloader, scheduler, n_epochs)
plt.plot(batch_sizes)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_LambdaBS.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_sanity(self):
"always be equal to the inferred length except for Iterable Datasets for "
"which the __len__ could be inaccurate.")

dataloader.batch_sampler.batch_size = 526
dataloader.batch_sampler.batch_size = 256
real, inferred = iterate(dataloader)
self.assertEqual(real, inferred, "Dataloader __len__ does not return the real length. The real length should "
"always be equal to the inferred length except for Iterable Datasets for "
Expand All @@ -46,7 +46,7 @@ def test_dataloader_batch_size(self):
dataloader = create_dataloader(self.dataset, batch_size=self.base_batch_size)
fn = lambda epoch: 10 * epoch # noqa: E731
scheduler = LambdaBS(dataloader, fn)
n_epochs = 15
n_epochs = 10

batch_sizes = get_batch_sizes_across_epochs(dataloader, scheduler, n_epochs)
expected_batch_sizes = self.compute_expected_batch_sizes(n_epochs, self.base_batch_size, fn,
Expand Down
10 changes: 5 additions & 5 deletions tests/test_MultiStepBS.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ def test_dataloader_lengths(self):

def test_dataloader_batch_size(self):
dataloader = create_dataloader(self.dataset, batch_size=self.base_batch_size)
milestones = [5, 10, 10, 12]
milestones = [5, 7, 7, 9]
gamma = 3.0
scheduler = MultiStepBS(dataloader, milestones=milestones, gamma=gamma, max_batch_size=5000, verbose=False)
n_epochs = 15
n_epochs = 10

batch_sizes = get_batch_sizes_across_epochs(dataloader, scheduler, n_epochs)
expected_batch_sizes = self.compute_expected_batch_sizes(n_epochs, self.base_batch_size, milestones, gamma,
Expand All @@ -60,7 +60,7 @@ def test_dataloader_batch_size(self):

def test_loading_and_unloading(self):
dataloader = create_dataloader(self.dataset)
milestones = [5, 10, 10, 12]
milestones = [5, 7, 7, 9]
gamma = 3.0
scheduler = MultiStepBS(dataloader, milestones=milestones, gamma=gamma, max_batch_size=5000, verbose=False)

Expand All @@ -76,10 +76,10 @@ def test_graphic(self):
warnings.filterwarnings("ignore", category=UserWarning)

dataloader = create_dataloader(self.dataset, batch_size=self.base_batch_size)
milestones = [5, 10, 10, 12]
milestones = [5, 7, 7, 9]
gamma = 3.0
scheduler = MultiStepBS(dataloader, milestones=milestones, gamma=gamma, max_batch_size=5000, verbose=False)
n_epochs = 15
n_epochs = 10

batch_sizes = get_batch_sizes_across_epochs(dataloader, scheduler, n_epochs)
plt.plot(batch_sizes)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_MultiplicativeBS.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_dataloader_batch_size(self):
dataloader = create_dataloader(self.dataset, batch_size=self.base_batch_size)
fn = lambda epoch: epoch / 100 + 2 # noqa: E731
scheduler = MultiplicativeBS(dataloader, fn, max_batch_size=5000, verbose=False)
n_epochs = 15
n_epochs = 10

batch_sizes = get_batch_sizes_across_epochs(dataloader, scheduler, n_epochs)
expected_batch_sizes = self.compute_expected_batch_sizes(n_epochs, self.base_batch_size, fn,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_StepBS.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def test_graphic(self):
step_size = 50
gamma = 1.1
scheduler = StepBS(dataloader, step_size=step_size, gamma=gamma)
n_epochs = 300
n_epochs = 200

batch_sizes = get_batch_sizes_across_epochs(dataloader, scheduler, n_epochs)
plt.plot(batch_sizes)
Expand Down

0 comments on commit 8256b45

Please sign in to comment.