-
Notifications
You must be signed in to change notification settings - Fork 0
/
qr_dqn_agent.py
138 lines (110 loc) · 5.91 KB
/
qr_dqn_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from copy import copy
from typing import Union
import numpy as np
from rl_coach.agents.dqn_agent import DQNAgentParameters, DQNNetworkParameters, DQNAlgorithmParameters
from rl_coach.agents.value_optimization_agent import ValueOptimizationAgent
from rl_coach.architectures.head_parameters import QuantileRegressionQHeadParameters
from rl_coach.core_types import StateType
from rl_coach.schedules import LinearSchedule
class QuantileRegressionDQNNetworkParameters(DQNNetworkParameters):
def __init__(self):
super().__init__()
self.heads_parameters = [QuantileRegressionQHeadParameters()]
self.learning_rate = 0.00005
self.optimizer_epsilon = 0.01 / 32
class QuantileRegressionDQNAlgorithmParameters(DQNAlgorithmParameters):
"""
:param atoms: (int)
the number of atoms to predict for each action
:param huber_loss_interval: (float)
One of the huber loss parameters, and is referred to as :math:`\kapa` in the paper.
It describes the interval [-k, k] in which the huber loss acts as a MSE loss.
"""
def __init__(self):
super().__init__()
self.atoms = 200
self.huber_loss_interval = 1 # called k in the paper
class QuantileRegressionDQNAgentParameters(DQNAgentParameters):
def __init__(self):
super().__init__()
self.algorithm = QuantileRegressionDQNAlgorithmParameters()
self.network_wrappers = {"main": QuantileRegressionDQNNetworkParameters()}
self.exploration.epsilon_schedule = LinearSchedule(1, 0.01, 1000000)
self.exploration.evaluation_epsilon = 0.001
@property
def path(self):
return 'rl_coach.agents.qr_dqn_agent:QuantileRegressionDQNAgent'
# Quantile Regression Deep Q Network - https://arxiv.org/pdf/1710.10044v1.pdf
class QuantileRegressionDQNAgent(ValueOptimizationAgent):
def __init__(self, agent_parameters, parent: Union['LevelManager', 'CompositeAgent']=None):
super().__init__(agent_parameters, parent)
self.quantile_probabilities = np.ones(self.ap.algorithm.atoms) / float(self.ap.algorithm.atoms)
@property
def is_on_policy(self) -> bool:
return False
def get_q_values(self, quantile_values):
return np.dot(quantile_values, self.quantile_probabilities)
# prediction's format is (batch,actions,atoms)
def get_all_q_values_for_states(self, states: StateType):
if self.exploration_policy.requires_action_values():
quantile_values = self.get_prediction(states)
actions_q_values = self.get_q_values(quantile_values)
else:
actions_q_values = None
return actions_q_values
# prediction's format is (batch,actions,atoms)
def get_all_q_values_for_states_and_softmax_probabilities(self, states: StateType):
actions_q_values, softmax_probabilities = None, None
if self.exploration_policy.requires_action_values():
outputs = copy(self.networks['main'].online_network.outputs)
outputs.append(self.networks['main'].online_network.output_heads[0].softmax)
quantile_values, softmax_probabilities = self.get_prediction(states, outputs)
actions_q_values = self.get_q_values(quantile_values)
return actions_q_values, softmax_probabilities
def learn_from_batch(self, batch):
network_keys = self.ap.network_wrappers['main'].input_embedders_parameters.keys()
# get the quantiles of the next states and current states
next_state_quantiles, current_quantiles = self.networks['main'].parallel_prediction([
(self.networks['main'].target_network, batch.next_states(network_keys)),
(self.networks['main'].online_network, batch.states(network_keys))
])
# add Q value samples for logging
self.q_values.add_sample(self.get_q_values(current_quantiles))
# get the optimal actions to take for the next states
target_actions = np.argmax(self.get_q_values(next_state_quantiles), axis=1)
# calculate the Bellman update
batch_idx = list(range(batch.size))
TD_targets = batch.rewards(True) + (1.0 - batch.game_overs(True)) * self.ap.algorithm.discount \
* next_state_quantiles[batch_idx, target_actions]
# get the locations of the selected actions within the batch for indexing purposes
actions_locations = [[b, a] for b, a in zip(batch_idx, batch.actions())]
# calculate the cumulative quantile probabilities and reorder them to fit the sorted quantiles order
cumulative_probabilities = np.array(range(self.ap.algorithm.atoms + 1)) / float(self.ap.algorithm.atoms) # tau_i
quantile_midpoints = 0.5*(cumulative_probabilities[1:] + cumulative_probabilities[:-1]) # tau^hat_i
quantile_midpoints = np.tile(quantile_midpoints, (batch.size, 1))
sorted_quantiles = np.argsort(current_quantiles[batch_idx, batch.actions()])
for idx in range(batch.size):
quantile_midpoints[idx, :] = quantile_midpoints[idx, sorted_quantiles[idx]]
# train
result = self.networks['main'].train_and_sync_networks({
**batch.states(network_keys),
'output_0_0': actions_locations,
'output_0_1': quantile_midpoints,
}, TD_targets)
total_loss, losses, unclipped_grads = result[:3]
return total_loss, losses, unclipped_grads