-
Notifications
You must be signed in to change notification settings - Fork 0
/
Actor.py
64 lines (53 loc) · 2.65 KB
/
Actor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
from Config import ACTIVATIONS, OPTIMIZERS, NUMBER_OF_ENGINEERED_FEATURES_BEYOND_PLAYER_ID_AND_CELLS
from Hex import Hex
import Utils
class Actor:
def __init__(self, board_size, hidden_layers, learning_rate, activation, optimizer):
modules = []
modules.append(torch.nn.Linear(Hex.get_number_of_cells(board_size)+1 +
NUMBER_OF_ENGINEERED_FEATURES_BEYOND_PLAYER_ID_AND_CELLS,
hidden_layers[0]))
modules.append(ACTIVATIONS[activation]())
for i in range(0, len(hidden_layers) - 1):
modules.append(torch.nn.Linear((hidden_layers[i]), hidden_layers[i + 1]))
modules.append(ACTIVATIONS[activation]())
modules.append(torch.nn.Linear(hidden_layers[-1], Hex.get_number_of_cells(board_size)))
self.net = torch.nn.Sequential(*modules)
self.criterion = self.cross_entropy_loss
self.optimizer = OPTIMIZERS[optimizer](self.net.parameters(), lr=learning_rate)
@staticmethod
def cross_entropy_loss(pred, soft_targets):
logsoftmax = torch.nn.LogSoftmax()
return torch.mean(torch.sum(- soft_targets * logsoftmax(pred), 1))
def forward(self, X):
self.net.eval()
return self.net(X)
def train(self, replay_buffer, new_examples_count, replay_buffer_max_size, replay_buffer_minibatch_size):
new_examples = replay_buffer[-new_examples_count:]
old_examples = replay_buffer[:-new_examples_count]
if len(old_examples) > replay_buffer_max_size:
old_examples = old_examples[:-replay_buffer_max_size]
old_examples_shuffled = Utils.shuffle(old_examples)
training_set = new_examples
if len(old_examples) > 0:
training_set = (new_examples + old_examples_shuffled)[:replay_buffer_minibatch_size]
self.net.train()
for example in training_set:
self.optimizer.zero_grad()
pred = self.net(example[0]).reshape(1, -1)
distr = torch.tensor(example[1])
loss = self.criterion(pred, distr)
loss.backward()
self.optimizer.step()
return self
def save(self, board_size, episode):
state_dict = self.net.state_dict()
torch.save(state_dict, './models/boardsize_' + str(board_size) + '/net_after_episode_'+str(episode)+".pt")
@staticmethod
def load_model(model_path, size, hidden_layers, learning_rate, activation, optimizer):
actor = Actor(size, hidden_layers, learning_rate, activation, optimizer)
model = actor.net
model.load_state_dict(torch.load(model_path))
model.eval()
return model