This repository has been archived by the owner on Jan 29, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest.py
81 lines (64 loc) · 2.27 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
"""
-------------------------------------------------
File Name: test.py
Author: Zhonghao Huang
Date: 2019/9/1
Description:
-------------------------------------------------
"""
import os
import argparse
import torch
import torch.nn as nn
from torch.backends import cudnn
from data import make_loader, make_loader_flip
from models import make_model
from trainers.evaluator import Evaluator
def test(cfg, flip):
# device
num_gpus = 0
if cfg.DEVICE == 'cuda':
os.environ['CUDA_VISIBLE_DEVICES'] = cfg.DEVICE_ID
num_gpus = len(cfg.DEVICE_ID.split(','))
print("Using {} GPUs.\n".format(num_gpus))
cudnn.benchmark = True
device = torch.device(cfg.DEVICE)
# data
train_loader, query_loader, gallery_loader, num_classes = make_loader(cfg)
# model
model = make_model(cfg, num_classes=num_classes)
model.load_state_dict(torch.load(os.path.join(cfg.OUTPUT_DIR, model.__class__.__name__ + '_best.pth')))
if num_gpus > 1:
model = nn.DataParallel(model)
model = model.to(device)
evaluator = Evaluator(model)
# Results
cmc, mAP = evaluator.evaluate(query_loader, gallery_loader)
ranks = [1, 5, 10]
print("Results ----------")
print("mAP: {:.1%}".format(mAP))
print("CMC curve")
for r in ranks:
print("Rank-{:<3}: {:.1%}".format(r, cmc[r - 1]))
print("------------------\n")
# Results with flip
if flip:
print("Results with flip --------------")
query_flip_loader, gallery_flip_loader = make_loader_flip(cfg)
cmc, mAP = evaluator.evaluate_flip(query_loader, gallery_loader, query_flip_loader, gallery_flip_loader)
print("Results ----------")
print("mAP: {:.1%}".format(mAP))
print("CMC curve")
for r in ranks:
print("Rank-{:<3}: {:.1%}".format(r, cmc[r - 1]))
print("------------------\n")
print('Done')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Person Re-identification Project.")
parser.add_argument('--config', default='./configs/sample.yaml')
parser.add_argument('--flip', default=True)
args = parser.parse_args()
from config import cfg as opt
opt.merge_from_file(args.config)
opt.freeze()
test(opt, args.flip)