From c3ebe77137898f31ace99db1fa7df053182459ed Mon Sep 17 00:00:00 2001 From: Gabriel Rasskin <43894452+grasskin@users.noreply.github.com> Date: Tue, 10 Oct 2023 06:10:06 -0400 Subject: [PATCH] No-op for torch.no_grad if torch<2.1.0 (#18563) * Optional no-op when @torch.nograd is not available * Depend on torch 2.0.1 * Format * New torch util * format --- keras/backend/torch/optimizers/torch_optimizer.py | 4 +++- .../torch/optimizers/torch_parallel_optimizer.py | 3 ++- keras/utils/torch_utils.py | 11 +++++++++++ requirements.txt | 5 ++--- 4 files changed, 18 insertions(+), 5 deletions(-) diff --git a/keras/backend/torch/optimizers/torch_optimizer.py b/keras/backend/torch/optimizers/torch_optimizer.py index ddc0d0a83ee..98c59cdb73e 100644 --- a/keras/backend/torch/optimizers/torch_optimizer.py +++ b/keras/backend/torch/optimizers/torch_optimizer.py @@ -1,7 +1,9 @@ import torch +from packaging.version import parse from keras import optimizers from keras.optimizers.base_optimizer import BaseOptimizer +from keras.utils import torch_utils class TorchOptimizer(BaseOptimizer): @@ -33,7 +35,7 @@ def __new__(cls, *args, **kwargs): return OPTIMIZERS[cls](*args, **kwargs) return super().__new__(cls) - @torch.no_grad + @torch_utils.no_grad def _apply_weight_decay(self, variables): if self.weight_decay is None: return diff --git a/keras/backend/torch/optimizers/torch_parallel_optimizer.py b/keras/backend/torch/optimizers/torch_parallel_optimizer.py index 19aa4cb39bb..b5bdbf41b26 100644 --- a/keras/backend/torch/optimizers/torch_parallel_optimizer.py +++ b/keras/backend/torch/optimizers/torch_parallel_optimizer.py @@ -1,10 +1,11 @@ import torch from keras.optimizers.base_optimizer import BaseOptimizer +from keras.utils import torch_utils class TorchParallelOptimizer(BaseOptimizer): - @torch.no_grad + @torch_utils.no_grad def _internal_apply_gradients(self, grads_and_vars): grads, trainable_variables = zip(*grads_and_vars) diff --git a/keras/utils/torch_utils.py b/keras/utils/torch_utils.py index cc72dd4ec52..4d7c40e17d2 100644 --- a/keras/utils/torch_utils.py +++ b/keras/utils/torch_utils.py @@ -1,5 +1,7 @@ import io +from packaging.version import parse + from keras.api_export import keras_export from keras.layers import Layer from keras.ops import convert_to_numpy @@ -145,3 +147,12 @@ def from_config(cls, config): buffer = io.BytesIO(config["module"]) config["module"] = torch.load(buffer) return cls(**config) + + +def no_grad(orig_func): + import torch + + if parse(torch.__version__) >= parse("2.1.0"): + return torch.no_grad(orig_func) + else: + return orig_func diff --git a/requirements.txt b/requirements.txt index 6f92700c345..39b351e3481 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,9 +2,8 @@ tf-nightly==2.15.0.dev20231009 # Pin a working nightly until rc0. # Torch. ---extra-index-url https://download.pytorch.org/whl/cpu -torch>=2.1.0 -torchvision>=0.16.0 +torch>=2.0.1 +torchvision>=0.15.1 # Jax. jax[cpu]