-
Notifications
You must be signed in to change notification settings - Fork 71
/
visualize_detector.py
84 lines (72 loc) · 2.88 KB
/
visualize_detector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# --------------------------------------------------------
# ImageNet-21K Pretraining for The Masses
# Copyright 2021 Alibaba MIIL (c)
# Licensed under MIT License [see the LICENSE file for details]
# Written by Tal Ridnik
# --------------------------------------------------------
import os
import urllib
from argparse import Namespace
import torch
from PIL import Image
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from src_files.semantic.semantics import ImageNet21kSemanticSoftmax
import timm
############### Downloading metadata ##############
print("downloading metadata...")
url, filename = (
"https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/resources/fall11/imagenet21k_miil_tree.pth",
"imagenet21k_miil_tree.pth")
if not os.path.isfile(filename):
urllib.request.urlretrieve(url, filename)
args = Namespace()
args.tree_path = filename
semantic_softmax_processor = ImageNet21kSemanticSoftmax(args)
print("done")
############### Loading (ViT) model from timm package ##############
print("initilizing model...")
model = timm.create_model('vit_base_patch16_224_miil_in21k', pretrained=True)
model.eval()
config = resolve_data_config({}, model=model)
transform = create_transform(**config)
print("done")
############## Loading sample image ##############
print("downloading sample image...")
url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
if not os.path.isfile(filename):
urllib.request.urlretrieve(url, filename)
img = Image.open(filename).convert('RGB')
tensor = transform(img).unsqueeze(0) # transform and add batch dimension
print("done")
############## Doing semantic inference ##############
print("doing semantic infernce...")
labels = []
with torch.no_grad():
logits = model(tensor)
semantic_logit_list = semantic_softmax_processor.split_logits_to_semantic_logits(logits)
# scanning hirarchy_level_list
for i in range(len(semantic_logit_list)):
logits_i = semantic_logit_list[i]
# generate probs
probabilities = torch.nn.functional.softmax(logits_i[0], dim=0)
top1_prob, top1_id = torch.topk(probabilities, 1)
if top1_prob > 0.5:
top_class_number = semantic_softmax_processor.hierarchy_indices_list[i][top1_id[0]]
top_class_name = semantic_softmax_processor.tree['class_list'][top_class_number]
top_class_description = semantic_softmax_processor.tree['class_description'][top_class_name]
labels.append(top_class_description)
print("labels found {}.".format(labels))
############## Visualization ##############
import matplotlib
import os
import numpy as np
if os.name == 'nt':
matplotlib.use('TkAgg')
else:
matplotlib.use('Agg')
import matplotlib.pyplot as plt
plt.imshow(img)
plt.axis('off')
plt.title('Semantic labels found: \n {}'.format(np.array(labels)))
plt.show()