forked from Parskatt/DKM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
demo_match.py
25 lines (22 loc) · 899 Bytes
/
demo_match.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from PIL import Image
import torch
import torch.nn.functional as F
import numpy as np
from dkm import DKM
from dkm.utils.utils import tensor_to_pil
dkm_model = DKM(pretrained=True, version="mega_synthetic")
im1 = Image.open(f"assets/sacre_coeur_multimodal_query.jpg").resize((512, 384))
im2 = Image.open(f"assets/sacre_coeur_multimodal_support.jpg").resize((512, 384))
im1.save(f"demo/sacre_coeur_query.jpg")
im2.save(f"demo/sacre_coeur_support.jpg")
flow, confidence = dkm_model.match(im1, im2)
confidence = confidence ** (1 / 2)
c_b = confidence / confidence.max()
x2 = (torch.tensor(np.array(im2)) / 255).cuda().permute(2, 0, 1)
im2_transfer_rgb = F.grid_sample(
x2[None], flow[..., 2:][None], mode="bicubic", align_corners=False
)[0]
white_im = torch.ones_like(x2)
tensor_to_pil(c_b * im2_transfer_rgb + (1 - c_b) * white_im, unnormalize=False).save(
f"demo/sacre_coeur_warped.jpg"
)