-
Notifications
You must be signed in to change notification settings - Fork 0
/
ddpg.py
88 lines (66 loc) · 3.64 KB
/
ddpg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch
import torch.optim as optim
import torch.autograd as autograd
import torch.nn.functional as F
from models import Critic
from replay_buffers import BasicBuffer
from networks import Actor
class DDPGAgent:
def __init__(self, env, gamma, tau, buffer_maxlen, critic_learning_rate, actor_learning_rate):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.env = env
self.obs_dim = env.observation_space.shape[0]
self.action_dim = env.action_space.shape[0]
# hyperparameters
self.env = env
self.gamma = gamma
self.tau = tau
# initialize actor and critic networks
self.critic = Critic(self.obs_dim, self.action_dim).to(self.device)
self.critic_target = Critic(self.obs_dim, self.action_dim).to(self.device)
#self.actor = Actor(self.obs_dim, self.action_dim).to(self.device)
#self.actor_target = Actor(self.obs_dim, self.action_dim).to(self.device)
self.actor=Actor().to(self.device)
self.actor_target=Actor().to(self.device)
# Copy critic target parameters
for target_param, param in zip(self.critic_target.parameters(), self.critic.parameters()):
target_param.data.copy_(param.data)
for target_param, param in zip(self.actor_target.parameters(), self.actor.parameters()):
target_param.data.copy_(param.data)
# optimizers
self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=critic_learning_rate)
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_learning_rate)
self.replay_buffer = BasicBuffer(buffer_maxlen)
def get_action(self, obs):
state = torch.FloatTensor(obs).to(self.device)
action = self.actor(state)
action = action.squeeze(0).cpu().detach().numpy()
return action
def update(self, batch_size):
states, actions, rewards, next_states, _ = self.replay_buffer.sample(batch_size)
state_batch, action_batch, reward_batch, next_state_batch, masks = self.replay_buffer.sample(batch_size)
#print(state_batch,action_batch)
state_batch = torch.FloatTensor(state_batch).to(self.device)
action_batch = torch.FloatTensor(action_batch).to(self.device)
reward_batch = torch.FloatTensor(reward_batch).to(self.device)
next_state_batch = torch.FloatTensor(next_state_batch).to(self.device)
masks = torch.FloatTensor(masks).to(self.device)
curr_Q = self.critic.forward(state_batch, action_batch)
next_actions = self.actor_target.forward(next_state_batch)
next_Q = self.critic_target.forward(next_state_batch, next_actions.detach())
expected_Q = reward_batch + self.gamma * next_Q
# update critic
q_loss = F.mse_loss(curr_Q, expected_Q.detach())
self.critic_optimizer.zero_grad()
q_loss.backward()
self.critic_optimizer.step()
# update actor
policy_loss = -self.critic.forward(state_batch, self.actor.forward(state_batch)).mean()
self.actor_optimizer.zero_grad()
policy_loss.backward()
self.actor_optimizer.step()
# update target networks
for target_param, param in zip(self.actor_target.parameters(), self.actor.parameters()):
target_param.data.copy_(param.data * self.tau + target_param.data * (1.0 - self.tau))
for target_param, param in zip(self.critic_target.parameters(), self.critic.parameters()):
target_param.data.copy_(param.data * self.tau + target_param.data * (1.0 - self.tau))