-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
75 lines (55 loc) · 2.68 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
import torch
import logging
from parse_args import parse_arguments
from metrics.utils import build_meters_dict, collect_metrics
from datasets.utils import build_dataloaders
DEVICE = torch.device('cpu')
if torch.cuda.is_available():
DEVICE = torch.device('cuda:0')
def load_experiment(args, dataloaders):
module_name = '.'.join(os.path.normpath(args.experiment).split(os.sep))
experiment_module = __import__(f'{module_name}.experiment', fromlist=[module_name])
return experiment_module.Experiment(args, dataloaders)
def main():
global DEVICE
args = parse_arguments()
# Setup logger
logging.basicConfig(filename=os.path.join(args.log_path, 'log.txt'), format='%(message)s', level=logging.INFO, filemode='a')
logging.info(args)
# Data setup
dataloaders = build_dataloaders(args)
# Experiments setup
experiment = load_experiment(args, dataloaders)
# Optionally Resume checkpoint
if os.path.exists(os.path.join(args.log_path, 'current.pth')):
experiment.load(os.path.join(args.log_path, 'current.pth'))
# Meters setup
meters_dict = build_meters_dict(args)
# Training Loop
while experiment.iteration < args.max_iters and not args.test_mode:
for data in dataloaders['train']:
train_losses = experiment.train_iteration(data)
# Log losses eg. via wandb
if experiment.iteration % (args.validate_every // 2) == 0:
logging.info(train_losses)
# Validation phase
if experiment.iteration % args.validate_every == 0:
predicted, target, group, avg_val_loss = experiment.evaluate(dataloaders['val'])
metrics = collect_metrics(meters_dict, predicted, target, group)
# Log metrics eg. via wandb
logging.info(f'[VAL @ {experiment.iteration}] {avg_val_loss} | {metrics}')
if experiment.best_metric is None or meters_dict[args.model_selection].compare(metrics[args.model_selection], experiment.best_metric):
experiment.best_metric = metrics[args.model_selection]
experiment.save(os.path.join(args.log_path, 'best.pth'))
experiment.save(os.path.join(args.log_path, 'current.pth'))
experiment.iteration += 1
if experiment.iteration >= args.max_iters: break
# Test phase
experiment.load(os.path.join(args.log_path, 'best.pth'))
predicted, target, group, _ = experiment.evaluate(dataloaders['test'])
metrics = collect_metrics(meters_dict, predicted, target, group)
# Log metrics eg. via wandb
logging.info(f'[TEST] {metrics}')
if __name__ == '__main__':
main()