-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_sarsa.py
118 lines (88 loc) · 3.57 KB
/
test_sarsa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from mdp.algorithms.sarsa import SARSA
from mdp.environment.env import Environment
import numpy as np
import pygame
pygame.init()
from env_config import grid, actions, rewards, gw, gh
CONVERT_POLICY = {1: '↓', 0: '↑' , 2: '→', 3: '←'}
DISPLAY_GRID = True
UTILITY_FONT_SIZE = 15
UTILITY_OFFSET = (4, 14)
POLICY_FONT_SIZE = 30
POLICY_OFFSET = (17, 5)
ratio = 1
mdp = Environment(grid, actions, rewards, gw, gh)
sarsa = SARSA(n_w=6, n_h=6, n_actions=4)
q_table = sarsa.solve(mdp=mdp)
print('\n(Column, Row)')
for i in range(q_table.shape[0]):
for j in range(q_table.shape[1]):
print(f"{j, i}: {max(q_table[i][j])}")
# Display utility and policy plot
if DISPLAY_GRID:
GREEN = (100, 200, 100)
RED = (200, 100, 100)
WHITE = (200, 200, 200)
GREY = (50, 50, 50)
directions = [[CONVERT_POLICY[np.argmax(cell)] for cell in row] for row in q_table]
utilities = [["{:.3f}".format(np.max(cell)) for cell in row] for row in q_table]
colors = []
for row in grid:
color = []
for cell in row:
if cell == 'W':
color.append(GREY)
elif cell == 'G':
color.append(GREEN)
elif cell == 'R':
color.append(RED)
else:
color.append(WHITE)
colors.append(color)
block_size = 50
width = 300
height = 300
screen_dimensions = (width, height)
screen_color = (0, 0, 0)
policy_font = pygame.font.Font("assets/seguisym.ttf", int(POLICY_FONT_SIZE*ratio))
utility_font = pygame.font.Font("assets/seguisym.ttf", int(UTILITY_FONT_SIZE*ratio))
screen = pygame.display.set_mode(screen_dimensions)
pygame.display.set_caption('SARSA')
# Display Policy
running = True
while running:
for event in pygame.event.get():
if event.type == pygame.QUIT:
running = False
rect = pygame.Rect(0, 0, width, height)
pygame.draw.rect(screen, screen_color, rect)
for row in range(len(grid)):
for col in range(len(grid)):
rect = pygame.Rect(col * block_size, row * block_size, block_size, block_size)
pygame.draw.rect(screen, colors[row][col], rect)
pygame.draw.rect(screen, (0, 0, 0), rect, 1)
if grid[row][col] == 'W':
continue
message = policy_font.render(directions[row][col], True, (0, 0, 0))
screen.blit(message, (col * block_size + POLICY_OFFSET[0] * ratio, row * block_size + POLICY_OFFSET[1]*ratio))
pygame.display.update()
screen = pygame.display.set_mode(screen_dimensions)
pygame.display.set_caption('SARSA')
# Display Utilities
running = True
while running:
for event in pygame.event.get():
if event.type == pygame.QUIT:
running = False
rect = pygame.Rect(0, 0, width, height)
pygame.draw.rect(screen, screen_color, rect)
for row in range(len(grid)):
for col in range(len(grid)):
rect = pygame.Rect(col * block_size, row * block_size, block_size, block_size)
pygame.draw.rect(screen, colors[row][col], rect)
pygame.draw.rect(screen, (0, 0, 0), rect, 1)
if grid[row][col] == 'W':
continue
message = utility_font.render(utilities[row][col], True, (0, 0, 0))
screen.blit(message, (col * block_size + UTILITY_OFFSET[0]*ratio, row * block_size + UTILITY_OFFSET[1]*ratio))
pygame.display.update()