-
Notifications
You must be signed in to change notification settings - Fork 1
/
meta_utils.py
57 lines (50 loc) · 1.9 KB
/
meta_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
from torch import nn
from torch.autograd import grad
from learn2learn.algorithms.base_learner import BaseLearner
from learn2learn.utils import clone_module, clone_parameters
from learn2learn.algorithms.meta_sgd import meta_sgd_update
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class ParamMetaSGD(BaseLearner):
def __init__(self, model, lr=1.0, first_order=True, lrs=None):
super(ParamMetaSGD, self).__init__()
self.module = model
if lrs is None:
lrs = nn.ParameterList(
[
nn.Parameter(torch.Tensor([lr]).to(DEVICE))
for p in model.parameters()
]
)
self.lrs = lrs
self.first_order = first_order
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
def clone(self):
"""
**Descritpion**
Akin to `MAML.clone()` but for MetaSGD: it includes a set of learnable fast-adaptation
learning rates.
"""
return ParamMetaSGD(
clone_module(self.module),
lrs=clone_parameters(self.lrs),
first_order=self.first_order,
)
def adapt(self, loss, first_order=None, retain_graph=False, allow_unused=False):
"""
**Descritpion**
Akin to `MAML.adapt()` but for MetaSGD: it updates the model with the learnable
per-parameter learning rates.
"""
if first_order is None:
first_order = self.first_order
second_order = not first_order
gradients = grad(
loss,
self.module.parameters(),
retain_graph=second_order or retain_graph,
create_graph=second_order,
allow_unused=allow_unused,
)
self.module = meta_sgd_update(self.module, self.lrs, gradients)