-
Notifications
You must be signed in to change notification settings - Fork 0
/
runner.py
52 lines (39 loc) · 1.82 KB
/
runner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import numpy as np
import random
import torch
import yaml
from tqdm import trange
from utils.data import Dataset
from utils.model import instantiate_model
from utils.logger import get_logger
if __name__ == '__main__':
args = yaml.load(open('./configs/runner.yml', 'r'), Loader=yaml.FullLoader)
seed = args["seed"]
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
train_file, test_file, valid_file, embedding_source = args["train_file"], args["test_file"], args["valid_file"], args["embedding_source"]
dataset = Dataset(batch_size=args["batch_size"], seq_len=args["seq_len"])
dataset.load_data(train_file, test_file, valid_file, embedding_source)
vocab_size = len(dataset.vocab)
embeddings = dataset.embeddings
model = instantiate_model(args["model"], vocab_size, embeddings)
# tqdm_range = trange(args["epochs"], desc="Epoch", unit="epoch")
print("=" * 54)
for epoch in range(args["epochs"]):
_, _ = model.run_epoch(dataset.train_iterator, dataset.valid_iterator)
print("-" * 54)
train_acc = model.evaluate(dataset.train_iterator)
val_acc = model.evaluate(dataset.valid_iterator)
print("Epoch %3d ended | train acc.: %3.2f | val acc.: %.2f" % ((epoch + 1), train_acc * 100, val_acc * 100))
# tqdm_range.set_postfix_str("Loss: %g" % loss)
print("=" * 54)
train_acc = model.evaluate(dataset.train_iterator)
valid_acc = model.evaluate(dataset.valid_iterator)
test_acc = model.evaluate(dataset.test_iterator)
logger = get_logger()
logger.info('Training Accuracy: %' + '%3.2f.' % (train_acc * 100))
logger.info('Validation Accuracy: %' + '%3.2f.' % (valid_acc * 100))
logger.info('Test Accuracy: %' + '%3.2f.' % (test_acc * 100))