This repository has been archived by the owner on Jan 29, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
91 lines (72 loc) · 2.43 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""
-------------------------------------------------
File Name: train.py
Author: Zhonghao Huang
Date: 2019/9/1
Description:
-------------------------------------------------
"""
import os
import argparse
import torch
import torch.nn as nn
from torch.backends import cudnn
from trainers.trainer import Trainer
from data import make_loader
from logger import make_logger
from models import make_model, make_loss
from optimizer import make_optimizer
from scheduler import make_scheduler
def train(cfg):
# output
output_dir = cfg.OUTPUT_DIR
if os.path.exists(output_dir):
raise KeyError("Existing path: ", output_dir)
else:
os.makedirs(output_dir)
with open(os.path.join(output_dir, 'config.yaml'), 'w') as f_out:
print(cfg, file=f_out)
# logger
logger = make_logger("project", output_dir, 'log')
# device
num_gpus = 0
if cfg.DEVICE == 'cuda':
os.environ['CUDA_VISIBLE_DEVICES'] = cfg.DEVICE_ID
num_gpus = len(cfg.DEVICE_ID.split(','))
logger.info("Using {} GPUs.\n".format(num_gpus))
cudnn.benchmark = True
device = torch.device(cfg.DEVICE)
# data
train_loader, query_loader, gallery_loader, num_classes = make_loader(cfg)
# model
model = make_model(cfg, num_classes=num_classes)
if num_gpus > 1:
model = nn.DataParallel(model)
# solver
criterion = make_loss(cfg, num_classes)
optimizer = make_optimizer(cfg, model)
scheduler = make_scheduler(cfg, optimizer)
# do_train
trainer = Trainer(model=model,
optimizer=optimizer,
criterion=criterion,
logger=logger,
scheduler=scheduler,
device=device)
trainer.run(start_epoch=0,
total_epoch=cfg.SOLVER.MAX_EPOCHS,
train_loader=train_loader,
query_loader=query_loader,
gallery_loader=gallery_loader,
print_freq=cfg.SOLVER.PRINT_FREQ,
eval_period=cfg.SOLVER.EVAL_PERIOD,
out_dir=output_dir)
print('Done.')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Person Re-identification Project.")
parser.add_argument('--config', default='./configs/sample.yaml')
args = parser.parse_args()
from config import cfg as opt
opt.merge_from_file(args.config)
opt.freeze()
train(opt)