Skip to content

Commit

Permalink
No-Op GradSampleModule (#492)
Browse files Browse the repository at this point in the history
Summary:
TL;DR: Adding a No-Op GradSampleModule in case the grad samples are computed by functorch. The CIFAR10 example has been updated to show a typical use-case for that.

The neat thing about functorch is that it directly gives the per-sample gradients with a couple of lines of code. These per-sample gradients are then manually given to `p.grad_sample` by the end-user.

Pull Request resolved: #492

Reviewed By: ffuuugor

Differential Revision: D39204008

Pulled By: alexandresablayrolles

fbshipit-source-id: 22036e6c941522bba7749ef46f97d54f6ee8c551
  • Loading branch information
Alex Sablayrolles authored and facebook-github-bot committed Sep 9, 2022
1 parent 38b24dc commit 9b855a7
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 10 deletions.
2 changes: 2 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ commands:
pip install tensorboard
python examples/cifar10.py --lr 0.1 --sigma 1.5 -c 10 --batch-size 2000 --epochs 10 --data-root runs/cifar10/data --log-dir runs/cifar10/logs --device <<parameters.device>>
python -c "import torch; model = torch.load('model_best.pth.tar'); exit(0) if (model['best_acc1']>0.4 and model['best_acc1']<0.49) else exit(1)"
python examples/cifar10.py --lr 0.1 --sigma 1.5 -c 10 --batch-size 2000 --epochs 10 --data-root runs/cifar10/data --log-dir runs/cifar10/logs --device <<parameters.device>> --grad_sample_mode no_op
python -c "import torch; model = torch.load('model_best.pth.tar'); exit(0) if (model['best_acc1']>0.4 and model['best_acc1']<0.49) else exit(1)"
when: always
- store_test_results:
path: runs/cifar10/test-reports
Expand Down
50 changes: 41 additions & 9 deletions examples/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,25 +138,55 @@ def train(args, model, train_loader, optimizer, privacy_engine, epoch, device):
losses = []
top1_acc = []

if args.grad_sample_mode == "no_op":
from functorch import grad_and_value, make_functional, vmap

# Functorch prepare
fmodel, _fparams = make_functional(model)

def compute_loss_stateless_model(params, sample, target):
batch = sample.unsqueeze(0)
targets = target.unsqueeze(0)

predictions = fmodel(params, batch)
loss = criterion(predictions, targets)
return loss

ft_compute_grad = grad_and_value(compute_loss_stateless_model)
ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, 0, 0))
# Using model.parameters() instead of fparams
# as fparams seems to not point to the dynamically updated parameters
params = list(model.parameters())

for i, (images, target) in enumerate(tqdm(train_loader)):

images = images.to(device)
target = target.to(device)

# compute output
output = model(images)
loss = criterion(output, target)
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
labels = target.detach().cpu().numpy()

# measure accuracy and record loss
acc1 = accuracy(preds, labels)
if args.grad_sample_mode == "no_op":
per_sample_grads, per_sample_losses = ft_compute_sample_grad(
params, images, target
)
per_sample_grads = [g.detach() for g in per_sample_grads]
loss = torch.mean(per_sample_losses)
for (p, g) in zip(params, per_sample_grads):
p.grad_sample = g
else:
loss = criterion(output, target)
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
labels = target.detach().cpu().numpy()

losses.append(loss.item())
top1_acc.append(acc1)
# measure accuracy and record loss
acc1 = accuracy(preds, labels)
top1_acc.append(acc1)

# compute gradient and do SGD step
loss.backward()

# compute gradient and do SGD step
loss.backward()
losses.append(loss.item())

# make sure we take a step after processing the last mini-batch in the
# epoch to ensure we start the next epoch with a clean state
Expand Down Expand Up @@ -331,6 +361,7 @@ def main():
noise_multiplier=args.sigma,
max_grad_norm=max_grad_norm,
clipping=clipping,
grad_sample_mode=args.grad_sample_mode,
)

# Store some logs
Expand Down Expand Up @@ -388,6 +419,7 @@ def main():

def parse_args():
parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training")
parser.add_argument("--grad_sample_mode", type=str, default="hooks")
parser.add_argument(
"-j",
"--workers",
Expand Down
45 changes: 45 additions & 0 deletions opacus/grad_sample/gsm_no_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn as nn
from opacus.grad_sample.gsm_base import AbstractGradSampleModule


class GradSampleModuleNoOp(AbstractGradSampleModule):
"""
NoOp GradSampleModule.
Only wraps the module. The main goal of this class is to provide the same API for all methods.
See README.md for more details
"""

def __init__(
self,
m: nn.Module,
*,
batch_first=True,
loss_reduction="mean",
):
if not batch_first:
raise NotImplementedError

super().__init__(
m,
batch_first=batch_first,
loss_reduction=loss_reduction,
)

def forward(self, x: torch.Tensor, *args, **kwargs):
return self._module.forward(x, *args, **kwargs)
3 changes: 3 additions & 0 deletions opacus/grad_sample/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .grad_sample_module import GradSampleModule
from .gsm_base import AbstractGradSampleModule
from .gsm_exp_weights import GradSampleModuleExpandedWeights
from .gsm_no_op import GradSampleModuleNoOp


def register_grad_sampler(
Expand Down Expand Up @@ -69,6 +70,8 @@ def get_gsm_class(grad_sample_mode: str) -> Type[AbstractGradSampleModule]:
return GradSampleModule
elif grad_sample_mode == "ew":
return GradSampleModuleExpandedWeights
elif grad_sample_mode == "no_op":
return GradSampleModuleNoOp
else:
raise ValueError(
f"Unexpected grad_sample_mode: {grad_sample_mode}. "
Expand Down
2 changes: 1 addition & 1 deletion opacus/optimizers/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def clip_and_accumulate(self):
"""

per_param_norms = [
g.norm(2, dim=tuple(range(1, g.ndim))) for g in self.grad_samples
g.reshape(len(g), -1).norm(2, dim=-1) for g in self.grad_samples
]
per_sample_norms = torch.stack(per_param_norms, dim=1).norm(2, dim=1)
per_sample_clip_factor = (self.max_grad_norm / (per_sample_norms + 1e-6)).clamp(
Expand Down

0 comments on commit 9b855a7

Please sign in to comment.