Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Nov 4, 2024
1 parent d9a620a commit 8bbe907
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 6 deletions.
2 changes: 1 addition & 1 deletion tests/tests_pytorch/callbacks/test_prediction_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_prediction_writer_batch_indices(num_workers, tmp_path):
DummyPredictionWriter.write_on_batch_end = Mock()
DummyPredictionWriter.write_on_epoch_end = Mock()

dataloader = DataLoader(RandomDataset(32, 64), batch_size=4, num_workers=num_workers, persistent_workers=True)
dataloader = DataLoader(RandomDataset(32, 64), batch_size=4, num_workers=num_workers, persistent_workers=num_workers > 0)
model = BoringModel()
writer = DummyPredictionWriter("batch_and_epoch")
trainer = Trainer(default_root_dir=tmp_path, logger=False, limit_predict_batches=4, callbacks=writer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def test_dataloader_persistent_workers_performance_warning(num_workers, tmp_path
barebones=True,
)
model = TestSpawnBoringModel(warning_expected=(num_workers > 0))
dataloader = DataLoader(RandomDataset(32, 64), num_workers=num_workers, persistent_workers=True)
dataloader = DataLoader(RandomDataset(32, 64), num_workers=num_workers, persistent_workers=num_workers > 0)
trainer.fit(model, dataloader)


Expand Down
5 changes: 1 addition & 4 deletions tests/tests_pytorch/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,10 +658,7 @@ def on_train_epoch_end(self):
def test_auto_add_worker_init_fn_distributed(tmp_path, monkeypatch):
"""Test that the lightning worker_init_fn takes care of dataloaders in multi-gpu/multi-node training."""
dataset = NumpyRandomDataset()
num_workers = 2
batch_size = 2

dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, persistent_workers=True)
dataloader = DataLoader(dataset, batch_size=2, num_workers=2, persistent_workers=True)
seed_everything(0, workers=True)
trainer = Trainer(default_root_dir=tmp_path, max_epochs=1, accelerator="gpu", devices=2, strategy="ddp_spawn")
model = MultiProcessModel()
Expand Down

0 comments on commit 8bbe907

Please sign in to comment.