Skip to content

Commit

Permalink
try drop pickle warning
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Nov 7, 2024
1 parent 272605d commit 80a089d
Show file tree
Hide file tree
Showing 7 changed files with 10 additions and 24 deletions.
6 changes: 2 additions & 4 deletions tests/tests_pytorch/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,13 +193,11 @@ def test_pickling():
early_stopping = EarlyStopping(monitor="foo")

early_stopping_pickled = pickle.dumps(early_stopping)
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
early_stopping_loaded = pickle.loads(early_stopping_pickled)
early_stopping_loaded = pickle.loads(early_stopping_pickled)
assert vars(early_stopping) == vars(early_stopping_loaded)

early_stopping_pickled = cloudpickle.dumps(early_stopping)
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
early_stopping_loaded = cloudpickle.loads(early_stopping_pickled)
early_stopping_loaded = cloudpickle.loads(early_stopping_pickled)
assert vars(early_stopping) == vars(early_stopping_loaded)


Expand Down
6 changes: 2 additions & 4 deletions tests/tests_pytorch/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,13 +352,11 @@ def test_pickling(tmp_path):
ckpt = ModelCheckpoint(dirpath=tmp_path)

ckpt_pickled = pickle.dumps(ckpt)
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
ckpt_loaded = pickle.loads(ckpt_pickled)
ckpt_loaded = pickle.loads(ckpt_pickled)
assert vars(ckpt) == vars(ckpt_loaded)

ckpt_pickled = cloudpickle.dumps(ckpt)
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
ckpt_loaded = cloudpickle.loads(ckpt_pickled)
ckpt_loaded = cloudpickle.loads(ckpt_pickled)
assert vars(ckpt) == vars(ckpt_loaded)


Expand Down
3 changes: 1 addition & 2 deletions tests/tests_pytorch/core/test_metric_result_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,7 @@ def lightning_log(fx, *args, **kwargs):
}

# make sure can be pickled
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
pickle.loads(pickle.dumps(result))
pickle.loads(pickle.dumps(result))
# make sure can be torch.loaded
filepath = str(tmp_path / "result")
torch.save(result, filepath)
Expand Down
6 changes: 2 additions & 4 deletions tests/tests_pytorch/helpers/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ def test_pickling_dataset_mnist(dataset_cls, args):
mnist = dataset_cls(**args)

mnist_pickled = pickle.dumps(mnist)
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
pickle.loads(mnist_pickled)
pickle.loads(mnist_pickled)

mnist_pickled = cloudpickle.dumps(mnist)
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
cloudpickle.loads(mnist_pickled)
cloudpickle.loads(mnist_pickled)
7 changes: 1 addition & 6 deletions tests/tests_pytorch/loggers/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,12 +184,7 @@ def _test_loggers_pickle(tmp_path, monkeypatch, logger_class: Logger):
trainer = Trainer(max_epochs=1, logger=logger)
pkl_bytes = pickle.dumps(trainer)

with (
pytest.warns(FutureWarning, match="`weights_only=False`")
if _TORCH_EQUAL_2_4_0 or (_TORCH_GREATER_EQUAL_2_4_1 and logger_class not in (CSVLogger, TensorBoardLogger))
else nullcontext()
):
trainer2 = pickle.loads(pkl_bytes)
trainer2 = pickle.loads(pkl_bytes)
trainer2.logger.log_metrics({"acc": 1.0})

# make sure we restored properly
Expand Down
3 changes: 1 addition & 2 deletions tests/tests_pytorch/loggers/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,7 @@ def test_multiple_loggers_pickle(tmp_path):

trainer = Trainer(logger=[logger1, logger2])
pkl_bytes = pickle.dumps(trainer)
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
trainer2 = pickle.loads(pkl_bytes)
trainer2 = pickle.loads(pkl_bytes)
for logger in trainer2.loggers:
logger.log_metrics({"acc": 1.0}, 0)

Expand Down
3 changes: 1 addition & 2 deletions tests/tests_pytorch/loggers/test_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,7 @@ def name(self):
assert trainer.logger.experiment, "missing experiment"
assert trainer.log_dir == logger.save_dir
pkl_bytes = pickle.dumps(trainer)
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
trainer2 = pickle.loads(pkl_bytes)
trainer2 = pickle.loads(pkl_bytes)

assert os.environ["WANDB_MODE"] == "dryrun"
assert trainer2.logger.__class__.__name__ == WandbLogger.__name__
Expand Down

0 comments on commit 80a089d

Please sign in to comment.