Skip to content

Commit

Permalink
feat: support adaptive optimizer in DPDL framework
Browse files Browse the repository at this point in the history
  - add optim_args in privacy_engine for optimizers with extra args
  - modify the Opacus adaptive optimizer to our AdaptDPSGD-Full version
  • Loading branch information
Linzh7 committed Oct 4, 2024
1 parent 02a8d99 commit 3140541
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
13 changes: 8 additions & 5 deletions opacus/optimizers/adaclipoptimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,13 @@ def __init__(
optimizer: Optimizer,
*,
noise_multiplier: float,
target_unclipped_quantile: float,
clipbound_learning_rate: float,
max_clipbound: float,
min_clipbound: float,
unclipped_num_std: float,
max_grad_norm: float,
expected_batch_size: Optional[int],
loss_reduction: str = "mean",
generator=None,
secure_mode: bool = False,
normalize_clipping: bool = False,
optim_args: dict = None,
):
super().__init__(
optimizer,
Expand All @@ -62,7 +59,13 @@ def __init__(
loss_reduction=loss_reduction,
generator=generator,
secure_mode=secure_mode,
normalize_clipping=normalize_clipping,
)
target_unclipped_quantile = optim_args.get('target_unclipped_quantile', 0.0)
clipbound_learning_rate = optim_args.get('clipbound_learning_rate', 1.0)
max_clipbound = optim_args.get('max_clipbound', torch.inf)
min_clipbound = optim_args.get('min_clipbound', -torch.inf)
unclipped_num_std = optim_args.get('unclipped_num_std')
assert (
max_clipbound > min_clipbound
), "max_clipbound must be larger than min_clipbound."
Expand Down
8 changes: 7 additions & 1 deletion opacus/privacy_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import IO, Any, BinaryIO, Dict, List, Optional, Tuple, Union

import torch
from torch import distributed as dist
from opacus.accountants import create_accountant
from opacus.accountants.utils import get_noise_multiplier
from opacus.data_loader import DPDataLoader, switch_generator
Expand Down Expand Up @@ -111,6 +112,7 @@ def _prepare_optimizer(
noise_generator=None,
grad_sample_mode="hooks",
normalize_clipping: bool = False,
optim_args: dict = None,
**kwargs,
) -> DPOptimizer:
if isinstance(optimizer, DPOptimizer):
Expand Down Expand Up @@ -294,6 +296,7 @@ def make_private(
grad_sample_mode: str = "hooks",
normalize_clipping: bool = False,
total_steps: int = None,
optim_args: dict = None,
**kwargs,
) -> Tuple[GradSampleModule, DPOptimizer, DataLoader]:
"""
Expand Down Expand Up @@ -375,7 +378,7 @@ def make_private(
"Module parameters are different than optimizer Parameters"
)

distributed = isinstance(module, (DPDDP, DDP))
distributed = dist.get_world_size() > 1

module = self._prepare_model(
module,
Expand Down Expand Up @@ -427,6 +430,7 @@ def make_private(
clipping=clipping,
grad_sample_mode=grad_sample_mode,
normalize_clipping=normalize_clipping,
optim_args=optim_args,
**kwargs,
)

Expand Down Expand Up @@ -454,6 +458,7 @@ def make_private_with_epsilon(
grad_sample_mode: str = "hooks",
normalize_clipping: bool = False,
total_steps: int = None,
optim_args: dict = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -569,6 +574,7 @@ def make_private_with_epsilon(
clipping=clipping,
normalize_clipping=normalize_clipping,
total_steps=total_steps,
optim_args=optim_args,
)

def get_epsilon(self, delta):
Expand Down

0 comments on commit 3140541

Please sign in to comment.