-
Notifications
You must be signed in to change notification settings - Fork 0
/
dinov2_extractor.py
84 lines (77 loc) · 3.05 KB
/
dinov2_extractor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# %%
import torch
from torch import nn
import torch.nn.functional as F
from typing import Literal
# %%
# Extract features from a Dino-v2 model
_DINO_V2_MODELS = Literal["dinov2_vits14", "dinov2_vitb14", \
"dinov2_vitl14", "dinov2_vitg14"]
_DINO_FACETS = Literal["query", "key", "value", "token"]
class DinoV2ExtractFeatures:
"""
Extract features from an intermediate layer in Dino-v2
"""
def __init__(self, dino_model: _DINO_V2_MODELS, layer: int,
facet: _DINO_FACETS="token", use_cls=False,
norm_descs=True, device: str = "cpu") -> None:
"""
Parameters:
- dino_model: The DINO-v2 model to use
- layer: The layer to extract features from
- facet: "query", "key", or "value" for the attention
facets. "token" for the output of the layer.
- use_cls: If True, the CLS token (first item) is also
included in the returned list of descriptors.
Otherwise, only patch descriptors are used.
- norm_descs: If True, the descriptors are normalized
- device: PyTorch device to use
"""
self.vit_type: str = dino_model
self.dino_model: nn.Module = torch.hub.load(
'facebookresearch/dinov2', dino_model)
self.device = torch.device(device)
self.dino_model = self.dino_model.eval().to(self.device)
self.layer: int = layer
self.facet = facet
if self.facet == "token":
self.fh_handle = self.dino_model.blocks[self.layer].\
register_forward_hook(
self._generate_forward_hook())
else:
self.fh_handle = self.dino_model.blocks[self.layer].\
attn.qkv.register_forward_hook(
self._generate_forward_hook())
self.use_cls = use_cls
self.norm_descs = norm_descs
# Hook data
self._hook_out = None
def _generate_forward_hook(self):
def _forward_hook(module, inputs, output):
self._hook_out = output
return _forward_hook
def __call__(self, img: torch.Tensor) -> torch.Tensor:
"""
Parameters:
- img: The input image
"""
with torch.no_grad():
res = self.dino_model(img)
if self.use_cls:
res = self._hook_out
else:
res = self._hook_out[:, 1:, ...]
if self.facet in ["query", "key", "value"]:
d_len = res.shape[2] // 3
if self.facet == "query":
res = res[:, :, :d_len]
elif self.facet == "key":
res = res[:, :, d_len:2*d_len]
else:
res = res[:, :, 2*d_len:]
if self.norm_descs:
res = F.normalize(res, dim=-1)
self._hook_out = None # Reset the hook
return res
def __del__(self):
self.fh_handle.remove()