-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
80 lines (63 loc) · 3.44 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import faiss
import torch
import logging
import numpy as np
from tqdm import tqdm
from typing import Tuple
from argparse import Namespace
from torch.utils.data.dataset import Subset
from torch.utils.data import DataLoader, Dataset
# Compute R@1, R@5, R@10, R@20
RECALL_VALUES = [1, 5, 10, 20]
def test(args: Namespace, eval_ds: Dataset, model: torch.nn.Module, output_folder: str = None) -> Tuple[np.ndarray, str]:
"""Compute descriptors of the given dataset and compute the recalls."""
model = model.eval()
with torch.no_grad():
logging.debug("Extracting database descriptors for evaluation/testing")
database_subset_ds = Subset(eval_ds, list(range(eval_ds.database_num)))
database_dataloader = DataLoader(dataset=database_subset_ds, num_workers=args.num_workers,
batch_size=args.infer_batch_size, pin_memory=(args.device == "cuda"))
all_descriptors = np.empty((len(eval_ds), args.fc_output_dim), dtype="float32")
for images, indices in tqdm(database_dataloader, ncols=100):
descriptors = model(images.to(args.device))
descriptors = descriptors.cpu().numpy()
all_descriptors[indices.numpy(), :] = descriptors
logging.debug("Extracting queries descriptors for evaluation/testing using batch size 1")
queries_infer_batch_size = 1
queries_subset_ds = Subset(eval_ds, list(range(eval_ds.database_num, eval_ds.database_num+eval_ds.queries_num)))
queries_dataloader = DataLoader(dataset=queries_subset_ds, num_workers=args.num_workers,
batch_size=queries_infer_batch_size, pin_memory=(args.device == "cuda"))
for images, indices in tqdm(queries_dataloader, ncols=100):
descriptors = model(images.to(args.device))
descriptors = descriptors.cpu().numpy()
all_descriptors[indices.numpy(), :] = descriptors
if args.save_descriptors and output_folder is not None:
torch.save(all_descriptors, f'{output_folder}/descriptors.pth')
predictions = get_predictions(eval_ds, args.fc_output_dim, all_descriptors)
#### For each query, check if the predictions are correct
recalls = evaluate_predictions(predictions, eval_ds)
# Divide by queries_num and multiply by 100, so the recalls are in percentages
recalls = recalls / eval_ds.queries_num * 100
recalls_str = ", ".join([f"R@{val}: {rec:.1f}" for val, rec in zip(RECALL_VALUES, recalls)])
return recalls, recalls_str
def evaluate_predictions(predictions, dataset: Dataset):
positives_per_query = dataset.get_positives()
recalls = np.zeros(len(RECALL_VALUES))
for query_index, preds in enumerate(predictions):
for i, n in enumerate(RECALL_VALUES):
if np.any(np.in1d(preds[:n], positives_per_query[query_index])):
recalls[i:] += 1
break
return recalls
def get_predictions(eval_ds, output_dim, all_descriptors):
queries_descriptors = all_descriptors[eval_ds.database_num:]
database_descriptors = all_descriptors[:eval_ds.database_num]
del all_descriptors
# Use a kNN to find predictions
faiss_index = faiss.IndexFlatL2(output_dim)
faiss_index.add(database_descriptors)
del database_descriptors
logging.debug("Calculating recalls")
_, predictions = faiss_index.search(queries_descriptors, max(RECALL_VALUES))
del queries_descriptors
return predictions