-
Notifications
You must be signed in to change notification settings - Fork 2
/
test.py
78 lines (60 loc) · 2.35 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import time
import pathlib
from PIL import Image
from argparse import ArgumentParser
import torch
import faiss
from src.feature_extraction import MyResnet50, MyVGG16, RGBHistogram, LBP
from src.dataloader import get_transformation
ACCEPTED_IMAGE_EXTS = ['.jpg', '.png']
query_root = './dataset/groundtruth'
image_root = './dataset/paris'
feature_root = './dataset/feature'
evaluate_root = './dataset/evaluation'
def get_image_list(image_root):
image_root = pathlib.Path(image_root)
image_list = list()
for image_path in image_root.iterdir():
if image_path.exists():
image_list.append(image_path)
image_list = sorted(image_list, key = lambda x: x.name)
return image_list
def main():
parser = ArgumentParser()
parser.add_argument("--feature_extractor", required=True, type=str, default='VGG16')
parser.add_argument("--device", required=False, type=str, default='cuda:0')
parser.add_argument("--top_k", required=False, type=int, default=11)
parser.add_argument("--test_image_path", required=False, type=str, default='./dataset/paris/paris_triomphe_001112.jpg')
print('Start retrieving .......')
start = time.time()
args = parser.parse_args()
device = torch.device(args.device)
if (args.feature_extractor == 'Resnet50'):
extractor = MyResnet50(device)
elif (args.feature_extractor == 'VGG16'):
extractor = MyVGG16(device)
elif (args.feature_extractor == 'RGBHistogram'):
extractor = RGBHistogram(device)
elif (args.feature_extractor == 'LBP'):
extractor = LBP(device)
else:
print("No matching model found")
return
img_list = get_image_list(image_root)
transform = get_transformation()
# Preprocessing
test_image_path = pathlib.Path(args.test_image_path)
pil_image = Image.open(test_image_path)
pil_image = pil_image.convert('RGB')
image_tensor = transform(pil_image)
image_tensor = image_tensor.unsqueeze(0).to(device)
feat = extractor.extract_features(image_tensor)
indexer = faiss.read_index(feature_root + '/' + args.feature_extractor + '.index.bin')
_, indices = indexer.search(feat, k=args.top_k)
print(indices)
for index in indices[0]:
print(img_list[index])
end = time.time()
print('Finish in ' + str(end - start) + ' seconds')
if __name__ == '__main__':
main()