-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
152 lines (131 loc) · 5.54 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""
Check model performance on the trained checkpoints
on test dataset or validation dataset
python test.py -h for more information
--------------------------------------
"""
import numpy as np
import matplotlib.pyplot as plt
from pretrained_models import get_model
from dataset.data import load_dataset
from sklearn.metrics import (accuracy_score,
f1_score,
precision_score,
recall_score)
import torch
import argparse
import logging
from train import get_num_correct, calculate_metrics
import yaml
from tqdm import tqdm
from time import sleep
from tabulate import tabulate
# configure logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
file_handler = logging.FileHandler(filename= "logger_test.log")
stream_handler = logging.StreamHandler()
formatter = logging.Formatter(fmt= "%(asctime)s: %(message)s", datefmt= '%Y-%m-%d %H:%M:%S')
file_handler.setFormatter(formatter)
stream_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.addHandler(stream_handler)
def read_args():
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type = int, default=32, help= "test loader batch size")
parser.add_argument("--weights", type = str, help= "path to weights file")
parser.add_argument("--model", type = str, default= "resnet18", help= "model name")
parser.add_argument("--classes", type = int, default= 4, help= "number of classes")
parser.add_argument("--config", type = str, default= "configs/configs.yaml", help = "configurations file")
parser.add_argument("--kind", type = str, help= "inference type, i.e. val, test, train..")
parser.add_argument("--subset", action= "store_true", help= "whether to use small subset")
parser.add_argument("--colab", action= "store_true", help= "colab option")
parser.add_argument("--manual_table", action="store_true", help="log manual table")
opt = parser.parse_args()
return opt
# run inference on the batches of images for testing.
def inference(batch: int = 32,
weights: str = "",
model: str = "resnet18",
args: argparse.Namespace = None):
"""
inference on the test set
-------------------------
args:
batch: int
weights: str
model: str
"""
# load the trained model
logger.info("loading a Trained model")
model_info = torch.load(args.weights, map_location= torch.device("cpu"))
epoch = model_info["epoch"]
logger.info(f"Total trained epochs: {epoch}")
model_sd = model_info["model_state_dict"]
model = get_model(name = model, pretrained= False,
num_classes= args.classes, weights= model_sd)
with open(args.config, "r") as file:
cfg = yaml.safe_load(file)
# load the dataset
data_loader = load_dataset(config_file= cfg, kind = args.kind, subset= args.subset,
batch_size= args.batch)
total_samples = len(data_loader.dataset)
logger.info(f"Total samples in the dataset: {total_samples}")
if args.colab:
cfg["general_configs"]["dataset splitted"] = "/gdrive/MyDrive/covid/data/COVID-19_Radiography_Dataset"
cfg["DataLoader"]["num_workers"] = 2
logger.info("Running inference on the specifed dataset.")
print()
# select hardware acceleration device
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
val_corrects = 0
loader = tqdm(data_loader)
precision, recall, f1_score, accuracy= 0, 0, 0, 0
with torch.no_grad():
for images, labels in loader:
loader.set_description(f'Inference')
images, labels = images.to(device), labels.to(device)
val_predictions = model(images)
val_corrects += get_num_correct(val_predictions, labels)
preds_classes = val_predictions.argmax(dim=1)
accs, p, r, f1 = calculate_metrics(preds_classes, labels,
"all", average= "macro")
precision += p
recall += r
f1_score += f1
accuracy += accs
loader.set_postfix(
acc=accs)
# average over the epoch
mean_precision = precision/len(data_loader)
mean_recall = recall/len(data_loader)
mean_f1 = f1_score/len(data_loader)
accuracy = accuracy/len(data_loader)
# print table
if args.manual_table:
logger.info("Evaluation Results")
logger.info("+-----------------------+---------+")
logger.info("| Metric | Value |")
logger.info("+-----------------------+---------+")
logger.info(f"| Precision macro | {mean_precision: .3f} |")
logger.info(f"| Recall macro | {mean_recall: .3f} |")
logger.info(f"| F1 Score macro | {mean_f1: .3f} |")
logger.info(f"| Accuracy | {accuracy: .3f} |")
logger.info("+-----------------------+---------+")
else:
print("Evaluation Results")
table_data = [
["Precision macro", mean_precision],
["Recall macro", mean_recall],
["F1 Score macro", mean_f1],
["Accuracy", accuracy]]
table_headers = ["Metric", "Value"]
table = tabulate(table_data, headers=table_headers, tablefmt="grid", numalign="right", stralign="center")
print(table)
def main():
args = read_args()
inference(args.batch, args.weights, args.model, args)
if __name__ == "__main__":
main()