From 80a089ddcc5db2a267aca26674c2a7233ca2f84f Mon Sep 17 00:00:00 2001 From: Jirka B Date: Thu, 7 Nov 2024 12:00:20 +0000 Subject: [PATCH] try drop pickle warning --- tests/tests_pytorch/callbacks/test_early_stopping.py | 6 ++---- tests/tests_pytorch/checkpointing/test_model_checkpoint.py | 6 ++---- tests/tests_pytorch/core/test_metric_result_integration.py | 3 +-- tests/tests_pytorch/helpers/test_datasets.py | 6 ++---- tests/tests_pytorch/loggers/test_all.py | 7 +------ tests/tests_pytorch/loggers/test_logger.py | 3 +-- tests/tests_pytorch/loggers/test_wandb.py | 3 +-- 7 files changed, 10 insertions(+), 24 deletions(-) diff --git a/tests/tests_pytorch/callbacks/test_early_stopping.py b/tests/tests_pytorch/callbacks/test_early_stopping.py index b7e52ee549bcc..221718425d7eb 100644 --- a/tests/tests_pytorch/callbacks/test_early_stopping.py +++ b/tests/tests_pytorch/callbacks/test_early_stopping.py @@ -193,13 +193,11 @@ def test_pickling(): early_stopping = EarlyStopping(monitor="foo") early_stopping_pickled = pickle.dumps(early_stopping) - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext(): - early_stopping_loaded = pickle.loads(early_stopping_pickled) + early_stopping_loaded = pickle.loads(early_stopping_pickled) assert vars(early_stopping) == vars(early_stopping_loaded) early_stopping_pickled = cloudpickle.dumps(early_stopping) - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext(): - early_stopping_loaded = cloudpickle.loads(early_stopping_pickled) + early_stopping_loaded = cloudpickle.loads(early_stopping_pickled) assert vars(early_stopping) == vars(early_stopping_loaded) diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index 97d8d3c4d0e4a..31f6db8b98272 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -352,13 +352,11 @@ def test_pickling(tmp_path): ckpt = ModelCheckpoint(dirpath=tmp_path) ckpt_pickled = pickle.dumps(ckpt) - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext(): - ckpt_loaded = pickle.loads(ckpt_pickled) + ckpt_loaded = pickle.loads(ckpt_pickled) assert vars(ckpt) == vars(ckpt_loaded) ckpt_pickled = cloudpickle.dumps(ckpt) - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext(): - ckpt_loaded = cloudpickle.loads(ckpt_pickled) + ckpt_loaded = cloudpickle.loads(ckpt_pickled) assert vars(ckpt) == vars(ckpt_loaded) diff --git a/tests/tests_pytorch/core/test_metric_result_integration.py b/tests/tests_pytorch/core/test_metric_result_integration.py index ef340d1e17ea9..6e7fa7310e115 100644 --- a/tests/tests_pytorch/core/test_metric_result_integration.py +++ b/tests/tests_pytorch/core/test_metric_result_integration.py @@ -254,8 +254,7 @@ def lightning_log(fx, *args, **kwargs): } # make sure can be pickled - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext(): - pickle.loads(pickle.dumps(result)) + pickle.loads(pickle.dumps(result)) # make sure can be torch.loaded filepath = str(tmp_path / "result") torch.save(result, filepath) diff --git a/tests/tests_pytorch/helpers/test_datasets.py b/tests/tests_pytorch/helpers/test_datasets.py index 98d77a6d9a8ad..f6d7fae4c86c5 100644 --- a/tests/tests_pytorch/helpers/test_datasets.py +++ b/tests/tests_pytorch/helpers/test_datasets.py @@ -44,9 +44,7 @@ def test_pickling_dataset_mnist(dataset_cls, args): mnist = dataset_cls(**args) mnist_pickled = pickle.dumps(mnist) - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext(): - pickle.loads(mnist_pickled) + pickle.loads(mnist_pickled) mnist_pickled = cloudpickle.dumps(mnist) - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext(): - cloudpickle.loads(mnist_pickled) + cloudpickle.loads(mnist_pickled) diff --git a/tests/tests_pytorch/loggers/test_all.py b/tests/tests_pytorch/loggers/test_all.py index c5b07562afb0a..480df336af6ea 100644 --- a/tests/tests_pytorch/loggers/test_all.py +++ b/tests/tests_pytorch/loggers/test_all.py @@ -184,12 +184,7 @@ def _test_loggers_pickle(tmp_path, monkeypatch, logger_class: Logger): trainer = Trainer(max_epochs=1, logger=logger) pkl_bytes = pickle.dumps(trainer) - with ( - pytest.warns(FutureWarning, match="`weights_only=False`") - if _TORCH_EQUAL_2_4_0 or (_TORCH_GREATER_EQUAL_2_4_1 and logger_class not in (CSVLogger, TensorBoardLogger)) - else nullcontext() - ): - trainer2 = pickle.loads(pkl_bytes) + trainer2 = pickle.loads(pkl_bytes) trainer2.logger.log_metrics({"acc": 1.0}) # make sure we restored properly diff --git a/tests/tests_pytorch/loggers/test_logger.py b/tests/tests_pytorch/loggers/test_logger.py index de0028000cd9f..3732a45c5e81c 100644 --- a/tests/tests_pytorch/loggers/test_logger.py +++ b/tests/tests_pytorch/loggers/test_logger.py @@ -124,8 +124,7 @@ def test_multiple_loggers_pickle(tmp_path): trainer = Trainer(logger=[logger1, logger2]) pkl_bytes = pickle.dumps(trainer) - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext(): - trainer2 = pickle.loads(pkl_bytes) + trainer2 = pickle.loads(pkl_bytes) for logger in trainer2.loggers: logger.log_metrics({"acc": 1.0}, 0) diff --git a/tests/tests_pytorch/loggers/test_wandb.py b/tests/tests_pytorch/loggers/test_wandb.py index 4e3fbb287a1f9..ddaa289172844 100644 --- a/tests/tests_pytorch/loggers/test_wandb.py +++ b/tests/tests_pytorch/loggers/test_wandb.py @@ -162,8 +162,7 @@ def name(self): assert trainer.logger.experiment, "missing experiment" assert trainer.log_dir == logger.save_dir pkl_bytes = pickle.dumps(trainer) - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext(): - trainer2 = pickle.loads(pkl_bytes) + trainer2 = pickle.loads(pkl_bytes) assert os.environ["WANDB_MODE"] == "dryrun" assert trainer2.logger.__class__.__name__ == WandbLogger.__name__