-
Notifications
You must be signed in to change notification settings - Fork 2
/
mcts_gif.py
56 lines (50 loc) · 1.47 KB
/
mcts_gif.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from gamestate import *
from playfield import *
from playfield_controller import *
from modeling import *
import torch
import torch.nn as nn
from torch.distributions import Categorical
import numpy as np
import os
from PIL import Image
from PIL import ImageFont
import shutil
model = torch.load("MCTS_value_soft_ce.pth").cuda()
model.eval()
highest_reward = 0
while True:
os.mkdir("figs")
pc = PlayfieldController()
pc.update()
count = 0
tree = MCTS(model=model, pc=pc, gamma=0.95)
total_reward = 0
steps = 0
while not pc._game_over and steps < 500:
steps += 1
gs = pc.gamestate()
action = tree.search(pc, num_iter=50)
if action == 0:
pc.move_left()
elif action == 1:
pc.move_right()
elif action == 2:
pc.rotate_cw()
tree.root = tree.root.child_nodes[action]
gs.plot('figs/im%d.jpg' % count)
count += 1
prev_score = pc._score
pc.update()
reward = pc._score - prev_score
total_reward += reward
print(total_reward)
if total_reward > highest_reward:
images = []
for c in range(count):
im = Image.open("figs/im%d.jpg" % (c))
images.append(im)
images[0].save("games/%s.gif" % "MCTS_value_soft_ce", save_all=True, append_images = images[1:],loop=0,duration=1)
with open("highest","w") as f:
f.write(str(total_reward))
shutil.rmtree("figs")