Skip to content

Commit

Permalink
try: persistent_workers=True
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Nov 4, 2024
1 parent 897b2af commit fed9783
Show file tree
Hide file tree
Showing 7 changed files with 9 additions and 7 deletions.
1 change: 1 addition & 0 deletions tests/parity_fabric/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def get_dataloader(self):
dataset,
batch_size=self.batch_size,
num_workers=2,
persistent_workers=True,
)

def get_loss_function(self):
Expand Down
1 change: 1 addition & 0 deletions tests/parity_pytorch/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,5 @@ def train_dataloader(self):
CIFAR10(root=_PATH_DATASETS, train=True, download=True, transform=self.transform),
batch_size=32,
num_workers=1,
persistent_workers=True,
)
4 changes: 2 additions & 2 deletions tests/tests_fabric/utilities/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,9 +638,9 @@ def test_suggested_max_num_workers_not_triggering_torch_warning(local_world_size

# The dataloader runs a check in `DataLoader.check_worker_number_rationality`
with pytest.warns(UserWarning, match="This DataLoader will create"):
DataLoader(range(2), num_workers=(cpu_count + 1))
DataLoader(range(2), num_workers=(cpu_count + 1), persistent_workers=True)
with no_warning_call():
DataLoader(range(2), num_workers=suggested_max_num_workers(local_world_size))
DataLoader(range(2), num_workers=suggested_max_num_workers(local_world_size), persistent_workers=True)


def test_state():
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/callbacks/test_prediction_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_prediction_writer_batch_indices(num_workers, tmp_path):
DummyPredictionWriter.write_on_batch_end = Mock()
DummyPredictionWriter.write_on_epoch_end = Mock()

dataloader = DataLoader(RandomDataset(32, 64), batch_size=4, num_workers=num_workers)
dataloader = DataLoader(RandomDataset(32, 64), batch_size=4, num_workers=num_workers, persistent_workers=True)
model = BoringModel()
writer = DummyPredictionWriter("batch_and_epoch")
trainer = Trainer(default_root_dir=tmp_path, logger=False, limit_predict_batches=4, callbacks=writer)
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/helpers/advanced_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,4 +218,4 @@ def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)

def train_dataloader(self):
return DataLoader(MNIST(root=_PATH_DATASETS, train=True, download=True), batch_size=128, num_workers=1)
return DataLoader(MNIST(root=_PATH_DATASETS, train=True, download=True), batch_size=128, num_workers=1, persistent_workers=True)
4 changes: 2 additions & 2 deletions tests/tests_pytorch/trainer/connectors/test_data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def test_dataloader_persistent_workers_performance_warning(num_workers, tmp_path
barebones=True,
)
model = TestSpawnBoringModel(warning_expected=(num_workers > 0))
dataloader = DataLoader(RandomDataset(32, 64), num_workers=num_workers)
dataloader = DataLoader(RandomDataset(32, 64), num_workers=num_workers, persistent_workers=True)
trainer.fit(model, dataloader)


Expand Down Expand Up @@ -252,7 +252,7 @@ def test_update_dataloader_with_multiprocessing_context():
"""This test verifies that `use_distributed_sampler` conserves multiprocessing context."""
train = RandomDataset(32, 64)
context = "spawn"
train = DataLoader(train, batch_size=32, num_workers=2, multiprocessing_context=context, shuffle=True)
train = DataLoader(train, batch_size=32, num_workers=2, multiprocessing_context=context, shuffle=True, persistent_workers=True)
new_data_loader = _update_dataloader(train, SequentialSampler(train.dataset))
assert new_data_loader.multiprocessing_context == train.multiprocessing_context

Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,7 @@ def test_auto_add_worker_init_fn_distributed(tmp_path, monkeypatch):
num_workers = 2
batch_size = 2

dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, persistent_workers=True)
seed_everything(0, workers=True)
trainer = Trainer(default_root_dir=tmp_path, max_epochs=1, accelerator="gpu", devices=2, strategy="ddp_spawn")
model = MultiProcessModel()
Expand Down

0 comments on commit fed9783

Please sign in to comment.