-
Notifications
You must be signed in to change notification settings - Fork 341
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: TL;DR: Adding a No-Op GradSampleModule in case the grad samples are computed by functorch. The CIFAR10 example has been updated to show a typical use-case for that. The neat thing about functorch is that it directly gives the per-sample gradients with a couple of lines of code. These per-sample gradients are then manually given to `p.grad_sample` by the end-user. Pull Request resolved: #492 Reviewed By: ffuuugor Differential Revision: D39204008 Pulled By: alexandresablayrolles fbshipit-source-id: 22036e6c941522bba7749ef46f97d54f6ee8c551
- Loading branch information
1 parent
38b24dc
commit 9b855a7
Showing
5 changed files
with
92 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import torch | ||
import torch.nn as nn | ||
from opacus.grad_sample.gsm_base import AbstractGradSampleModule | ||
|
||
|
||
class GradSampleModuleNoOp(AbstractGradSampleModule): | ||
""" | ||
NoOp GradSampleModule. | ||
Only wraps the module. The main goal of this class is to provide the same API for all methods. | ||
See README.md for more details | ||
""" | ||
|
||
def __init__( | ||
self, | ||
m: nn.Module, | ||
*, | ||
batch_first=True, | ||
loss_reduction="mean", | ||
): | ||
if not batch_first: | ||
raise NotImplementedError | ||
|
||
super().__init__( | ||
m, | ||
batch_first=batch_first, | ||
loss_reduction=loss_reduction, | ||
) | ||
|
||
def forward(self, x: torch.Tensor, *args, **kwargs): | ||
return self._module.forward(x, *args, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters