-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
150 lines (115 loc) · 4.1 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import random
from collections import deque
import torch
import torch.nn.functional as F
from torch import nn
class ReplayBuffer:
def __init__(self, size: int):
"""Replay buffer initialisation
Args:
size: maximum numbers of objects stored by replay buffer
"""
self.size = size
self.buffer = deque([], size)
def push(self, transition) -> list:
"""Push an object to the replay buffer
Args:
transition: object to be stored in replay buffer. Can be of any type
Returns:
The current memory of the buffer (any iterable object e.g. list)
"""
self.buffer.append(transition)
return self.buffer
def sample(self, batch_size: int) -> list:
"""Get a random sample from the replay buffer
Args:
batch_size: size of sample
Returns:
iterable (e.g. list) with objects sampled from buffer without replacement
"""
return random.sample(self.buffer, batch_size)
class DQN(nn.Module):
def __init__(self, layer_sizes: list[int]):
"""
DQN initialisation
Args:
layer_sizes: list with size of each layer as elements
"""
super(DQN, self).__init__()
self.layers = nn.ModuleList(
[nn.Linear(layer_sizes[i], layer_sizes[i + 1]) for i in range(len(layer_sizes) - 1)]
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass through the DQN
Args:
x: input to the DQN
Returns:
outputted value by the DQN
"""
for layer in self.layers:
x = F.relu(layer(x))
return x
def greedy_action(dqn: DQN, state: torch.Tensor) -> int:
"""Select action according to a given DQN
Args:
dqn: the DQN that selects the action
state: state at which the action is chosen
Returns:
Greedy action according to DQN
"""
return int(torch.argmax(dqn(state)))
def epsilon_greedy(epsilon: float, dqn: DQN, state: torch.Tensor) -> int:
"""Sample an epsilon-greedy action according to a given DQN
Args:
epsilon: parameter for epsilon-greedy action selection
dqn: the DQN that selects the action
state: state at which the action is chosen
Returns:
Sampled epsilon-greedy action
"""
q_values = dqn(state)
num_actions = q_values.shape[0]
greedy_act = int(torch.argmax(q_values))
p = float(torch.rand(1))
if p > epsilon:
return greedy_act
else:
return random.randint(0, num_actions - 1)
def update_target(target_dqn: DQN, policy_dqn: DQN):
"""Update target network parameters using policy network.
Does not return anything but modifies the target network passed as parameter
Args:
target_dqn: target network to be modified in-place
policy_dqn: the DQN that selects the action
"""
target_dqn.load_state_dict(policy_dqn.state_dict())
def loss(
policy_dqn: DQN,
target_dqn: DQN,
states: torch.Tensor,
actions: torch.Tensor,
rewards: torch.Tensor,
next_states: torch.Tensor,
dones: torch.Tensor,
ddqn=False,
) -> torch.Tensor:
"""Calculate Bellman error loss
Args:
policy_dqn: policy DQN
target_dqn: target DQN
states: batched state tensor
actions: batched action tensor
rewards: batched rewards tensor
next_states: batched next states tensor
dones: batched Boolean tensor, True when episode terminates
ddqn: compute loss for ddqn
Returns:
Float scalar tensor with loss value
"""
bellman_targets = (~dones).reshape(-1) * (target_dqn(next_states)).max(1).values + rewards.reshape(-1)
q_values = policy_dqn(states).gather(1, actions).reshape(-1)
if ddqn:
ddqn_actions = policy_dqn(next_states).max(1)[1].reshape(-1, 1)
tensor_of_values = target_dqn(next_states).gather(1, ddqn_actions).reshape(-1)
bellman_targets = (~dones).reshape(-1) * tensor_of_values + rewards.reshape(-1)
return ((q_values - bellman_targets) ** 2).mean()