-
Notifications
You must be signed in to change notification settings - Fork 3
/
n_step_sarsa.py
133 lines (108 loc) · 4.33 KB
/
n_step_sarsa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import sys
from envs.GridWorldEnv import GridWorld, Env
from double_q_learning import epsilon_prob
import numpy as np
import plotly.offline as py
import plotly.graph_objs as go
from utils import Algorithm
class NStepSarsa(Algorithm):
def __init__(self, env: Env, n, alpha=0.1, gamma=1, epsilon=0.1):
self.n = n
self.alpha = alpha
self.gamma = gamma
self.epsilon = epsilon
self.actions = np.arange(env.action_space.n)
obs_space = [space.n for space in env.observation_space.spaces]
self.action_values = np.zeros(obs_space + [len(self.actions)])
self._reset()
def _reset(self):
self.t = 0
self.T = sys.maxsize
self.states_hist = [(0, 0)] * (self.n + 1)
self.actions_hist = [0] * (self.n + 1)
self.rewards_hist = [0] * (self.n + 1)
def _idx(self, time):
return time % (self.n + 1)
def store_action(self, action, time):
self.actions_hist[self._idx(time)] = action
def store_state(self, state, time):
self.states_hist[self._idx(time)] = state
def store_reward(self, reward, time):
self.rewards_hist[self._idx(time)] = reward
def get_state(self, time):
return self.states_hist[self._idx(time)]
def get_action(self, time):
return self.actions_hist[self._idx(time)]
def get_reward(self, time):
return self.rewards_hist[self._idx(time)]
def get_key(self, time):
return self.get_state(time) + (self.get_action(time),)
def action(self, state):
if self.t > 0:
return self.get_action(self.t)
else:
return self._action(state)
def _action(self, state):
greedy = self.greedy_action(state)
probs = [epsilon_prob(greedy, action, len(self.actions), self.epsilon) for action in self.actions]
return np.random.choice(self.actions, p=probs)
def greedy_action(self, state):
return np.argmax(self.action_values[state])
def on_new_state(self, state, action, reward, next_state, done):
if self.t == 0:
self.store_state(state, 0)
self.store_action(action, 0)
self.store_reward(0, 0)
if self.t < self.T:
self.store_state(next_state, self.t + 1)
self.store_reward(reward, self.t + 1)
if done:
self.T = self.t + 1
else:
self.store_action(self._action(next_state), self.t + 1)
update_time = self.t - self.n + 1
if update_time >= 0:
update_key = self.get_key(update_time)
key_t_plus_1 = self.get_key(update_time + self.n)
returns = self.calc_returns(update_time)
if update_time + self.n < self.T:
returns += pow(self.gamma, self.n) * self.action_values[key_t_plus_1]
self.action_values[update_key] += self.alpha * (returns - self.action_values[update_key])
self.t += 1
if done and update_time != self.T - 1:
self.on_new_state(state, action, reward, next_state, done)
elif done:
self._reset()
def calc_returns(self, update_time):
return sum([pow(self.gamma, t - update_time - 1) * self.get_reward(t) for t in
range(update_time + 1, min(update_time + self.n, self.T) + 1)])
def generate_episode(env: Env, algo: Algorithm, render=False):
done = False
count = 0
obs = env.reset()
while not done:
if render:
env.render()
prev_obs = obs
action = algo.action(prev_obs)
obs, reward, done, _ = env.step(action)
algo.on_new_state(prev_obs, action, reward, obs, done)
count += 1
return count
def perform_algo_eval(env, algo_supplier, ns, n_avg=100, n_ep=100, render=False):
ret = np.zeros((len(ns), n_ep))
for i in range(n_avg):
for n_idx, n in enumerate(ns):
print('Run: {} n={}'.format(i, n))
algo = algo_supplier(n)
for ep in range(n_ep):
ret[n_idx][ep] += generate_episode(env, algo)
return ret / n_avg
if __name__ == '__main__':
env = GridWorld()
ns = np.power(2, np.arange(4))
ret = perform_algo_eval(env, lambda n: NStepSarsa(env, n), ns)
data = []
for idx, row in enumerate(ret):
data.append(go.Scatter(y=row, name='{}-step Sarsa'.format(ns[idx])))
py.plot(data)