-
Notifications
You must be signed in to change notification settings - Fork 2
/
example_defmo_source.py
79 lines (65 loc) · 3.13 KB
/
example_defmo_source.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import os
import torch
from benchmark.benchmark_loader import *
from benchmark.loaders_helpers import *
import argparse
import sys
# requires to download official DeFMO implementation from https://github.com/rozumden/DeFMO
sys.path.insert(0, './DeFMO')
from models.encoder import *
from models.rendering import *
from dataloaders.loader import get_transform
from helpers.torch_helpers import renders2traj
g_saved_models_folder = './DeFMO/saved_models/'
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--tbd_path", default='/cluster/home/denysr/scratch/dataset/TbD', required=False)
parser.add_argument("--tbd3d_path", default='/cluster/home/denysr/scratch/dataset/TbD-3D', required=False)
parser.add_argument("--falling_path", default='/cluster/home/denysr/scratch/dataset/falling_objects', required=False)
parser.add_argument("--verbose", default=False)
parser.add_argument("--visualization_path", default='/cluster/home/denysr/tmp', required=False)
parser.add_argument("--save_visualization", default=False, required=False)
return parser.parse_args()
def main():
args = parse_args()
g_resolution_x = int(640/2)
g_resolution_y = int(480/2)
multi_f = 5 ## simulate small motion blur
gpu_id = 0
device = torch.device("cuda:{}".format(gpu_id) if torch.cuda.is_available() else "cpu")
print(device)
torch.backends.cudnn.benchmark = True
encoder = EncoderCNN()
rendering = RenderingCNN()
if torch.cuda.is_available():
encoder.load_state_dict(torch.load(os.path.join(g_saved_models_folder, 'encoder_best.pt')))
rendering.load_state_dict(torch.load(os.path.join(g_saved_models_folder, 'rendering_best.pt')))
else:
encoder.load_state_dict(torch.load(os.path.join(g_saved_models_folder, 'encoder_best.pt'),map_location=torch.device('cpu')))
rendering.load_state_dict(torch.load(os.path.join(g_saved_models_folder, 'rendering_best.pt'),map_location=torch.device('cpu')))
encoder = encoder.to(device)
rendering = rendering.to(device)
encoder.train(False)
rendering.train(False)
def deblur_defmo(I,B,bbox_tight,nsplits,radius,obj_dim):
bbox = extend_bbox(bbox_tight.copy(),4*np.max(radius),g_resolution_y/g_resolution_x,I.shape)
im_crop = crop_resize(I, bbox, (g_resolution_x, g_resolution_y))
bgr_crop = crop_resize(B, bbox, (g_resolution_x, g_resolution_y))
preprocess = get_transform()
input_batch = torch.cat((preprocess(im_crop), preprocess(bgr_crop)), 0).to(device).unsqueeze(0).float()
with torch.no_grad():
latent = encoder(input_batch)
times = torch.linspace(0,1,nsplits*multi_f+1).to(device)
renders = rendering(latent,times[None])
renders = renders[:,:-1].reshape(1, nsplits, multi_f, 4, g_resolution_y, g_resolution_x).mean(2) # add small motion blur
renders_rgba = renders[0].data.cpu().detach().numpy().transpose(2,3,1,0)
est_hs_crop = rgba2hs(renders_rgba, bgr_crop)
est_hs = rev_crop_resize(est_hs_crop,bbox,I)
est_traj = renders2traj(renders,device)[0].T.cpu()
est_traj = rev_crop_resize_traj(est_traj, bbox, (g_resolution_x, g_resolution_y))
return est_hs, est_traj
args.add_traj = False
args.method_name = 'DeFMO'
run_benchmark(args, deblur_defmo)
if __name__ == "__main__":
main()