Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support of adaptive optimizer with DPDL #4

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 42 additions & 19 deletions opacus/optimizers/adaclipoptimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,18 @@ def __init__(
optimizer: Optimizer,
*,
noise_multiplier: float,
target_unclipped_quantile: float,
clipbound_learning_rate: float,
max_clipbound: float,
min_clipbound: float,
unclipped_num_std: float,
max_grad_norm: float,
expected_batch_size: Optional[int],
loss_reduction: str = "mean",
generator=None,
secure_mode: bool = False,
normalize_clipping: bool = False,
optim_args: dict = None,
):

assert(normalize_clipping == True), "Let us focus on the normalized version first"
max_grad_norm = 1.0

super().__init__(
optimizer,
noise_multiplier=noise_multiplier,
Expand All @@ -62,12 +63,21 @@ def __init__(
loss_reduction=loss_reduction,
generator=generator,
secure_mode=secure_mode,
normalize_clipping=normalize_clipping,
optim_args=optim_args,
)
assert (
max_clipbound > min_clipbound
), "max_clipbound must be larger than min_clipbound."

target_unclipped_quantile = optim_args.get('target_unclipped_quantile', 0.0)
clipbound_learning_rate = optim_args.get('clipbound_learning_rate', 1.0)
count_threshold = optim_args.get('count_threshold', 1.0)
max_clipbound = optim_args.get('max_clipbound', torch.inf)
min_clipbound = optim_args.get('min_clipbound', -torch.inf)
unclipped_num_std = optim_args.get('unclipped_num_std')
assert (max_clipbound > min_clipbound), "max_clipbound must be larger than min_clipbound."
self.clipbound = max_grad_norm # let we set the init value of clip bound to 1
self.target_unclipped_quantile = target_unclipped_quantile
self.clipbound_learning_rate = clipbound_learning_rate
self.count_threshold = count_threshold
self.max_clipbound = max_clipbound
self.min_clipbound = min_clipbound
self.unclipped_num_std = unclipped_num_std
Expand All @@ -92,15 +102,26 @@ def clip_and_accumulate(self):
g.view(len(g), -1).norm(2, dim=-1) for g in self.grad_samples
]
per_sample_norms = torch.stack(per_param_norms, dim=1).norm(2, dim=1)
per_sample_clip_factor = (self.max_grad_norm / (per_sample_norms + 1e-6)).clamp(
max=1.0
)

#print(f"max per_param_norms before clipping: {per_sample_norms.max().item()}")

# Create a mask to determine which gradients need to be clipped based on the clipbound
clip_mask = per_sample_norms > self.clipbound
per_sample_clip_factor = torch.where(
clip_mask,
self.max_grad_norm / (per_sample_norms + 1e-6),
torch.tensor(self.max_grad_norm / self.clipbound, device=per_sample_norms.device)
).clamp(max=1.0)

# Print max per_param_norms after clipping
clipped_per_sample_norms = per_sample_norms * per_sample_clip_factor
#print(f"max per_param_norms after clipping: {clipped_per_sample_norms.max().item()}")

# the two lines below are the only changes
# relative to the parent DPOptimizer class.
self.sample_size += len(per_sample_clip_factor)
self.unclipped_num += (
len(per_sample_clip_factor) - (per_sample_clip_factor < 1).sum()
len(per_sample_norms) - (per_sample_norms < self.clipbound * self.count_threshold).sum()
)

for p in self.params:
Expand All @@ -127,24 +148,26 @@ def add_noise(self):
self.unclipped_num = float(self.unclipped_num)
self.unclipped_num += unclipped_num_noise

def update_max_grad_norm(self):
def update_clipbound(self):
"""
Update clipping bound based on unclipped fraction
"""
unclipped_frac = self.unclipped_num / self.sample_size
self.max_grad_norm *= torch.exp(
self.clipbound *= torch.exp(
-self.clipbound_learning_rate
* (unclipped_frac - self.target_unclipped_quantile)
)
if self.max_grad_norm > self.max_clipbound:
self.max_grad_norm = self.max_clipbound
elif self.max_grad_norm < self.min_clipbound:
self.max_grad_norm = self.min_clipbound
if self.clipbound > self.max_clipbound:
self.clipbound = self.max_clipbound
elif self.clipbound < self.min_clipbound:
self.clipbound = self.min_clipbound

#print(f"self.clipbound: {self.clipbound}")

def pre_step(
self, closure: Optional[Callable[[], float]] = None
) -> Optional[float]:
pre_step_full = super().pre_step()
if pre_step_full:
self.update_max_grad_norm()
self.update_clipbound()
return pre_step_full
1 change: 1 addition & 0 deletions opacus/optimizers/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def __init__(
generator=None,
secure_mode: bool = False,
normalize_clipping: bool = False,
optim_args: dict = None,
):
"""

Expand Down
9 changes: 8 additions & 1 deletion opacus/privacy_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import IO, Any, BinaryIO, Dict, List, Optional, Tuple, Union

import torch
from torch import distributed as dist
from opacus.accountants import create_accountant
from opacus.accountants.utils import get_noise_multiplier
from opacus.data_loader import DPDataLoader, switch_generator
Expand Down Expand Up @@ -111,6 +112,7 @@ def _prepare_optimizer(
noise_generator=None,
grad_sample_mode="hooks",
normalize_clipping: bool = False,
optim_args: dict = None,
**kwargs,
) -> DPOptimizer:
if isinstance(optimizer, DPOptimizer):
Expand All @@ -137,6 +139,7 @@ def _prepare_optimizer(
generator=generator,
secure_mode=self.secure_mode,
normalize_clipping=normalize_clipping,
optim_args=optim_args,
**kwargs,
)

Expand Down Expand Up @@ -294,6 +297,7 @@ def make_private(
grad_sample_mode: str = "hooks",
normalize_clipping: bool = False,
total_steps: int = None,
optim_args: dict = None,
**kwargs,
) -> Tuple[GradSampleModule, DPOptimizer, DataLoader]:
"""
Expand Down Expand Up @@ -375,7 +379,7 @@ def make_private(
"Module parameters are different than optimizer Parameters"
)

distributed = isinstance(module, (DPDDP, DDP))
distributed = dist.get_world_size() > 1

module = self._prepare_model(
module,
Expand Down Expand Up @@ -427,6 +431,7 @@ def make_private(
clipping=clipping,
grad_sample_mode=grad_sample_mode,
normalize_clipping=normalize_clipping,
optim_args=optim_args,
**kwargs,
)

Expand Down Expand Up @@ -454,6 +459,7 @@ def make_private_with_epsilon(
grad_sample_mode: str = "hooks",
normalize_clipping: bool = False,
total_steps: int = None,
optim_args: dict = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -569,6 +575,7 @@ def make_private_with_epsilon(
clipping=clipping,
normalize_clipping=normalize_clipping,
total_steps=total_steps,
optim_args=optim_args,
)

def get_epsilon(self, delta):
Expand Down
Loading