Skip to content

Commit

Permalink
Opacus release v1.5.2 (#663)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #663

Release a new version of Opacus

Furthermore, we replace "opt_einsum.contract" by torch.einsum to avoid errors when "opt_einsum" is not available. This will not hurt the performance since torch will automatically shift to "opt_einsum" for acceleratiton when the package is available (https://pytorch.org/docs/stable/generated/torch.einsum.html)
Code pointer: https://pytorch.org/docs/stable/_modules/torch/backends/opt_einsum.html#is_available

Reviewed By: EnayatUllah

Differential Revision: D60672828

fbshipit-source-id: f8bbc0aa404e48f15ce129689a6e55af68daa5e4
  • Loading branch information
HuanyuZhang authored and facebook-github-bot committed Aug 3, 2024
1 parent eb94674 commit f1412fa
Show file tree
Hide file tree
Showing 11 changed files with 35 additions and 32 deletions.
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
# Changelog

## v1.5.2

### New features
* Add a function of "double_backward" simplifying the training loop (#661)

### Bug fixes
* Fix issue with setting of param_group for the DPOptimizer wrapper (issue 649) (#660)
* Fix issue of DDP optimizer for FGC. The step function incorrectly called "original_optimizer.original_optimizer" (#662)
* Replace "opt_einsum.contract" by "torch.einsum"(#663)

## v1.5.1

### Bug fixes
Expand Down
5 changes: 2 additions & 3 deletions opacus/grad_sample/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import torch.nn as nn
import torch.nn.functional as F
from opacus.utils.tensor_utils import unfold2d, unfold3d
from opt_einsum import contract

from .utils import register_grad_sampler

Expand Down Expand Up @@ -90,7 +89,7 @@ def compute_conv_grad_sample(
ret = {}
if layer.weight.requires_grad:
# n=batch_sz; o=num_out_channels; p=(num_in_channels/groups)*kernel_sz
grad_sample = contract("noq,npq->nop", backprops, activations)
grad_sample = torch.einsum("noq,npq->nop", backprops, activations)
# rearrange the above tensor and extract diagonals.
grad_sample = grad_sample.view(
n,
Expand All @@ -100,7 +99,7 @@ def compute_conv_grad_sample(
int(layer.in_channels / layer.groups),
np.prod(layer.kernel_size),
)
grad_sample = contract("ngrg...->ngr...", grad_sample).contiguous()
grad_sample = torch.einsum("ngrg...->ngr...", grad_sample).contiguous()
shape = [n] + list(layer.weight.shape)
ret[layer.weight] = grad_sample.view(shape)

Expand Down
5 changes: 2 additions & 3 deletions opacus/grad_sample/dp_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import torch
import torch.nn as nn
from opacus.layers.dp_rnn import RNNLinear
from opt_einsum import contract

from .utils import register_grad_sampler

Expand All @@ -42,8 +41,8 @@ def compute_rnn_linear_grad_sample(
activations = activations[0]
ret = {}
if layer.weight.requires_grad:
gs = contract("n...i,n...j->nij", backprops, activations)
gs = torch.einsum("n...i,n...j->nij", backprops, activations)
ret[layer.weight] = gs
if layer.bias is not None and layer.bias.requires_grad:
ret[layer.bias] = contract("n...k->nk", backprops)
ret[layer.bias] = torch.einsum("n...k->nk", backprops)
return ret
5 changes: 2 additions & 3 deletions opacus/grad_sample/group_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from opt_einsum import contract

from .utils import register_grad_sampler

Expand All @@ -42,7 +41,7 @@ def compute_group_norm_grad_sample(
ret = {}
if layer.weight.requires_grad:
gs = F.group_norm(activations, layer.num_groups, eps=layer.eps) * backprops
ret[layer.weight] = contract("ni...->ni", gs)
ret[layer.weight] = torch.einsum("ni...->ni", gs)
if layer.bias is not None and layer.bias.requires_grad:
ret[layer.bias] = contract("ni...->ni", backprops)
ret[layer.bias] = torch.einsum("ni...->ni", backprops)
return ret
5 changes: 2 additions & 3 deletions opacus/grad_sample/instance_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from opt_einsum import contract

from .utils import register_grad_sampler

Expand Down Expand Up @@ -51,7 +50,7 @@ def compute_instance_norm_grad_sample(
ret = {}
if layer.weight.requires_grad:
gs = F.instance_norm(activations, eps=layer.eps) * backprops
ret[layer.weight] = contract("ni...->ni", gs)
ret[layer.weight] = torch.einsum("ni...->ni", gs)
if layer.bias is not None and layer.bias.requires_grad:
ret[layer.bias] = contract("ni...->ni", backprops)
ret[layer.bias] = torch.einsum("ni...->ni", backprops)
return ret
23 changes: 12 additions & 11 deletions opacus/grad_sample/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import torch
import torch.nn as nn
from opt_einsum.contract import contract

from .utils import register_grad_sampler, register_norm_sampler

Expand All @@ -42,10 +41,10 @@ def compute_linear_grad_sample(
activations = activations[0]
ret = {}
if layer.weight.requires_grad:
gs = contract("n...i,n...j->nij", backprops, activations)
gs = torch.einsum("n...i,n...j->nij", backprops, activations)
ret[layer.weight] = gs
if layer.bias is not None and layer.bias.requires_grad:
ret[layer.bias] = contract("n...k->nk", backprops)
ret[layer.bias] = torch.einsum("n...k->nk", backprops)
return ret


Expand All @@ -66,23 +65,25 @@ def compute_linear_norm_sample(

if backprops.dim() == 2:
if layer.weight.requires_grad:
g = contract("n...i,n...i->n", backprops, backprops)
a = contract("n...j,n...j->n", activations, activations)
g = torch.einsum("n...i,n...i->n", backprops, backprops)
a = torch.einsum("n...j,n...j->n", activations, activations)
ret[layer.weight] = torch.sqrt((g * a).flatten())
if layer.bias is not None and layer.bias.requires_grad:
ret[layer.bias] = torch.sqrt(
contract("n...i,n...i->n", backprops, backprops).flatten()
torch.einsum("n...i,n...i->n", backprops, backprops).flatten()
)
elif backprops.dim() == 3:
if layer.weight.requires_grad:

ggT = contract("nik,njk->nij", backprops, backprops) # batchwise g g^T
aaT = contract("nik,njk->nij", activations, activations) # batchwise a a^T
ga = contract("n...i,n...i->n", ggT, aaT).clamp(min=0)
ggT = torch.einsum("nik,njk->nij", backprops, backprops) # batchwise g g^T
aaT = torch.einsum(
"nik,njk->nij", activations, activations
) # batchwise a a^T
ga = torch.einsum("n...i,n...i->n", ggT, aaT).clamp(min=0)

ret[layer.weight] = torch.sqrt(ga)
if layer.bias is not None and layer.bias.requires_grad:
ggT = contract("nik,njk->nij", backprops, backprops)
gg = contract("n...i,n...i->n", ggT, ggT).clamp(min=0)
ggT = torch.einsum("nik,njk->nij", backprops, backprops)
gg = torch.einsum("n...i,n...i->n", ggT, ggT).clamp(min=0)
ret[layer.bias] = torch.sqrt(gg)
return ret
3 changes: 1 addition & 2 deletions opacus/optimizers/adaclipoptimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from typing import Callable, Optional

import torch
from opt_einsum import contract
from torch.optim import Optimizer

from .optimizer import (
Expand Down Expand Up @@ -107,7 +106,7 @@ def clip_and_accumulate(self):
for p in self.params:
_check_processed_flag(p.grad_sample)
grad_sample = self._get_flat_grad_sample(p)
grad = contract("i,i...", per_sample_clip_factor, grad_sample)
grad = torch.einsum("i,i...", per_sample_clip_factor, grad_sample)

if p.summed_grad is not None:
p.summed_grad += grad
Expand Down
3 changes: 1 addition & 2 deletions opacus/optimizers/ddp_perlayeroptimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from typing import Callable, List, Optional

import torch
from opt_einsum import contract
from torch import nn
from torch.optim import Optimizer

Expand All @@ -31,7 +30,7 @@ def _clip_and_accumulate_parameter(p: nn.Parameter, max_grad_norm: float):
per_sample_norms = p.grad_sample.view(len(p.grad_sample), -1).norm(2, dim=-1)
per_sample_clip_factor = (max_grad_norm / (per_sample_norms + 1e-6)).clamp(max=1.0)

grad = contract("i,i...", per_sample_clip_factor, p.grad_sample)
grad = torch.einsum("i,i...", per_sample_clip_factor, p.grad_sample)
if p.summed_grad is not None:
p.summed_grad += grad
else:
Expand Down
3 changes: 1 addition & 2 deletions opacus/optimizers/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import torch
from opacus.optimizers.utils import params
from opt_einsum.contract import contract
from torch import nn
from torch.optim import Optimizer

Expand Down Expand Up @@ -450,7 +449,7 @@ def clip_and_accumulate(self):
for p in self.params:
_check_processed_flag(p.grad_sample)
grad_sample = self._get_flat_grad_sample(p)
grad = contract("i,i...", per_sample_clip_factor, grad_sample)
grad = torch.einsum("i,i...", per_sample_clip_factor, grad_sample)

if p.summed_grad is not None:
p.summed_grad += grad
Expand Down
3 changes: 1 addition & 2 deletions opacus/optimizers/perlayeroptimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import torch
from opacus.optimizers.utils import params
from opt_einsum import contract
from torch.optim import Optimizer

from .optimizer import DPOptimizer, _check_processed_flag, _mark_as_processed
Expand Down Expand Up @@ -65,7 +64,7 @@ def clip_and_accumulate(self):
per_sample_clip_factor = (max_grad_norm / (per_sample_norms + 1e-6)).clamp(
max=1.0
)
grad = contract("i,i...", per_sample_clip_factor, grad_sample)
grad = torch.einsum("i,i...", per_sample_clip_factor, grad_sample)

if p.summed_grad is not None:
p.summed_grad += grad
Expand Down
2 changes: 1 addition & 1 deletion opacus/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "1.5.1"
__version__ = "1.5.2"

0 comments on commit f1412fa

Please sign in to comment.