-
Notifications
You must be signed in to change notification settings - Fork 0
/
optim.py
30 lines (23 loc) · 954 Bytes
/
optim.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import itertools
import torch
class Optimizer(object):
_ARG_MAX_GRAD_NORM = 'max_grad_norm'
def __init__(self, optim, max_grad_norm=0):
self.optimizer = optim
self.scheduler = None
self.max_grad_norm = max_grad_norm
def set_scheduler(self, scheduler):
self.scheduler = scheduler
def step(self):
""" Performs a single optimization step, including gradient norm clipping if necessary. """
if self.max_grad_norm > 0:
params = itertools.chain.from_iterable([group['params'] for group in self.optimizer.param_groups])
torch.nn.utils.clip_grad_norm_(params, self.max_grad_norm)
self.optimizer.step()
def update(self, loss, epoch):
if self.scheduler is None:
pass
elif isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
self.scheduler.step(loss)
else:
self.scheduler.step()