diff --git a/CHANGELOG.md b/CHANGELOG.md index 219360f7..746525e3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,15 @@ # Changelog +## v1.5.2 + +### New features +* Add a function of "double_backward" simplifying the training loop (#661) + +### Bug fixes +* Fix issue with setting of param_group for the DPOptimizer wrapper (issue 649) (#660) +* Fix issue of DDP optimizer for FGC. The step function incorrectly called "original_optimizer.original_optimizer" (#662) +* Replace "opt_einsum.contract" by "torch.einsum"(#663) + ## v1.5.1 ### Bug fixes diff --git a/opacus/grad_sample/conv.py b/opacus/grad_sample/conv.py index 633782f9..e272e4c5 100644 --- a/opacus/grad_sample/conv.py +++ b/opacus/grad_sample/conv.py @@ -21,7 +21,6 @@ import torch.nn as nn import torch.nn.functional as F from opacus.utils.tensor_utils import unfold2d, unfold3d -from opt_einsum import contract from .utils import register_grad_sampler @@ -90,7 +89,7 @@ def compute_conv_grad_sample( ret = {} if layer.weight.requires_grad: # n=batch_sz; o=num_out_channels; p=(num_in_channels/groups)*kernel_sz - grad_sample = contract("noq,npq->nop", backprops, activations) + grad_sample = torch.einsum("noq,npq->nop", backprops, activations) # rearrange the above tensor and extract diagonals. grad_sample = grad_sample.view( n, @@ -100,7 +99,7 @@ def compute_conv_grad_sample( int(layer.in_channels / layer.groups), np.prod(layer.kernel_size), ) - grad_sample = contract("ngrg...->ngr...", grad_sample).contiguous() + grad_sample = torch.einsum("ngrg...->ngr...", grad_sample).contiguous() shape = [n] + list(layer.weight.shape) ret[layer.weight] = grad_sample.view(shape) diff --git a/opacus/grad_sample/dp_rnn.py b/opacus/grad_sample/dp_rnn.py index ce07d4a2..3fe05876 100644 --- a/opacus/grad_sample/dp_rnn.py +++ b/opacus/grad_sample/dp_rnn.py @@ -19,7 +19,6 @@ import torch import torch.nn as nn from opacus.layers.dp_rnn import RNNLinear -from opt_einsum import contract from .utils import register_grad_sampler @@ -42,8 +41,8 @@ def compute_rnn_linear_grad_sample( activations = activations[0] ret = {} if layer.weight.requires_grad: - gs = contract("n...i,n...j->nij", backprops, activations) + gs = torch.einsum("n...i,n...j->nij", backprops, activations) ret[layer.weight] = gs if layer.bias is not None and layer.bias.requires_grad: - ret[layer.bias] = contract("n...k->nk", backprops) + ret[layer.bias] = torch.einsum("n...k->nk", backprops) return ret diff --git a/opacus/grad_sample/group_norm.py b/opacus/grad_sample/group_norm.py index 4c0e77d7..f9b8f415 100644 --- a/opacus/grad_sample/group_norm.py +++ b/opacus/grad_sample/group_norm.py @@ -19,7 +19,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -from opt_einsum import contract from .utils import register_grad_sampler @@ -42,7 +41,7 @@ def compute_group_norm_grad_sample( ret = {} if layer.weight.requires_grad: gs = F.group_norm(activations, layer.num_groups, eps=layer.eps) * backprops - ret[layer.weight] = contract("ni...->ni", gs) + ret[layer.weight] = torch.einsum("ni...->ni", gs) if layer.bias is not None and layer.bias.requires_grad: - ret[layer.bias] = contract("ni...->ni", backprops) + ret[layer.bias] = torch.einsum("ni...->ni", backprops) return ret diff --git a/opacus/grad_sample/instance_norm.py b/opacus/grad_sample/instance_norm.py index 5618164a..31403bdd 100644 --- a/opacus/grad_sample/instance_norm.py +++ b/opacus/grad_sample/instance_norm.py @@ -18,7 +18,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -from opt_einsum import contract from .utils import register_grad_sampler @@ -51,7 +50,7 @@ def compute_instance_norm_grad_sample( ret = {} if layer.weight.requires_grad: gs = F.instance_norm(activations, eps=layer.eps) * backprops - ret[layer.weight] = contract("ni...->ni", gs) + ret[layer.weight] = torch.einsum("ni...->ni", gs) if layer.bias is not None and layer.bias.requires_grad: - ret[layer.bias] = contract("ni...->ni", backprops) + ret[layer.bias] = torch.einsum("ni...->ni", backprops) return ret diff --git a/opacus/grad_sample/linear.py b/opacus/grad_sample/linear.py index cceb8cac..1b30f94a 100644 --- a/opacus/grad_sample/linear.py +++ b/opacus/grad_sample/linear.py @@ -18,7 +18,6 @@ import torch import torch.nn as nn -from opt_einsum.contract import contract from .utils import register_grad_sampler, register_norm_sampler @@ -42,10 +41,10 @@ def compute_linear_grad_sample( activations = activations[0] ret = {} if layer.weight.requires_grad: - gs = contract("n...i,n...j->nij", backprops, activations) + gs = torch.einsum("n...i,n...j->nij", backprops, activations) ret[layer.weight] = gs if layer.bias is not None and layer.bias.requires_grad: - ret[layer.bias] = contract("n...k->nk", backprops) + ret[layer.bias] = torch.einsum("n...k->nk", backprops) return ret @@ -66,23 +65,25 @@ def compute_linear_norm_sample( if backprops.dim() == 2: if layer.weight.requires_grad: - g = contract("n...i,n...i->n", backprops, backprops) - a = contract("n...j,n...j->n", activations, activations) + g = torch.einsum("n...i,n...i->n", backprops, backprops) + a = torch.einsum("n...j,n...j->n", activations, activations) ret[layer.weight] = torch.sqrt((g * a).flatten()) if layer.bias is not None and layer.bias.requires_grad: ret[layer.bias] = torch.sqrt( - contract("n...i,n...i->n", backprops, backprops).flatten() + torch.einsum("n...i,n...i->n", backprops, backprops).flatten() ) elif backprops.dim() == 3: if layer.weight.requires_grad: - ggT = contract("nik,njk->nij", backprops, backprops) # batchwise g g^T - aaT = contract("nik,njk->nij", activations, activations) # batchwise a a^T - ga = contract("n...i,n...i->n", ggT, aaT).clamp(min=0) + ggT = torch.einsum("nik,njk->nij", backprops, backprops) # batchwise g g^T + aaT = torch.einsum( + "nik,njk->nij", activations, activations + ) # batchwise a a^T + ga = torch.einsum("n...i,n...i->n", ggT, aaT).clamp(min=0) ret[layer.weight] = torch.sqrt(ga) if layer.bias is not None and layer.bias.requires_grad: - ggT = contract("nik,njk->nij", backprops, backprops) - gg = contract("n...i,n...i->n", ggT, ggT).clamp(min=0) + ggT = torch.einsum("nik,njk->nij", backprops, backprops) + gg = torch.einsum("n...i,n...i->n", ggT, ggT).clamp(min=0) ret[layer.bias] = torch.sqrt(gg) return ret diff --git a/opacus/optimizers/adaclipoptimizer.py b/opacus/optimizers/adaclipoptimizer.py index c89a0dbf..7144f06b 100644 --- a/opacus/optimizers/adaclipoptimizer.py +++ b/opacus/optimizers/adaclipoptimizer.py @@ -18,7 +18,6 @@ from typing import Callable, Optional import torch -from opt_einsum import contract from torch.optim import Optimizer from .optimizer import ( @@ -107,7 +106,7 @@ def clip_and_accumulate(self): for p in self.params: _check_processed_flag(p.grad_sample) grad_sample = self._get_flat_grad_sample(p) - grad = contract("i,i...", per_sample_clip_factor, grad_sample) + grad = torch.einsum("i,i...", per_sample_clip_factor, grad_sample) if p.summed_grad is not None: p.summed_grad += grad diff --git a/opacus/optimizers/ddp_perlayeroptimizer.py b/opacus/optimizers/ddp_perlayeroptimizer.py index 1421debb..c9b9bdfa 100644 --- a/opacus/optimizers/ddp_perlayeroptimizer.py +++ b/opacus/optimizers/ddp_perlayeroptimizer.py @@ -18,7 +18,6 @@ from typing import Callable, List, Optional import torch -from opt_einsum import contract from torch import nn from torch.optim import Optimizer @@ -31,7 +30,7 @@ def _clip_and_accumulate_parameter(p: nn.Parameter, max_grad_norm: float): per_sample_norms = p.grad_sample.view(len(p.grad_sample), -1).norm(2, dim=-1) per_sample_clip_factor = (max_grad_norm / (per_sample_norms + 1e-6)).clamp(max=1.0) - grad = contract("i,i...", per_sample_clip_factor, p.grad_sample) + grad = torch.einsum("i,i...", per_sample_clip_factor, p.grad_sample) if p.summed_grad is not None: p.summed_grad += grad else: diff --git a/opacus/optimizers/optimizer.py b/opacus/optimizers/optimizer.py index 21be2147..7a22eeec 100644 --- a/opacus/optimizers/optimizer.py +++ b/opacus/optimizers/optimizer.py @@ -20,7 +20,6 @@ import torch from opacus.optimizers.utils import params -from opt_einsum.contract import contract from torch import nn from torch.optim import Optimizer @@ -450,7 +449,7 @@ def clip_and_accumulate(self): for p in self.params: _check_processed_flag(p.grad_sample) grad_sample = self._get_flat_grad_sample(p) - grad = contract("i,i...", per_sample_clip_factor, grad_sample) + grad = torch.einsum("i,i...", per_sample_clip_factor, grad_sample) if p.summed_grad is not None: p.summed_grad += grad diff --git a/opacus/optimizers/perlayeroptimizer.py b/opacus/optimizers/perlayeroptimizer.py index fe4fbff6..6d0029bf 100644 --- a/opacus/optimizers/perlayeroptimizer.py +++ b/opacus/optimizers/perlayeroptimizer.py @@ -18,7 +18,6 @@ import torch from opacus.optimizers.utils import params -from opt_einsum import contract from torch.optim import Optimizer from .optimizer import DPOptimizer, _check_processed_flag, _mark_as_processed @@ -65,7 +64,7 @@ def clip_and_accumulate(self): per_sample_clip_factor = (max_grad_norm / (per_sample_norms + 1e-6)).clamp( max=1.0 ) - grad = contract("i,i...", per_sample_clip_factor, grad_sample) + grad = torch.einsum("i,i...", per_sample_clip_factor, grad_sample) if p.summed_grad is not None: p.summed_grad += grad diff --git a/opacus/version.py b/opacus/version.py index b195b092..0d218ae9 100644 --- a/opacus/version.py +++ b/opacus/version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "1.5.1" +__version__ = "1.5.2"