-
Notifications
You must be signed in to change notification settings - Fork 44
/
Copy pathedit_object_inpaint.py
246 lines (197 loc) · 11 KB
/
edit_object_inpaint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
# Copyright (C) 2023, Gaussian-Grouping
# Gaussian-Grouping research group, https://github.com/lkeab/gaussian-grouping
# All rights reserved.
#
# ------------------------------------------------------------------------
# Modified from codes in Gaussian-Splatting
# GRAPHDECO research group, https://team.inria.fr/graphdeco
import torch
from scene import Scene
import os
from tqdm import tqdm
from os import makedirs
from gaussian_renderer import render
import torchvision
from utils.general_utils import safe_state
from argparse import ArgumentParser
from arguments import ModelParams, PipelineParams, OptimizationParams, get_combined_args
from gaussian_renderer import GaussianModel
import numpy as np
from PIL import Image
import cv2
from utils.loss_utils import masked_l1_loss
from random import randint
import lpips
import json
from render import feature_to_rgb, visualize_obj
from edit_object_removal import points_inside_convex_hull
def mask_to_bbox(mask):
# Find the rows and columns where the mask is non-zero
rows = torch.any(mask, dim=1)
cols = torch.any(mask, dim=0)
ymin, ymax = torch.where(rows)[0][[0, -1]]
xmin, xmax = torch.where(cols)[0][[0, -1]]
return xmin, ymin, xmax, ymax
def crop_using_bbox(image, bbox):
xmin, ymin, xmax, ymax = bbox
return image[:, ymin:ymax+1, xmin:xmax+1]
# Function to divide image into K x K patches
def divide_into_patches(image, K):
B, C, H, W = image.shape
patch_h, patch_w = H // K, W // K
patches = torch.nn.functional.unfold(image, (patch_h, patch_w), stride=(patch_h, patch_w))
patches = patches.view(B, C, patch_h, patch_w, -1)
return patches.permute(0, 4, 1, 2, 3)
def finetune_inpaint(opt, model_path, iteration, views, gaussians, pipeline, background, classifier, selected_obj_ids, cameras_extent, removal_thresh, finetune_iteration):
# get 3d gaussians idx corresponding to select obj id
with torch.no_grad():
logits3d = classifier(gaussians._objects_dc.permute(2,0,1))
prob_obj3d = torch.softmax(logits3d,dim=0)
mask = prob_obj3d[selected_obj_ids, :, :] > removal_thresh
mask3d = mask.any(dim=0).squeeze()
mask3d_convex = points_inside_convex_hull(gaussians._xyz.detach(),mask3d,outlier_factor=1.0)
mask3d = torch.logical_or(mask3d,mask3d_convex)
mask3d = mask3d.float()[:,None,None]
# fix some gaussians
gaussians.inpaint_setup(opt,mask3d)
iterations = finetune_iteration
progress_bar = tqdm(range(iterations), desc="Finetuning progress")
LPIPS = lpips.LPIPS(net='vgg')
for param in LPIPS.parameters():
param.requires_grad = False
LPIPS.cuda()
for iteration in range(iterations):
viewpoint_stack = views.copy()
viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))
render_pkg = render(viewpoint_cam, gaussians, pipeline, background)
image, viewspace_point_tensor, visibility_filter, radii, objects = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"], render_pkg["render_object"]
mask2d = viewpoint_cam.objects > 128
gt_image = viewpoint_cam.original_image.cuda()
Ll1 = masked_l1_loss(image, gt_image, ~mask2d)
bbox = mask_to_bbox(mask2d)
cropped_image = crop_using_bbox(image, bbox)
cropped_gt_image = crop_using_bbox(gt_image, bbox)
K = 2
rendering_patches = divide_into_patches(cropped_image[None, ...], K)
gt_patches = divide_into_patches(cropped_gt_image[None, ...], K)
lpips_loss = LPIPS(rendering_patches.squeeze()*2-1,gt_patches.squeeze()*2-1).mean()
loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * lpips_loss
loss.backward()
with torch.no_grad():
if iteration < 5000 :
# Keep track of max radii in image-space for pruning
gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
if iteration % 300 == 0:
size_threshold = 20
gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, cameras_extent, size_threshold)
gaussians.optimizer.step()
gaussians.optimizer.zero_grad(set_to_none = True)
if iteration % 10 == 0:
progress_bar.set_postfix({"Loss": f"{loss:.{7}f}"})
progress_bar.update(10)
progress_bar.close()
# save gaussians
point_cloud_path = os.path.join(model_path, "point_cloud_object_inpaint/iteration_{}".format(iteration))
gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply"))
return gaussians
def render_set(model_path, name, iteration, views, gaussians, pipeline, background, classifier):
render_path = os.path.join(model_path, name, "ours{}".format(iteration), "renders")
gts_path = os.path.join(model_path, name, "ours{}".format(iteration), "gt")
colormask_path = os.path.join(model_path, name, "ours{}".format(iteration), "objects_feature16")
gt_colormask_path = os.path.join(model_path, name, "ours{}".format(iteration), "gt_objects_color")
pred_obj_path = os.path.join(model_path, name, "ours{}".format(iteration), "objects_pred")
makedirs(render_path, exist_ok=True)
makedirs(gts_path, exist_ok=True)
makedirs(colormask_path, exist_ok=True)
makedirs(gt_colormask_path, exist_ok=True)
makedirs(pred_obj_path, exist_ok=True)
for idx, view in enumerate(tqdm(views, desc="Rendering progress")):
results = render(view, gaussians, pipeline, background)
rendering = results["render"]
rendering_obj = results["render_object"]
logits = classifier(rendering_obj)
pred_obj = torch.argmax(logits,dim=0)
pred_obj_mask = visualize_obj(pred_obj.cpu().numpy().astype(np.uint8))
gt_objects = view.objects
gt_rgb_mask = visualize_obj(gt_objects.cpu().numpy().astype(np.uint8))
rgb_mask = feature_to_rgb(rendering_obj)
Image.fromarray(rgb_mask).save(os.path.join(colormask_path, '{0:05d}'.format(idx) + ".png"))
Image.fromarray(gt_rgb_mask).save(os.path.join(gt_colormask_path, '{0:05d}'.format(idx) + ".png"))
Image.fromarray(pred_obj_mask).save(os.path.join(pred_obj_path, '{0:05d}'.format(idx) + ".png"))
gt = view.original_image[0:3, :, :]
torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png"))
torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png"))
out_path = os.path.join(render_path[:-8],'concat')
makedirs(out_path,exist_ok=True)
fourcc = cv2.VideoWriter.fourcc(*'DIVX')
size = (gt.shape[-1]*5,gt.shape[-2])
fps = float(5) if 'train' in out_path else float(1)
writer = cv2.VideoWriter(os.path.join(out_path,'result.mp4'), fourcc, fps, size)
for file_name in sorted(os.listdir(gts_path)):
gt = np.array(Image.open(os.path.join(gts_path,file_name)))
rgb = np.array(Image.open(os.path.join(render_path,file_name)))
gt_obj = np.array(Image.open(os.path.join(gt_colormask_path,file_name)))
render_obj = np.array(Image.open(os.path.join(colormask_path,file_name)))
pred_obj = np.array(Image.open(os.path.join(pred_obj_path,file_name)))
result = np.hstack([gt,rgb,gt_obj,pred_obj,render_obj])
result = result.astype('uint8')
Image.fromarray(result).save(os.path.join(out_path,file_name))
writer.write(result[:,:,::-1])
writer.release()
def inpaint(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool, opt : OptimizationParams, select_obj_id : int, removal_thresh : float, finetune_iteration: int):
# 1. load gaussian checkpoint
gaussians = GaussianModel(dataset.sh_degree)
scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)
num_classes = dataset.num_classes
print("Num classes: ",num_classes)
classifier = torch.nn.Conv2d(gaussians.num_objects, num_classes, kernel_size=1)
classifier.cuda()
classifier.load_state_dict(torch.load(os.path.join(dataset.model_path,"point_cloud","iteration_"+str(scene.loaded_iter),"classifier.pth")))
bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
# 2. inpaint selected object
gaussians = finetune_inpaint(opt, dataset.model_path, scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background, classifier, select_obj_id, scene.cameras_extent, removal_thresh, finetune_iteration)
# 3. render new result
dataset.object_path = 'object_mask'
dataset.images = 'images'
scene = Scene(dataset, gaussians, load_iteration='_object_inpaint/iteration_'+str(finetune_iteration-1), shuffle=False)
with torch.no_grad():
if not skip_train:
render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background, classifier)
if not skip_test:
render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background, classifier)
if __name__ == "__main__":
# Set up command line argument parser
parser = ArgumentParser(description="Testing script parameters")
model = ModelParams(parser, sentinel=True)
opt = OptimizationParams(parser)
pipeline = PipelineParams(parser)
parser.add_argument("--iteration", default=-1, type=int)
parser.add_argument("--skip_train", action="store_true")
parser.add_argument("--skip_test", action="store_true")
parser.add_argument("--quiet", action="store_true")
parser.add_argument("--config_file", type=str, default="config/object_removal/bear.json", help="Path to the configuration file")
args = get_combined_args(parser)
print("Rendering " + args.model_path)
# Read and parse the configuration file
try:
with open(args.config_file, 'r') as file:
config = json.load(file)
except FileNotFoundError:
print(f"Error: Configuration file '{args.config_file}' not found.")
exit(1)
except json.JSONDecodeError as e:
print(f"Error: Failed to parse the JSON configuration file: {e}")
exit(1)
args.num_classes = config.get("num_classes", 200)
args.removal_thresh = config.get("removal_thresh", 0.3)
args.select_obj_id = config.get("select_obj_id", [34])
args.images = config.get("images", "images")
args.object_path = config.get("object_path", "object_mask")
args.resolution = config.get("r", 1)
args.lambda_dssim = config.get("lambda_dlpips", 0.5)
args.finetune_iteration = config.get("finetune_iteration", 10_000)
# Initialize system state (RNG)
safe_state(args.quiet)
inpaint(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test, opt.extract(args), args.select_obj_id, args.removal_thresh, args.finetune_iteration)