-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsnake.py
135 lines (128 loc) · 3.57 KB
/
snake.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import pygame
import argparse
from argparse import RawTextHelpFormatter
import control
from models import Environment
if __name__ == "__main__":
parser = argparse.ArgumentParser(formatter_class=RawTextHelpFormatter)
parser.add_argument(
"--train",
dest="train",
default=False,
action=argparse.BooleanOptionalAction,
help="Train and save Q.pkl in current directory,\
load it when not `train`. \nTo stop training press `Ctrl-C`.",
)
parser.add_argument(
"--load",
dest="load",
default=False,
action=argparse.BooleanOptionalAction,
help="Continue training after loading action-value file",
)
parser.add_argument(
"--delay",
dest="delay",
default=0.1,
type=float,
help="Controls snake speed in visual mode",
)
parser.add_argument(
"--brick",
dest="brick",
default=30,
type=int,
help="Size of a grid cell in pixels",
)
parser.add_argument(
"--x", dest="x", default=4, type=int, help="Frame `x` size in bricks"
)
parser.add_argument(
"--y", dest="y", default=4, type=int, help="Frame `y` size in bricks"
)
parser.add_argument(
"--grow",
dest="grow",
default=False,
action=argparse.BooleanOptionalAction,
help="Whether to grow snake on eating a target",
)
parser.add_argument(
"--debug",
dest="debug",
default=False,
action=argparse.BooleanOptionalAction,
help="Move snake by hand and see it's state",
)
parser.add_argument(
"--algo",
default="sarsa",
const="sarsa",
nargs="?",
choices=["mc", "sarsa", "ql"],
help="algorithm (default: %(default)s)",
)
parser.add_argument(
"--epsilon",
dest="epsilon",
default=0.1,
type=float,
help="Exploration strength",
)
parser.add_argument(
"--alpha",
dest="alpha",
default=0.05,
type=float,
help="Temporal difference step size",
)
parser.add_argument(
"--steps",
dest="steps",
default=4,
type=int,
help="Number of steps for temporal difference method (n-step sarsa)",
)
parser.add_argument(
"--episodes",
dest="episodes",
default=0,
type=int,
help="Maximum number of episodes to run",
)
args = parser.parse_args()
env = Environment(args.x, args.y, args.brick, args.grow)
game = None
if not args.train:
# Initialize game window
pygame.display.set_caption("Snake")
game = pygame.display.set_mode(
(args.x * args.brick, args.y * args.brick)
)
if args.debug:
control.debug(game, env)
elif args.train:
if args.algo == "mc":
alg = control.MonteCarlo(game, env, epsilon=args.epsilon)
elif args.algo == "sarsa":
alg = control.Sarsa(
game,
env,
n=args.steps,
epsilon=args.epsilon,
alpha=args.alpha,
)
elif args.algo == "ql":
alg = control.QLearning(
game,
env,
n=args.steps,
epsilon=args.epsilon,
alpha=args.alpha,
)
else:
raise NotImplementedError(args.algo)
alg.train(args.episodes, args.load)
else:
ctrl = control.Control(game, env, epsilon=args.epsilon)
ctrl.follow(delay=args.delay)