-
Notifications
You must be signed in to change notification settings - Fork 2
/
play.py
143 lines (128 loc) · 5.56 KB
/
play.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import argparse
import os
import pickle
import sys
import numpy as np
import matplotlib.pylab as plt
from tictactoe.agent import Qlearner, SARSAlearner
from tictactoe.teacher import Teacher
from tictactoe.game import Game
def plot_agent_reward(rewards):
""" Function to plot agent's accumulated reward vs. iteration """
plt.plot(np.cumsum(rewards))
plt.title('Agent Cumulative Reward vs. Iteration')
plt.ylabel('Reward')
plt.xlabel('Episode')
plt.show()
class GameLearning(object):
"""
A class that holds the state of the learning process. Learning
agents are created/loaded here, and a count is kept of the
games that have been played.
"""
def __init__(self, args, alpha=0.5, gamma=0.9, epsilon=0.1):
self.games_played = 0
self.qlearner_agent_path = './trained_agents/qlearner_agent.pkl'
self.sarsa_agent_path = './trained_agents/sarsa_agent.pkl'
if args.load:
# load agent
if args.agent_type == 'q':
# QLearner
try:
f = open(self.qlearner_agent_path,'rb')
except IOError:
print("The agent file does not exist. Quitting.")
sys.exit(0)
else:
# SarsaLearner
try:
f = open(self.sarsa_agent_path,'rb')
except IOError:
print("The agent file does not exist. Quitting.")
sys.exit(0)
self.agent = pickle.load(f)
f.close()
# If plotting, show plot and quit
if args.plot:
plot_agent_reward(self.agent.rewards)
sys.exit(0)
else:
# check if agent state file already exists, and ask user whether to overwrite if so
if ((args.agent_type == "q" and os.path.isfile(self.qlearner_agent_path)) or
(args.agent_type == "s" and os.path.isfile(self.sarsa_agent_path))):
while True:
response = input("An agent state is already saved for this type. "
"Are you sure you want to overwrite? [y/n]: ")
if response == 'y' or response == 'yes':
break
elif response == 'n' or response == 'no':
print("OK. Quitting.")
sys.exit(0)
else:
print("Invalid input. Please choose 'y' or 'n'.")
if args.agent_type == "q":
self.agent = Qlearner(alpha,gamma,epsilon)
else:
self.agent = SARSAlearner(alpha,gamma,epsilon)
def beginPlaying(self):
""" Loop through game iterations with a human player. """
print("Welcome to Tic-Tac-Toe. You are 'X' and the computer is 'O'.")
def play_again():
print("Games played: %i" % self.games_played)
while True:
play = input("Do you want to play again? [y/n]: ")
if play == 'y' or play == 'yes':
return True
elif play == 'n' or play == 'no':
return False
else:
print("Invalid input. Please choose 'y' or 'n'.")
while True:
game = Game(self.agent)
game.start()
self.games_played += 1
if not play_again():
print("OK. Quitting.")
break
def beginTeaching(self, episodes):
""" Loop through game iterations with a teaching agent. """
teacher = Teacher()
# Train for alotted number of episodes
while self.games_played < episodes:
game = Game(self.agent, teacher=teacher)
game.start()
self.games_played += 1
# Monitor progress
if self.games_played % 1000 == 0:
print("Games played: %i" % self.games_played)
plot_agent_reward(self.agent.rewards)
if args.agent_type == "q":
self.agent.save_agent(self.qlearner_agent_path)
elif args.agent_type == "s":
self.agent.save_agent(self.sarsa_agent_path)
if __name__ == "__main__":
# Parse command line arguments
parser = argparse.ArgumentParser(description="Play Tic-Tac-Toe.")
parser.add_argument('-a', "--agent_type", type=str, default="q",
help="Specify the computer agent learning algorithm. "
"AGENT_TYPE='q' for Q-learning and ='s' for Sarsa-learning")
parser.add_argument("-l", "--load", action="store_true",
help="whether to load trained agent")
parser.add_argument("-t", "--teacher_episodes", default=None, type=int,
help="employ teacher agent who knows the optimal "
"strategy and will play for TEACHER_EPISODES games")
parser.add_argument("-p", "--plot", action="store_true",
help="whether to plot reward vs. episode of stored agent "
"and quit")
args = parser.parse_args()
assert args.agent_type == 'q' or args.agent_type == 's', \
"learner type must be either 'q' or 's'."
if args.plot:
assert args.load, "Must load an agent to plot reward."
assert args.teacher_episodes is None, \
"Cannot plot and teach concurrently; must chose one or the other."
gl = GameLearning(args)
if args.teacher_episodes is not None:
gl.beginTeaching(args.teacher_episodes)
else:
gl.beginPlaying()