-
Notifications
You must be signed in to change notification settings - Fork 1
/
dqn_agent.py
executable file
·199 lines (164 loc) · 7.1 KB
/
dqn_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
import os
import numpy as np
from typing import Tuple, Optional
from agents.base import Agent
from agents.policies.base_policy import Policy
from copy import deepcopy
import torch
from agents.memory.prioritized_memory import PrioritizedMemory
from tools.rl_constants import Experience
from torch.optim.lr_scheduler import _LRScheduler
from tools.misc import set_seed
from agents.models.base import BaseModel
from tools.misc import soft_update
from tools.rl_constants import ExperienceBatch, Action
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class DQNAgent(Agent):
"""Interacts with and learns from the environment."""
def __init__(self,
state_shape: Tuple[int, ...],
action_size: int,
model: BaseModel,
policy: Policy,
memory: PrioritizedMemory,
lr_scheduler: _LRScheduler,
optimizer: torch.optim.Optimizer,
batch_size: int = 32,
gamma: float = 0.95,
tau: float = 1e-3,
update_frequency: int = 5,
seed: int = None,
action_repeats: int = 1,
gradient_clip: float = 1,
):
"""Initialize an Agent object.
Args:
state_shape (Tuple[int, ...]): Shape of the state
action_size (int): Number of possible integer actions
model (torch.nn.Module): Model producing actions from state
policy (Policy):
memory: Memory,
lr_scheduler: _LRScheduler,
optimizer: torch.optim.Optimizer,
batch_size: int = 32,
gamma: float = 0.95,
tau: float = 1e-3,
update_frequency: int = 5,
seed: int = None
"""
super().__init__(action_size=action_size, state_shape=state_shape)
self.batch_size = batch_size
self.gamma = gamma
self.tau = tau
self.update_frequency = update_frequency
self.gradient_clip = gradient_clip
self.previous_action: Optional[Action] = None
self.action_repeats = action_repeats
# Double DQN
self.online_qnetwork = model.to(device)
self.target_qnetwork = deepcopy(model).to(device).eval()
self.memory = memory
self.losses = []
self.policy: Policy = policy
self.optimizer: optimizer = optimizer
self.lr_scheduler: _LRScheduler = lr_scheduler
if seed:
set_seed(seed)
self.online_qnetwork.set_seed(seed)
self.target_qnetwork.set_seed(seed)
def set_mode(self, mode: str):
if mode == 'train':
self.online_qnetwork.train()
self.target_qnetwork.train()
self.policy.train()
elif mode == 'eval':
self.online_qnetwork.eval()
self.target_qnetwork.eval()
self.policy.eval()
else:
raise ValueError('Invalid mode: {}'.format(mode))
def load(self, path_to_online_network_pth: str):
assert os.path.exists(path_to_online_network_pth), "Path does not exist"
self.online_qnetwork.load_state_dict(torch.load(path_to_online_network_pth))
def preprocess_state(self, state: torch.Tensor):
preprocessed_state = self.online_qnetwork.preprocess_state(state)
return preprocessed_state
def step_episode(self, episode: int, param_frequency: int = 10, *args):
self.episode_counter += 1
self.policy.step_episode(episode)
self.lr_scheduler.step()
self.memory.step_episode(episode)
self.online_qnetwork.step_episode(episode)
self.target_qnetwork.step_episode(episode)
return True
def step(self, experience: Experience, **kwargs) -> None:
"""Step the agent in response to a change in environment"""
# Add the experience, defaulting the priority 0
self.memory.add(experience)
if self.warmup:
return
else:
self.t_step += 1
# If enough samples are available in memory, get random subset and learn
if self.t_step % self.update_frequency == 0 and len(self.memory) > self.batch_size:
experience_batch = self.memory.sample(self.batch_size)
experience_batch = experience_batch.to(device)
loss, errors = self.learn(experience_batch)
with torch.no_grad():
if errors.min() < 0:
raise RuntimeError("Errors must be > 0, found {}".format(errors.min()))
priorities = errors.detach().cpu().numpy()
self.memory.update(experience_batch.sample_idxs, priorities)
# Perform any post-backprop updates
self.online_qnetwork.step()
self.target_qnetwork.step()
self.param_capture.add('loss', loss)
def get_action(self, state: torch.Tensor, *args, **kwargs) -> Action:
"""Returns actions for given state as per current policy.
Args:
state (np.array): Current environment state
Returns:
action (int): The action to perform
"""
state = state.to(device)
state.requires_grad = False
if not self.training:
# Run in evaluation mode
action: Action = self.policy.get_action(state=state, model=self.online_qnetwork)
else:
if not self.previous_action or self.t_step % self.action_repeats == 0:
# Get the action from the policy
action: Action = self.policy.get_action(state=state, model=self.online_qnetwork)
self.previous_action = action
else:
# Repeat the last action
action: Action = self.previous_action
return action
def get_random_action(self, state: torch.Tensor, *args, **kwargs) -> Action:
action = np.array(np.random.random_integers(0, self.action_size - 1, (1, )))
action = Action(value=action)
return action
def learn(self, experience_batch: ExperienceBatch) -> tuple:
"""Update value parameters using given batch of experience tuples and return TD error
Args:
experience_batch (ExperienceBatch): Minibatch of experience
Returns:
td_errors (torch.FloatTensor): The TD errors for each sample
"""
# By default, calculate TD errors. Some DQN modifications (eg. categorical DQN) use custom errors/loss
loss, errors = self.policy.compute_errors(
self.online_qnetwork,
self.target_qnetwork,
experience_batch,
gamma=self.gamma
)
assert errors.min() >= 0
# Perform optimization step
self.optimizer.zero_grad()
loss.backward()
for param in self.online_qnetwork.parameters():
param.grad.data.clamp_(-self.gradient_clip, self.gradient_clip)
self.optimizer.step()
# Perform a soft update of the target -> local network
soft_update(self.online_qnetwork, self.target_qnetwork, self.tau)
return loss, errors