-
Notifications
You must be signed in to change notification settings - Fork 1
/
naf.py
71 lines (61 loc) · 2.89 KB
/
naf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.autograd import Variable
import torch.nn.functional as F
import policy as p
import numpy as np
def softUpdate(target, source, tau):
for targetParam, sourceParam in zip(target.parameters(), source.parameters()):
targetParam.data.copy_(targetParam.data * (1.0 - tau) + sourceParam.data*tau)
def hardUpdate(target, source):
softUpdate(target, source, 1.0)
class NAF:
def __init__(self, gamma, tau, hiddenSize, numInputs, actionSpace, device = torch.device('cpu')):
self.device = device
self.actionSpace = actionSpace
self.numInputs = numInputs
self.model = p.Policy(hiddenSize, numInputs, actionSpace, device).to(device=device)
self.target = p.Policy(hiddenSize, numInputs, actionSpace, device).to(device=device)
hardUpdate(self.target, self.model)
self.gamma = gamma
self.tau = tau
self.optimizer = Adam(self.model.parameters(), lr=1e-4, weight_decay=1e-5)
self.loss = torch.nn.MSELoss(reduction='sum')
def selectAction(self, state, actionNoise = False, useTarget = False):
if (useTarget):
mu, _, _ = self.target.((state,None))
else:
self.model.eval()
mu, _, _ = self.model((state, None))
self.model.train()
mu = mu.data
if actionNoise:
mu += torch.Tensor(np.random.standard_normal(mu.shape)).to(self.device)
return mu.clamp(-1, 1)
def updateParameters(self, batch, device):
#Sample a random minibatch of m transitions
stateBatch = torch.Tensor(np.concatenate(batch.state)).to(device)
actionBatch = torch.Tensor(np.concatenate(batch.action)).to(device)
rewardBatch = torch.Tensor(np.concatenate(batch.reward)).to(device)
maskBatch = torch.Tensor(np.concatenate(batch.mask)).to(device)
nextStateBatch = torch.Tensor(np.concatenate(batch.nextState)).to(device)
#Set y_i = r_i + gamma*V'(x_t+1 | Q')
_, _, nextStateValues = self.target((nextStateBatch, None))
rewardBatch = rewardBatch.unsqueeze(1)
maskBatch = maskBatch.unsqueeze(1)
expectedStateActionValues = rewardBatch + (self.gamma * maskBatch + nextStateValues)
#Update Q by minimizing the loss
_, stateActionValues, _ = self.model((stateBatch, actionBatch))
loss = self.loss(stateActionValues, expectedStateActionValues)
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1)
self.optimizer.step()
#Update the target network Q'
softUpdate(self.target, self.model, self.tau)
return loss.item()
def saveModel(self, modelPath):
torch.save(self.model.state_dict(), modelPath)
def loadModel(self, modelPath):
self.model.load_state_dict(torch.load(modelPath))