-
Notifications
You must be signed in to change notification settings - Fork 0
/
model_wrapper.py
140 lines (125 loc) · 5.04 KB
/
model_wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# Wrappers to produce final models
# %%
import torch
from torch import nn
import einops as ein
# Internal
from dinov2_extractor import DinoV2ExtractFeatures
from utilities import VLAD
# %%
class AnyLocVladDinov2(nn.Module):
"""
Wrapper around the AnyLoc-VLAD-DINOv2 model in the paper for
the domain vocabularies (default).
It basically has the DINOv2 ViT feature extraction and the
VLAD descriptor construction in a single module.
"""
def __init__(self, c_centers: torch.Tensor,
dino_model: str = "dinov2_vitg14", layer: int = 31,
facet: str = "value", num_c: int = 32,
device: torch.device = "cpu"):
super().__init__()
# DINOv2 feature extractor
self.dino_model = dino_model
self.layer = layer
self.facet = facet
self.device = torch.device(device)
self.dino_extractor = self._get_dino_extractor()
# VLAD clustering
self.vlad = VLAD(num_c)
self.vlad.c_centers = c_centers.to(self.device)
self.vlad.fit(None) # Load the database (vocabulary/c_centers)
# Extractor
def _get_dino_extractor(self):
return DinoV2ExtractFeatures(
dino_model=self.dino_model, layer=self.layer,
facet=self.facet, device=self.device)
# Move DINO model to device
def to(self, device: torch.device):
self.device = torch.device(device)
self.dino_extractor = self._get_dino_extractor()
# Wrapper around CUDA
def cuda(self):
self.to("cuda")
# Forward pass
def forward(self, x: torch.Tensor) -> torch.Tensor:
img_pt = x
shapes = ein.parse_shape(img_pt, "b c h w")
assert shapes["c"] == 3, "Image(s) must be RGB!"
assert shapes["h"] % 14 == shapes["w"] % 14 == 0, \
"Height and width should be multiple of 14 (for "\
"patching)"
img_pt = img_pt.to(self.device)
# Extract features
ret = self.dino_extractor(img_pt) # [b, (nH*nW), dino_dim]
gds = self.vlad.generate_multi(ret)
return gds.to(self.device)
# %%
class AnyLocVladNoCacheDinov2(nn.Module):
"""
Wrapper around the AnyLoc-VLAD-DINOv2 model without the VLAD
cluster centers. This is useful for using DINOv2 as a feature
extractor, and then using VLAD for the clustering.
If you want to use a cache (cluster centers already computed),
then use `AnyLocVladDinov2` class instead.
"""
def __init__(self, dino_model: str = "dinov2_vitg14",
layer: int = 31, facet: str = "value",
num_c: int = 32, device: torch.device = "cpu")\
-> None:
super().__init__()
# DINOv2 feature extractor
self.dino_model = dino_model
self.layer = layer
self.facet = facet
self.device = torch.device(device)
self.dino_extractor = self._get_dino_extractor()
# VLAD module
self.vlad = VLAD(num_c)
self.clusters_fitted = False # Flag
# Extractor
def _get_dino_extractor(self):
return DinoV2ExtractFeatures(
dino_model=self.dino_model, layer=self.layer,
facet=self.facet, device=self.device)
# Move the DINO model and cluster centers to another device
def to(self, device: torch.device):
self.device = torch.device(device)
self.dino_extractor = self._get_dino_extractor()
if self.clusters_fitted:
self.vlad.c_centers = self.vlad.c_centers.to(self.device)
# Wrapper around CUDA
def cuda(self):
self.to("cuda")
# Extract image features using backbone
def extract(self, x: torch.Tensor) -> torch.Tensor:
img_pt = x
shapes = ein.parse_shape(img_pt, "b c h w")
assert shapes["c"] == 3, "Image(s) must be RGB!"
assert shapes["h"] % 14 == shapes["w"] % 14 == 0, \
"Height and width should be multiple of 14 (for "\
"patching)"
img_pt = img_pt.to(self.device)
# Extract features
ret = self.dino_extractor(img_pt) # [b, (nH*nW), dino_dim]
return ret.to(self.device)
# Get cluster centers from descriptors
def fit(self, x: torch.Tensor) -> None:
self.vlad.fit(x) # x shape = (num_descs, desc_dim)
self.clusters_fitted = True
# Forward pass
def forward(self, x: torch.Tensor) -> torch.Tensor:
if not self.clusters_fitted:
raise ValueError(
"Cluster centers unavailable. Call 'fit'.")
img_pt = x
shapes = ein.parse_shape(img_pt, "b c h w")
assert shapes["c"] == 3, "Image(s) must be RGB!"
assert shapes["h"] % 14 == shapes["w"] % 14 == 0, \
"Height and width should be multiple of 14 (for "\
"patching)"
img_pt = img_pt.to(self.device)
# Extract features
ret = self.dino_extractor(img_pt) # [b, (nH*nW), dino_dim]
gds = self.vlad.generate_multi(ret)
return gds.to(self.device)