-
Notifications
You must be signed in to change notification settings - Fork 1
/
test_n_step.py
44 lines (32 loc) · 1.38 KB
/
test_n_step.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import numpy as np
import pprint
from operator import itemgetter
from mdp_matrix import GridWorld, WindyGridCliffMazeWorld, StochasticGridWorld
from sarsa import sarsa
from expected_sarsa import expected_sarsa
from n_step_expected_sarsa import n_step_expected_sarsa
from double_sarsa import double_sarsa
from double_expected_sarsa import double_expected_sarsa
from n_step_sarsa import n_step_sarsa
from n_step_tree_backup import n_step_tree_backup
from q_sigma_with_varianced_sigma import n_step_q_sigma
start_state = [0, 0]
test_rewards = [[i, j, -1] for i in range(5) for j in range(5)]
test_rewards[2] = [0, 2, 1]
test_rewards[23] = [4,3, 1]
def transform_to_actions(f):
actions = ["S", "E", "N", "W"]
return np.array([actions[x] for x in f])
rewards = np.ones((5,5))
for x,y,r in test_rewards:
rewards[x][y] = r
gw = GridWorld(5, test_rewards, terminal_states=[2, 23] )
# print gw.T
# Q, ave_reward, max_reward, rewards_per_episode, Q_variances = n_step_sarsa(gw, 20000, alpha=.1, n=10)
Q, ave_reward, max_reward, rewards_per_episode, Q_variances = n_step_q_sigma(gw, 20000, alpha=.7, n=4)
sarsa_Q, ave_reward, max_reward, rewards_per_episode, Q_variances = sarsa(gw, 20000, alpha=.1)
print "REWARDS"
print np.reshape(np.array(rewards), (5,5))
print np.reshape(transform_to_actions(np.argmax(Q, 1)), (5,5))
print "SARSA"
print np.reshape(transform_to_actions(np.argmax(sarsa_Q, 1)), (5,5))