-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpolicy.py
71 lines (59 loc) · 2.6 KB
/
policy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch as th
class Policy(th.nn.Module):
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, device, freeze_output_layer=False,
learn_h0=True, freeze_input_layer=False,freeze_bias_hidden=False,freeze_h0=False):
super().__init__()
self.device = device
self.hidden_dim = hidden_dim
self.n_layers = 1
self.gru = th.nn.GRU(input_dim, hidden_dim, 1, batch_first=True)
self.fc = th.nn.Linear(hidden_dim, output_dim)
self.sigmoid = th.nn.Sigmoid()
if freeze_output_layer:
for param in self.fc.parameters():
param.requires_grad = False
if freeze_input_layer:
for name, param in self.gru.named_parameters():
if name == "weight_ih_l0" or name == "bias_ih_l0":
param.requires_grad = False
if freeze_bias_hidden:
for name, param in self.gru.named_parameters():
if name == "bias_hh_l0":
param.requires_grad = False
# the default initialization in torch isn't ideal
for name, param in self.named_parameters():
if name == "gru.weight_ih_l0":
th.nn.init.xavier_uniform_(param)
elif name == "gru.weight_hh_l0":
th.nn.init.orthogonal_(param)
elif name == "gru.bias_ih_l0":
th.nn.init.zeros_(param)
elif name == "gru.bias_hh_l0":
th.nn.init.zeros_(param)
elif name == "fc.weight":
th.nn.init.xavier_uniform_(param)
elif name == "fc.bias":
th.nn.init.constant_(param, -5.)
else:
raise ValueError
if learn_h0:
self.h0 = th.nn.Parameter(th.zeros(self.n_layers, 1, hidden_dim), requires_grad=True)
if freeze_h0:
for name, param in self.named_parameters():
if name == "h0":
param.requires_grad = False
self.to(device)
def forward(self, x, h0):
# TODO
# Here I can add noise to h0 before applying
y, h = self.gru(x[:, None, :], h0)
#hidden_noise = 1e-3
u = self.sigmoid(self.fc(y)).squeeze(dim=1)
return u, h
def init_hidden(self, batch_size):
if hasattr(self, 'h0'):
hidden = self.h0.repeat(1, batch_size, 1).to(self.device)
else:
weight = next(self.parameters()).data
hidden = weight.new(self.n_layers, batch_size, self.hidden_dim).zero_().to(self.device)
return hidden