This repository has been archived by the owner on Oct 9, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2
/
train_MCTS.py
151 lines (127 loc) · 4.77 KB
/
train_MCTS.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from playfield_controller import PlayfieldController
from gamestate import *
from playfield import *
from playfield_controller import *
import numpy.random
from numpy.random import randint
from modeling import *
from copy import deepcopy
from torch.utils.data import DataLoader
import random
import os
from torch.utils.tensorboard import SummaryWriter
######## for testing ########
import inspect
from time import time
from functools import wraps
def measure_time(f):
'''
measure function runtime
https://stackoverflow.com/questions/51503672/decorator-for-timeit-timeit-method
'''
@wraps(f)
def _time_it(*args, **kwargs):
# name of function being decorated
funcname = inspect.getouterframes(inspect.currentframe())[1].function
# add args and kwargs
start = int(round(time() * 1000))
try:
return f(*args, **kwargs)
finally:
end_ = int(round(time() * 1000)) - start
TIMING_LOGGING_THRESHOLD_MS = 69 # only log slow functions
if end_ >= TIMING_LOGGING_THRESHOLD_MS:
print(f"{funcname}: {end_ if end_ > 0 else 0} ms")
return _time_it
@measure_time
######## testing part end ########
def train_dataset(datasets):
global game
global writer
global model
global model_temp
model.train()
total_training_loss = 0.0
optim = torch.optim.Adam(model.parameters())
total_valid_data = []
for data in datasets[:3]:
total_valid_data += data
valid_datasets = Gameplays(total_valid_data)
total_data = []
for data in datasets[3:]:
total_data += data
train_datasets = Gameplays(total_data)
train_loader = DataLoader(train_datasets, batch_size=64)
valid_loader = DataLoader(valid_datasets, batch_size=64)
prev_valid_loss_average = None
c = 0
average_norm = 0.0
counter = 0
while c < 200:
temp_total_training_loss = 0.0
temp_average_norm = 0.0
total_loss = 0
for state, value in valid_loader:
state = state.to(device)
value = value.to(device)
predicted_value = model(state)
value = torch.min(value.float(), torch.tensor([1.0]).to(device))
loss = - torch.mean( value * predicted_value.log() + (1.0 - value) * (1 - predicted_value).log())
total_loss += loss.item()
total_loss /= len(valid_loader)
print(total_loss)
valid_loss_average = total_loss
c += 1
if prev_valid_loss_average == None or prev_valid_loss_average > valid_loss_average:
prev_valid_loss_average = valid_loss_average
model_temp.load_state_dict(model.state_dict())
counter = 0
else:
counter += 1
for state, value in train_loader:
state = state.to(device)
value = value.to(device)
predicted_value = model(state)
value = torch.min(value.float(), torch.tensor([1.0]).to(device))
loss = - torch.mean( value * predicted_value.log() + (1.0 - value) * (1 - predicted_value).log())
optim.zero_grad()
loss.backward()
total_norm = 0
for p in model.parameters():
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** (1. / 2)
optim.step()
temp_total_training_loss += loss.item()
temp_average_norm += total_norm
average_norm = temp_average_norm / len(train_loader)
total_training_loss = temp_total_training_loss / len(train_loader)
if counter > 5:
model.load_state_dict(model_temp.state_dict())
torch.save(model, "MCTS.pth")
break
writer.add_scalar('Training loss', total_training_loss, game)
writer.add_scalar('Training Norm', average_norm, game)
writer.add_scalar('Validation loss', prev_valid_loss_average, game)
writer.add_scalar('Number of Epoch', c, game)
path = input("Enter a label for this training:")
writer = SummaryWriter("runs/"+path)
print("Checking if CUDA is available:", torch.cuda.is_available())
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = Model().to(device)
model_temp = Model().to(device)
game = 0
datasets = []
while True:
pc = PlayfieldController()
pc.update()
model.eval()
tree = MCTS(model=model.to(device), pc=pc, gamma=0.999)
data, _, reward = tree.generate_a_game(num_iter=50, max_steps=500, stats_writer=(writer, game))
game += 1
datasets.append(data)
if game % 25 == 0:
train_dataset(datasets)
datasets = []
if game % 100 == 0:
torch.save(model, "MCTS_%d.pth" % game)