forked from Parskatt/DKM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
demo_homography.py
48 lines (45 loc) · 1.52 KB
/
demo_homography.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from PIL import Image
import torch
import torch.nn.functional as F
from dkm import DKM
from dkm.utils.utils import tensor_to_pil
import cv2
from torchvision import transforms
dkm_model = DKM(pretrained=True, version="mega")
im1 = Image.open(f"assets/ams_hom_left.jpg").resize((512, 384))
im2 = Image.open(f"assets/ams_hom_right.jpg").resize((512, 384))
im1.save(f"demo/ams_hom_left.jpg")
im2.save(f"demo/ams_hom_right.jpg")
flow, confidence = dkm_model.match(im1, im2)
confidence = confidence ** (1 / 2)
good_matches, _ = dkm_model.sample(flow, confidence, 10000)
H, inliers = cv2.findHomography(
good_matches[..., :2],
good_matches[..., 2:],
method=cv2.RANSAC,
confidence=0.99999,
ransacReprojThreshold=1.5 / 512,
)
H = torch.tensor(H).cuda()
to_tensor = transforms.ToTensor()
im1, im2 = to_tensor(im1), to_tensor(im2).cuda()
target_im = torch.zeros((3, 384, 512 * 2)).cuda()
target_im[:, :, :512] = im1.cuda()
h, w = 384, 512
x1_coords = torch.meshgrid(
(
torch.linspace(-1 + 1 / h, 1 - 1 / h, 384, device="cuda"),
torch.linspace(1 + 1 / w, 3 - 1 / w, w, device="cuda"),
)
)
x1_coords = torch.stack((x1_coords[1], x1_coords[0]), dim=-1)
x2_coords = torch.einsum(
"dc, hwc -> hwd",
H,
torch.cat((x1_coords, torch.ones_like(x1_coords[..., -1:])), dim=-1).double(),
)
x2_coords = x2_coords[..., :-1] / x2_coords[..., -1:]
target_im[:, :, 512:] = F.grid_sample(
im2[None], x2_coords[None].float(), mode="bicubic", align_corners=False
)[0]
tensor_to_pil(target_im).save("demo/stiched_ams.jpg")