diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..1521ff5 --- /dev/null +++ b/.gitignore @@ -0,0 +1,23 @@ +*.pyc +.vscode +output +build +diff_rasterization/diff_rast.egg-info +diff_rasterization/dist +tensorboard_3d +screenshots +data/ +data +argument/ +scripts/ +*.mp4 +checkpoint-500/ +*.bin +__MACOSX/ +**/__MACOSX/ +*.ply +olddata +*.out +*.zip +output_flow/ +exp_data/ diff --git a/README.md b/README.md index b943740..7eb7c03 100644 --- a/README.md +++ b/README.md @@ -1 +1,75 @@ -# 4DGen +# 4DGen: Grounded 4D Content Generation with Spatial-temporal Consistency + +[[Project Page]](https://vita-group.github.io/4DGen/) | [[Video]](https://www.youtube.com/watch?v=-bXyBKdpQ1o) + +## Setup + +Please follow the [3D-GS](https://github.com/graphdeco-inria/gaussian-splatting) and to install the related packages. + +```bash +conda env create -f environment.yml +conda activate 4DGen +pip install -r requirements.txt + +# 3D Gaussian Splatting modules, skip if you already installed them +# a modified gaussian splatting (+ depth, alpha rendering) +git clone --recursive https://github.com/ashawkey/diff-gaussian-rasterization +pip install ./diff-gaussian-rasterization +pip install ./simple-knn + +# install kaolin for chamfer distance (optional) +# https://kaolin.readthedocs.io/en/latest/notes/installation.html +# CHANGE the torch and CUDA toolkit version if yours are different +pip install kaolin -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-1.12.1_cu116.html +``` + +## Data Preparation + +We release our collected data in Google Drive. + +Each test cases contains two folders: `{name}_pose0` and `{name}_sync`. `pose0` refers to the monocular video sequence. `sync` refers to the pseudo labels generated by SyncDreamer. + +We recommend using [Practical-RIFE](https://github.com/hzwer/Practical-RIFE) if you need to introduce more frames in your video sequence. + +To preprocess the images into RGBA format, one can use `preprocess.py` or `preprocess_sync.py` + +```bash +# for monocular image sequence +python preprocess.py --path xxx +# for images generated by syncdreamer +python preprocess_sync.py --path xxx +``` + +## Training + +```bash +python train.py --configs arguments/i2v.py -e rose +``` + +## Rendering + +```bash +python render.py --skip_train --configs arguments/ours/i2v_xdj.py --skip_test --model_path "./output/xxxx/" +``` + +## Acknowledgement + +This work is built on many amazing research works and open-source projects, thanks a lot to all the authors for sharing! + +- https://github.com/dreamgaussian/dreamgaussian +- https://github.com/hustvl/4DGaussians +- https://github.com/graphdeco-inria/gaussian-splatting +- https://github.com/graphdeco-inria/diff-gaussian-rasterization +- https://github.com/threestudio-project/threestudio + +## Citation +If you find this repository/work helpful in your research, please consider citing the paper and starring the repo ⭐. + +``` +@article{yin20234dgen, + title={4DGen: Grounded 4D Content Generation with Spatial-temporal Consistency}, + author={}, + journal={arXiv preprint}, + year={2023} +}} +``` diff --git a/arguments/__init__.py b/arguments/__init__.py new file mode 100644 index 0000000..1b1f030 --- /dev/null +++ b/arguments/__init__.py @@ -0,0 +1,172 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +from argparse import ArgumentParser, Namespace +import sys +import os + +class GroupParams: + pass + +class ParamGroup: + def __init__(self, parser: ArgumentParser, name : str, fill_none = False): + group = parser.add_argument_group(name) + for key, value in vars(self).items(): + shorthand = False + if key.startswith("_"): + shorthand = True + key = key[1:] + t = type(value) + value = value if not fill_none else None + if shorthand: + if t == bool: + group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true") + else: + group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t) + else: + if t == bool: + group.add_argument("--" + key, default=value, action="store_true") + else: + group.add_argument("--" + key, default=value, type=t) + + def extract(self, args): + group = GroupParams() + for arg in vars(args).items(): + if arg[0] in vars(self) or ("_" + arg[0]) in vars(self): + setattr(group, arg[0], arg[1]) + return group + +class ModelParams(ParamGroup): + def __init__(self, parser, sentinel=False): + self.frame_num = 8 + self.sh_degree = 0 # NOTE: we don't need sh + self._source_path = "" + self._model_path = "" + self._images = "images" + self._resolution = -1 + self._white_background = True + self.data_device = "cuda" + self.eval = True + self.render_process=False + self.name="panda" + self.rife=False + self.imagedream=False + self.static=False + super().__init__(parser, "Loading Parameters", sentinel) + + def extract(self, args): + g = super().extract(args) + g.source_path = os.path.abspath(g.source_path) + return g + +class PipelineParams(ParamGroup): + def __init__(self, parser): + self.convert_SHs_python = False + self.compute_cov3D_python = False + self.debug = False + super().__init__(parser, "Pipeline Parameters") +class ModelHiddenParams(ParamGroup): + def __init__(self, parser): + self.net_width = 64 + self.timebase_pe = 4 + self.defor_depth = 1 + self.posebase_pe = 10 + self.scale_rotation_pe = 2 + self.opacity_pe = 2 + self.timenet_width = 64 + self.timenet_output = 32 + self.bounds = 1.6 + self.plane_tv_weight = 0.0001 + self.time_smoothness_weight = 0.01 + self.l1_time_planes = 0.0001 + self.grid_merge = 'mul' + self.kplanes_config = { + 'grid_dimensions': 2, + 'input_coordinate_dim': 4, + 'output_coordinate_dim': 32, + 'resolution': [64, 64, 64, 25] + } + self.multires = [1, 2, 4, 8] + self.no_grid=False + self.no_ds=False + self.no_dr=False + self.no_do=True + self.no_dc=True + + + super().__init__(parser, "ModelHiddenParams") + +class OptimizationParams(ParamGroup): + def __init__(self, parser): + self.dataloader=False + self.iterations = 30_000 + self.coarse_iterations = 3000 + self.static_iterations = 700 + self.position_lr_init = 0.00016 + self.position_lr_final = 0.0000016 + self.position_lr_delay_mult = 0.01 + self.position_lr_max_steps = 20_000 + self.deformation_lr_init = 0.00016 + self.deformation_lr_final = 0.000016 + self.deformation_lr_delay_mult = 0.01 + self.grid_lr_init = 0.0016 + self.grid_lr_final = 0.00016 + + self.feature_lr = 0.0025 + self.opacity_lr = 0.05 + self.scaling_lr = 0.005 + self.rotation_lr = 0.001 + self.percent_dense = 0.01 + self.lambda_dssim = 0 + self.lambda_pts = 0 + self.lambda_zero123 = 0.5 + self.lambda_lpips = 0 + self.fine_rand_rate=1 + self.weight_constraint_init= 1 + self.weight_constraint_after = 0.2 + self.weight_decay_iteration = 5000 + self.opacity_reset_interval = 3000 + self.densification_interval = 100 + self.densify_from_iter = 500 + self.densify_until_iter = 15_000 + self.densify_grad_threshold_coarse = 0.0002 + self.densify_grad_threshold_fine_init = 0.0002 + self.densify_grad_threshold_after = 0.0002 + self.pruning_from_iter = 500 + self.pruning_interval = 100 + self.pruning_interval_fine = 100 + self.opacity_threshold_coarse = 0.005 + self.opacity_threshold_fine_init = 0.005 + self.opacity_threshold_fine_after = 0.005 + + super().__init__(parser, "Optimization Parameters") + +def get_combined_args(parser : ArgumentParser): + cmdlne_string = sys.argv[1:] + cfgfile_string = "Namespace()" + args_cmdline = parser.parse_args(cmdlne_string) + + try: + cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args") + print("Looking for config file in", cfgfilepath) + with open(cfgfilepath) as cfg_file: + print("Config file found: {}".format(cfgfilepath)) + cfgfile_string = cfg_file.read() + except TypeError: + print("Config file not found at") + pass + args_cfgfile = eval(cfgfile_string) + + merged_dict = vars(args_cfgfile).copy() + for k,v in vars(args_cmdline).items(): + if v != None: + merged_dict[k] = v + return Namespace(**merged_dict) diff --git a/arguments/i2v.py b/arguments/i2v.py new file mode 100644 index 0000000..403ea54 --- /dev/null +++ b/arguments/i2v.py @@ -0,0 +1,51 @@ +OptimizationParams = dict( + static_iterations = 1099, + coarse_iterations = 1000, + iterations = 3000, # don't set it to 0 !!! + position_lr_max_steps = 3000, + position_lr_delay_mult = 1, #1, + pruning_interval = 100, + pruning_interval_fine = 100000, + percent_dense = 0.01, + densify_grad_threshold_fine_init = 0.5, + densify_grad_threshold_coarse = 0.01, + densify_grad_threshold_after = 0.1, + densification_interval = 100, + opacity_reset_interval = 100, # not used + lambda_lpips = 2, + lambda_dssim = 2, + lambda_pts = 0, + lambda_zero123 = 0.5, # default 0.5 + fine_rand_rate = 0.8 +) + +ModelParams = dict( + frame_num = 16, + name="rose", + rife=False, +) + +ModelHiddenParams = dict( + grid_merge = 'cat', + # grid_merge = 'mul', + multires = [1, 2, 4, 8 ], + defor_depth = 2, + net_width = 256, + plane_tv_weight = 0, + time_smoothness_weight = 0, + l1_time_planes = 0, + weight_decay_iteration=0, + bounds=2, + no_ds=True, + # no_dr=True, + no_do=True, + no_dc=True, + kplanes_config = { + 'grid_dimensions': 2, + 'input_coordinate_dim': 4, + 'output_coordinate_dim': 32, + #'resolution': [32,32,32,32], + 'resolution': [64, 64, 64, 64] + # 'resolution': [64, 64, 64, 150] + } +) diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..f20799b --- /dev/null +++ b/environment.yml @@ -0,0 +1,78 @@ +name: 4DGen +channels: + - defaults +dependencies: + - anaconda/noarch::mesa-libegl-cos6-x86_64==11.0.7=4 + - conda-forge/linux-64::_libgcc_mutex==0.1=conda_forge + - defaults/linux-64::blas==1.0=mkl + - anaconda/linux-64::ca-certificates==2023.01.10=h06a4308_0 + - conda-forge/linux-64::ld_impl_linux-64==2.40=h41732ed_0 + - conda-forge/linux-64::libstdcxx-ng==13.2.0=h7e041cc_2 + - pytorch/noarch::pytorch-mutex==1.0=cuda + - anaconda/linux-64::certifi==2022.12.7=py37h06a4308_0 + - anaconda/linux-64::openssl==1.1.1s=h7f8727e_0 + - conda-forge/linux-64::_openmp_mutex==4.5=2_kmp_llvm + - conda-forge/linux-64::libgcc-ng==13.2.0=h807b86a_2 + - conda-forge/linux-64::bzip2==1.0.8=h7f98852_4 + - conda-forge/linux-64::cudatoolkit==11.6.2=hfc3e2af_12 + - conda-forge/linux-64::gmp==6.2.1=h58526e2_0 + - conda-forge/linux-64::icu==73.2=h59595ed_0 + - conda-forge/linux-64::jpeg==9e=h0b41bf4_3 + - conda-forge/linux-64::lame==3.100=h166bdaf_1003 + - conda-forge/linux-64::lerc==4.0.0=h27087fc_0 + - conda-forge/linux-64::libdeflate==1.14=h166bdaf_0 + - conda-forge/linux-64::libffi==3.3=h58526e2_2 + - conda-forge/linux-64::libiconv==1.17=h166bdaf_0 + - conda-forge/linux-64::libwebp-base==1.3.2=hd590300_0 + - conda-forge/linux-64::libzlib==1.2.13=hd590300_5 + - conda-forge/linux-64::ncurses==6.4=hcb278e6_0 + - conda-forge/linux-64::nettle==3.6=he412f7d_0 + - conda-forge/linux-64::pthread-stubs==0.4=h36c2ea0_1001 + - conda-forge/linux-64::xorg-libxau==1.0.11=hd590300_0 + - conda-forge/linux-64::xorg-libxdmcp==1.1.3=h7f98852_0 + - conda-forge/linux-64::xz==5.2.6=h166bdaf_0 + - conda-forge/linux-64::gnutls==3.6.13=h85f3911_1 + - conda-forge/linux-64::libpng==1.6.39=h753d276_0 + - conda-forge/linux-64::libsqlite==3.43.0=h2797004_0 + - conda-forge/linux-64::libxcb==1.13=h7f98852_1004 + - conda-forge/linux-64::libxml2==2.11.5=h232c23b_1 + - conda-forge/linux-64::readline==8.2=h8228510_1 + - conda-forge/linux-64::tk==8.6.13=h2797004_0 + - conda-forge/linux-64::zlib==1.2.13=hd590300_5 + - conda-forge/linux-64::zstd==1.5.5=hfc55251_0 + - conda-forge/linux-64::freetype==2.12.1=h267a509_2 + - conda-forge/linux-64::libhwloc==2.9.3=default_h554bfaf_1009 + - conda-forge/linux-64::libtiff==4.4.0=h82bc61c_5 + - conda-forge/linux-64::llvm-openmp==16.0.6=h4dfa4b3_0 + - conda-forge/linux-64::openh264==2.1.1=h780b84a_0 + - conda-forge/linux-64::sqlite==3.43.0=h2c6b66d_0 + - pytorch/linux-64::ffmpeg==4.3=hf484d3e_0 + - conda-forge/linux-64::lcms2==2.14=h6ed2654_0 + - conda-forge/linux-64::openjpeg==2.5.0=h7d73246_1 + - defaults/linux-64::python==3.7.13=haa1d7c7_1 + - conda-forge/linux-64::tbb==2021.10.0=h00ab1b0_1 + - conda-forge/noarch::charset-normalizer==3.2.0=pyhd8ed1ab_0 + - conda-forge/noarch::colorama==0.4.6=pyhd8ed1ab_0 + - conda-forge/noarch::idna==3.4=pyhd8ed1ab_0 + - conda-forge/linux-64::mkl==2021.4.0=h8d4b97c_729 + - conda-forge/linux-64::python_abi==3.7=2_cp37m + - conda-forge/noarch::setuptools==68.2.2=pyhd8ed1ab_0 + - conda-forge/noarch::six==1.16.0=pyh6c4a22f_0 + - conda-forge/noarch::typing_extensions==4.7.1=pyha770c72_0 + - conda-forge/noarch::wheel==0.41.2=pyhd8ed1ab_0 + - conda-forge/linux-64::brotli-python==1.0.9=py37hd23a5d3_7 + - conda-forge/linux-64::mkl-service==2.4.0=py37h402132d_0 + - conda-forge/linux-64::pillow==9.2.0=py37h850a105_2 + - conda-forge/noarch::pip==22.3.1=pyhd8ed1ab_0 + - conda-forge/linux-64::pysocks==1.7.1=py37h89c1867_5 + - pytorch/linux-64::pytorch==1.12.1=py3.7_cuda11.6_cudnn8.3.2_0 + - conda-forge/noarch::tqdm==4.66.1=pyhd8ed1ab_0 + - defaults/linux-64::numpy-base==1.21.5=py37ha15fc14_3 + - conda-forge/noarch::urllib3==2.0.5=pyhd8ed1ab_0 + - conda-forge/noarch::requests==2.31.0=pyhd8ed1ab_0 + - conda-forge/linux-64::mkl_fft==1.3.1=py37h3e078e5_1 + - conda-forge/linux-64::mkl_random==1.2.2=py37h219a48f_0 + - defaults/linux-64::numpy==1.21.5=py37h6c91a56_3 + - conda-forge/noarch::plyfile==0.8.1=pyhd8ed1ab_0 + - pytorch/linux-64::torchaudio==0.12.1=py37_cu116 + - pytorch/linux-64::torchvision==0.13.1=py37_cu116 diff --git a/gaussian_renderer/__init__.py b/gaussian_renderer/__init__.py new file mode 100644 index 0000000..1dba1dd --- /dev/null +++ b/gaussian_renderer/__init__.py @@ -0,0 +1,232 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +import math +from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer +from scene.gaussian_model import GaussianModel +from utils.sh_utils import eval_sh + +def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, time=torch.tensor([[0]]), scaling_modifier = 1.0, override_color = None, stage=None, render_flow=False, return_pts=False): + # print(scaling_modifier) + assert scaling_modifier == 1 + if stage is None: + raise NotImplementedError + """ + Render the scene. + + Background tensor (bg_color) must be on GPU! + """ + + # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means + screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 + try: + screenspace_points.retain_grad() + except Exception as e: + # print(e) + pass + + # Set up rasterization configuration + + tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) + tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) + + raster_settings = GaussianRasterizationSettings( + image_height=int(viewpoint_camera.image_height), + image_width=int(viewpoint_camera.image_width), + tanfovx=tanfovx, + tanfovy=tanfovy, + bg=bg_color, + scale_modifier=scaling_modifier, + viewmatrix=viewpoint_camera.world_view_transform.cuda(), + projmatrix=viewpoint_camera.full_proj_transform.cuda(), + sh_degree=pc.active_sh_degree, + campos=viewpoint_camera.camera_center.cuda(), + prefiltered=False, + debug=pipe.debug + ) + + rasterizer = GaussianRasterizer(raster_settings=raster_settings) + + # means3D = pc.get_xyz + # add deformation to each points + # deformation = pc.get_deformation + means3D = pc.get_xyz + try: + assert time.item() >= 0 and time.item() <= 1 + time = time.to(means3D.device).repeat(means3D.shape[0],1) + except: + assert time >= 0 and time <= 1 + time = torch.tensor([time]).to(means3D.device).repeat(means3D.shape[0],1) + # time = time / 16 # in range of [0, 1] + + means2D = screenspace_points + opacity = pc._opacity + color=pc._features_dc + color=color[:,0,:] + + + + # If precomputed 3d covariance is provided, use it. If not, then it will be computed from + # scaling / rotation by the rasterizer. + scales = None + rotations = None + cov3D_precomp = None + + dx = None + if pipe.compute_cov3D_python: + cov3D_precomp = pc.get_covariance(scaling_modifier) + else: + # scales = pc.get_scaling + scales = pc._scaling + if scales.shape[-1] == 1: + scales = scales.repeat(1, 3) + #scales = torch.ones_like(scales ) * 0.03 + # rotations = pc.get_rotation + rotations = pc._rotation + deformation_point = pc._deformation_table + # print('color render:',color.shape) #[40000, 1, 3]->[40000, 3] + # print('rotations render:',rotations.shape) #[40000, 4] + + if stage == "static": # or time.sum() == 0: + # if stage == "static" or time.sum() == 0: + means3D_deform, scales_deform, rotations_deform, opacity_deform,color_deform = means3D, scales, rotations, opacity,color + elif stage in ["coarse", 'fine']: + means3D_deform, scales_deform, rotations_deform, opacity_deform, color_deform = pc._deformation(means3D[deformation_point], scales[deformation_point], rotations[deformation_point], opacity[deformation_point], color[deformation_point], time[deformation_point]) + dx = (means3D_deform - means3D[deformation_point]) + ds = (scales_deform - scales[deformation_point]) + dr = (rotations_deform - rotations[deformation_point]) + do = (opacity_deform - opacity[deformation_point]) + dc = (color_deform - color[deformation_point]) + else: + # deprecated + means3D_deform, scales_deform, rotations_deform, opacity_deform,color_deform = pc._deformation(means3D[deformation_point].detach(), scales[deformation_point].detach(), rotations[deformation_point].detach(), opacity[deformation_point].detach(),color[deformation_point].detach(), time[deformation_point].detach()) + dx = (means3D_deform - means3D[deformation_point].detach()) + ds = (scales_deform - scales[deformation_point].detach()) + dr = (rotations_deform - rotations[deformation_point].detach()) + do = (opacity_deform - opacity[deformation_point].detach()) + #dc=0 + dc=(color_deform - color[deformation_point].detach()) + + means3D_final = torch.zeros_like(means3D) + rotations_final = torch.zeros_like(rotations) + scales_final = torch.zeros_like(scales) + opacity_final = torch.zeros_like(opacity) + color_final= torch.zeros_like(color) + means3D_final[deformation_point] = means3D_deform + rotations_final[deformation_point] = rotations_deform + scales_final[deformation_point] = scales_deform + opacity_final[deformation_point] = opacity_deform + + # print('color_final shape before',color_final.shape) + + # print('color_final shape',color_final.shape) + # print('color_deform shape',color_deform.shape) + # print('deformation_point shape',deformation_point.shape) + color_final[deformation_point] = color_deform + + means3D_final[~deformation_point] = means3D[~deformation_point] + rotations_final[~deformation_point] = rotations[~deformation_point] + scales_final[~deformation_point] = scales[~deformation_point] + opacity_final[~deformation_point] = opacity[~deformation_point] + color_final[~deformation_point] = color[~deformation_point] + color_final=torch.unsqueeze(color_final, 1) #[40000, 3]->[40000, 1, 3] + + scales_final = pc.scaling_activation(scales_final) + #scales_final = torch.ones_like(scales_final ) * 0.01 + rotations_final = pc.rotation_activation(rotations_final) + opacity = pc.opacity_activation(opacity) + #color without activation + + # print(opacity.max()) + # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors + # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. + shs = None + colors_precomp = None + + #pc._features_dc=color_final #update color + + if override_color is None: + if pipe.convert_SHs_python: + shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2) + dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.cuda().repeat(pc.get_features.shape[0], 1)) + dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True) + sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized) + colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0) + else: + # print('shs=============') + #shs = pc.get_features + # dc=pc.get_features_dc + # print('pc.get_features_dc devide',pc.get_features_dc.device) + dc=color_final + #print('color_final devide',dc.device) + rest=pc.get_features_rest + shs=torch.cat((dc, rest), dim=1) + else: + colors_precomp = override_color + + #colors_precomp=color_final #not sure + # print('colors_precomp shape:',colors_precomp.shape) + # print('color_final shape:',color_final.shape) + + # Rasterize visible Gaussians to image, obtain their radii (on screen). + rendered_image, radii, depth, alpha = rasterizer( + means3D = means3D_final, + means2D = means2D, + shs = shs, + colors_precomp = colors_precomp, + opacities = opacity, + scales = scales_final, + rotations = rotations_final, + cov3D_precomp = cov3D_precomp) + + + # Those Gaussians that were frustum culled or had a radius of 0 were not visible. + # They will be excluded from value updates used in the splitting criteria. + res = { + "render": rendered_image, + "viewspace_points": screenspace_points, + "visibility_filter" : radii > 0, + "radii": radii, + "alpha": alpha, + "depth":depth, + } + # print(dx, time.sum(), stage) + if dx is not None: + res['dx'] = dx #.mean() + res['ds'] = ds #.mean() + res['dr'] = dr #.mean() + res['do'] = do #.mean() + res['dc'] = dc + + if render_flow and stage == 'coarse': + flow_screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 + try: + flow_screenspace_points.retain_grad() + except: + pass + rendered_flow, _, _, _ = rasterizer( + means3D = means3D_final, + means2D = flow_screenspace_points, + shs = None, + colors_precomp = dx, + opacities = opacity, + scales = scales_final, + rotations = rotations_final, + cov3D_precomp = cov3D_precomp + ) + res['rendered_flow'] = rendered_flow + if return_pts: + res['means3D'] = means3D_final + res['means2D'] = means2D + res['opacity_final'] = opacity_final + return res + diff --git a/gaussian_renderer/network_gui.py b/gaussian_renderer/network_gui.py new file mode 100644 index 0000000..df2f9da --- /dev/null +++ b/gaussian_renderer/network_gui.py @@ -0,0 +1,86 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +import traceback +import socket +import json +from scene.cameras import MiniCam + +host = "127.0.0.1" +port = 6009 + +conn = None +addr = None + +listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + +def init(wish_host, wish_port): + global host, port, listener + host = wish_host + port = wish_port + listener.bind((host, port)) + listener.listen() + listener.settimeout(0) + +def try_connect(): + global conn, addr, listener + try: + conn, addr = listener.accept() + print(f"\nConnected by {addr}") + conn.settimeout(None) + except Exception as inst: + pass + +def read(): + global conn + messageLength = conn.recv(4) + messageLength = int.from_bytes(messageLength, 'little') + message = conn.recv(messageLength) + return json.loads(message.decode("utf-8")) + +def send(message_bytes, verify): + global conn + if message_bytes != None: + conn.sendall(message_bytes) + conn.sendall(len(verify).to_bytes(4, 'little')) + conn.sendall(bytes(verify, 'ascii')) + +def receive(): + message = read() + + width = message["resolution_x"] + height = message["resolution_y"] + + if width != 0 and height != 0: + try: + do_training = bool(message["train"]) + fovy = message["fov_y"] + fovx = message["fov_x"] + znear = message["z_near"] + zfar = message["z_far"] + do_shs_python = bool(message["shs_python"]) + do_rot_scale_python = bool(message["rot_scale_python"]) + keep_alive = bool(message["keep_alive"]) + scaling_modifier = message["scaling_modifier"] + world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda() + world_view_transform[:,1] = -world_view_transform[:,1] + world_view_transform[:,2] = -world_view_transform[:,2] + full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda() + full_proj_transform[:,1] = -full_proj_transform[:,1] + custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform) + except Exception as e: + print("") + traceback.print_exc() + raise e + return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier + else: + return None, None, None, None, None, None \ No newline at end of file diff --git a/guidance/animatediff_utils.py b/guidance/animatediff_utils.py new file mode 100644 index 0000000..9922db1 --- /dev/null +++ b/guidance/animatediff_utils.py @@ -0,0 +1,336 @@ +from transformers import CLIPTextModel, CLIPTokenizer, logging +from diffusers import ( + AutoencoderKL, + UNet2DConditionModel, + DDIMScheduler, + StableDiffusionPipeline, +) +import torchvision.transforms.functional as TF + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +import sys +sys.path.append('./') + + +import os +from omegaconf import OmegaConf +from einops import rearrange +import sys +import argparse +import numpy as np +import torch +from torch import nn +import torch.nn.functional as F +from torchvision.utils import save_image +from torchvision import io +from tqdm import tqdm +from datetime import datetime +import random +import imageio +from pathlib import Path +import shutil +import logging +from diffusers.utils.import_utils import is_xformers_available +# from diffusers import StableDiffusionPipeline +from transformers import CLIPTextModel, CLIPTokenizer +from transformers import logging as transformers_logging +transformers_logging.set_verbosity_error() # disable warning +from animatediff.pipelines.pipeline_old import AnimationPipeline +# from animatediff.pipelines.pipeline_animation import AnimationPipeline +from diffusers import AutoencoderKL, UNet2DConditionModel +from diffusers import DDIMScheduler +from animatediff.models.unet import UNet3DConditionModel +from animatediff.utils.util import load_weights +from animatediff.utils.util import save_videos_grid + +class AnimateDiff(nn.Module): + def __init__(self, device='cuda',use_textual_inversion=False): + inference_config=OmegaConf.load("animatediff/configs/inference/inference-v2.yaml") + pretrained_model_path="animatediff/animatediff_models/StableDiffusion/stable-diffusion-v1-5" + # pretrained_model_path="runwayml/stable-diffusion-v1-5" + self.pretrained_model_path = pretrained_model_path + tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") + text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") + if use_textual_inversion: + inversion_path = None # TODO: CHANGE this! + text_encoder = CLIPTextModel.from_pretrained(inversion_path, subfolder="checkpoint-500") + else: + text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") + vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") + unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)) + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + unet.requires_grad_(False) + if is_xformers_available(): unet.enable_xformers_memory_efficient_attention() + else: assert False + self.device = device = torch.device(device) + # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + pipeline = AnimationPipeline( + vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, + scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)), + ).to(device) + motion_module_path = "./animatediff/animatediff_models/Motion_Module/mm_sd_v15_v2.ckpt" + dreambooth_model_path = "./animatediff/animatediff_models/DreamBooth_LoRA/rcnzCartoon3d_v20.safetensors" + + self.pipeline = load_weights( + pipeline, + motion_module_path = motion_module_path, + dreambooth_model_path = dreambooth_model_path, + ).to(device) + # unet = unet.to(device) + # vae = vae.to(device) + # text_encoder = text_encoder.to(device) + self.scheduler = self.pipeline.scheduler + self.alphas = self.scheduler.alphas_cumprod.to(self.device) + # self.scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler", torch_dtype= torch.float32) + self.rgb_to_latent = torch.from_numpy(np.array([[ 1.69810224, -0.28270747, -2.55163474, -0.78083445], + [-0.02986101, 4.91430525, 2.23158593, 3.02981481], + [-0.05746497, -3.04784101, 0.0448761 , -3.22913725]])).float().cuda(non_blocking=True) # 3 x 4 + self.latent_to_rgb = torch.from_numpy(np.array([ + [ 0.298, 0.207, 0.208], # L1 + [ 0.187, 0.286, 0.173], # L2 + [-0.158, 0.189, 0.264], # L3 + [-0.184, -0.271, -0.473], # L4 + ])).float().cuda(non_blocking=True) # 4 x 3 + + def load_text_encoder(self, use_textual_inversion=False): + if use_textual_inversion: + inversion_path="." + text_encoder = CLIPTextModel.from_pretrained(inversion_path, subfolder="checkpoint-500") + else: + text_encoder = CLIPTextModel.from_pretrained(self.pretrained_model_path, subfolder="text_encoder") + return text_encoder + + @torch.no_grad() + def prepare_text_emb(self, prompt=None, neg_prompt=None): + #example + if prompt is None: + prompt = "a panda dancing" + # prompt = "a dancing" + if neg_prompt is None: + neg_prompt = "color distortion,color shift,green light,semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, immutable, unchanging, stable, fixed, permant, unvarying, stationary, constant, steady, motionless, inactive, still, rooted, set" + # neg_prompt = "color distortion,color shift,green light,semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, missing arms, missing legs, extra arms, extra legs" + text_embeddings = self.pipeline._encode_prompt( + [prompt], self.device, num_videos_per_prompt=1, do_classifier_free_guidance=True, negative_prompt=[neg_prompt], + ) + return text_embeddings + + @torch.no_grad() + def prepare_text_emb_inversion(self, prompt=None, neg_prompt=None, inversion_prompt=None): + #example + if inversion_prompt is None: + inversion_prompt = 'a dancing' + if prompt is None: + prompt = "a panda dancing" + if neg_prompt is None: + neg_prompt = "color distortion,color shift,green light,semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, missing arms, missing legs, extra arms, extra legs" + text_embeddings = self.pipeline._encode_prompt( + [prompt], self.device, num_videos_per_prompt=1, do_classifier_free_guidance=True, negative_prompt=[neg_prompt], + ) + # self.pipeline.text_encoder = self.load_text_encoder(use_textual_inversion=True) + text_embeddings_inversion = self.pipeline._encode_prompt( + [inversion_prompt], self.device, num_videos_per_prompt=1, do_classifier_free_guidance=True, negative_prompt=[neg_prompt], + ) + return text_embeddings, text_embeddings_inversion + + def get_cfg(self, noisy_latents, text_embeddings, guidance_scale, t): + latent_model_input = torch.cat([noisy_latents] * 2) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + # print('latent_model_input', latent_model_input.shape) + noise_pred = self.pipeline.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + return noise_pred + + def train_step(self, pred_rgb, text_embeddings, guidance_scale=30, as_latent=False): + # shape = [1, 4, 8, 64, 64] #b,c,t,h,w b=1, c=4 beacause vae encode + # latents_vsd = torch.randn(shape).to(device) #input + if not as_latent: + print('diff input rgb', pred_rgb.shape) + frame_size=pred_rgb.shape[0] + # latents = (pred_rgb.permute(0, 2, 3, 1) @ self.rgb_to_latent).permute(3, 0, 1, 2) + + # latents = F.interpolate(latents, (64, 64), mode='bilinear', align_corners=False).unsqueeze(0) + # print('latents', latents.shape) + latents = self.pipeline.vae.encode(pred_rgb * 2 - 1).latent_dist.sample() # [8, 4, 64, 64] + print('latents shape',latents.shape) + # randn_noise=torch.rand_like(latents[0]).to(latents.device) + # for i in range(1,frame_size): + # i=torch.tensor(i,device=self.device).long() + # latents[i]=self.scheduler.add_noise(latents[i], randn_noise, i*100) + latents = latents.unsqueeze(0).permute(0, 2, 1, 3, 4) * 0.18215 #[1, 4, 8, 32, 32]) + + #image+guassian+guassian... + + else: + latents = pred_rgb + # latents = rearrange(latents, "b c) f h w -> (b f) c h w") + + print('latents', latents.shape, latents.requires_grad) + with torch.no_grad(): + noise = torch.randn_like(latents) + # t=torch.tensor(100).to(device) # + t = torch.randint( + 50, 950, (latents.shape[0],), device=self.device + ).long() + # print('time shape', t.shape) + noisy_latents = self.scheduler.add_noise(latents, noise, t) + noise_pred = self.get_cfg(noisy_latents, text_embeddings, guidance_scale, t) + noise_diff = noise_pred - noise + # noise_pred=self.pipeline(noisy_lantents=noisy_latents, + # t=t, + # prompt=prompt, + # negative_prompt= n_prompt, + # ) + # print('noise pred shape:',noise_pred.shape) #([1, 4, 8, 64, 64]) + w = (1 - self.alphas[t]).view(noise.shape[0], 1, 1, 1, 1) + grad = w * (noise_diff) + grad = torch.nan_to_num(grad) + + # if not as_latent: + # # grad: [1, 4, 16, 64, 64] + # print(grad.shape) + # # norm = torch.norm(grad, dim=(1)) + # norm = torch.norm(grad, dim=(1, 2)) + # print(norm) + # thres = torch.ones_like(norm).detach() * 1 + # # grad = torch.minimum(norm, thres) * F.normalize(grad, dim=(1)) + # grad = torch.minimum(norm, thres) * F.normalize(grad, dim=(1, 2)) + + target = (latents - grad).detach() + loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum') + return loss + + def train_step_inversion(self, pred_rgb, text_embeddings, text_embeddings_inversion, guidance_scale=30, as_latent=False): + # shape = [1, 4, 8, 64, 64] #b,c,t,h,w b=1, c=4 beacause vae encode + # latents_vsd = torch.randn(shape).to(device) #input + if not as_latent: + print('diff input rgb', pred_rgb.shape) + # latents = (pred_rgb.permute(0, 2, 3, 1) @ self.rgb_to_latent).permute(3, 0, 1, 2) + + # latents = F.interpolate(latents, (64, 64), mode='bilinear', align_corners=False).unsqueeze(0) + # print('latents', latents.shape) + latents = self.pipeline.vae.encode(pred_rgb * 2 - 1).latent_dist.sample() # [8, 4, 64, 64] + latents = latents.unsqueeze(0).permute(0, 2, 1, 3, 4) * 0.18215 + else: + latents = pred_rgb + # latents = rearrange(latents, "b c) f h w -> (b f) c h w") + + print('latents', latents.shape, latents.requires_grad) + with torch.no_grad(): + noise = torch.randn_like(latents) + # t=torch.tensor(100).to(device) # + t = torch.randint( + 50, 950, (latents.shape[0],), device=self.device + ).long() + # print('time shape', t.shape) + noisy_latents = self.scheduler.add_noise(latents, noise, t) + noise_pred_original = self.get_cfg(noisy_latents, text_embeddings, guidance_scale, t) + noise_pred_inversion = self.get_cfg(noisy_latents, text_embeddings_inversion, guidance_scale, t) + noise_diff = noise_pred_inversion - noise_pred_original + print('noise pred shape:',noise_diff.shape) #([1, 4, 8, 64, 64]) + w = (1 - self.alphas[t]).view(noise.shape[0], 1, 1, 1, 1) + grad = w * (noise_diff) + grad = torch.nan_to_num(grad) + + target = (latents - grad).detach() + loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum') + return loss + + @torch.no_grad() + def sample(self, text_embeddings, guidance_scale=7.5): + # latents = self.pipeline.vae.encode(pred_rgb).latent_dist.mode() + latents = torch.randn([1, 4, 16, 64, 64], device=self.device) * self.scheduler.init_noise_sigma + # noise = torch.randn_like(latents) + # t=torch.tensor(100).to(device) # + # t = torch.randint( + # 0, self.diffusion_model.num_timesteps, (pred_rgb.shape[0],), device=self.device + # ).long() + # print('time shape', t.shape) + from tqdm import tqdm + extra_step_kwargs = {} + # if accepts_eta: + extra_step_kwargs["eta"] = 0.0 + for i, t in enumerate(tqdm(self.scheduler.timesteps)): + latent_model_input = torch.cat([latents] * 2) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + noise_pred = self.pipeline.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + latents = self.scheduler.step(noise_pred, t, latents, eta=0.0).prev_sample + latents = 1 / 0.18215 * latents + print('output', latents.shape) + latents = rearrange(latents, "b c f h w -> (b f) c h w") + imgs = self.pipeline.vae.decode(latents).sample + imgs = (imgs / 2 + 0.5).clamp(0, 1) + return imgs + + @torch.no_grad() + def decode_latent(self, x): + latents = 1 / 0.18215 * x + latents = rearrange(latents, "b c f h w -> (b f) c h w") + return (self.pipeline.vae.decode(latents).sample / 2 + 0.5).clamp(0, 1) + +if __name__ == '__main__': + torch.manual_seed(16931037867122267877) + anim = AnimateDiff() + text_emb = anim.prepare_text_emb() + t2i = False + if t2i: + anim.scheduler.set_timesteps(50) + # pred_rgb = torch.randn((8, 3, 256, 256)).cuda() + # pred_rgb = torch.randn((1, 4, 8, 64, 64)) + res = anim.sample(text_emb) + print(res.shape) + res = res.permute(0, 2, 3, 1) + res = res.detach().cpu().numpy() + res = (res * 255).astype(np.uint8) + print(res.shape) + imageio.mimwrite('a.mp4', res, fps=16, quality=8, macro_block_size=1) + sds = True + if sds: + prefix = 'inversion_sds_latent_0.01' + # prefix = 'sds_rgb' + from PIL import Image + from torchvision.transforms import ToTensor + rgb0 = Image.open('data/panda_static/1.png').resize((256, 256)) + rgb0 = ToTensor()(rgb0).cuda().unsqueeze(0) + # print('rgb0', rgb0.shape) + # anim.scheduler.set_timesteps() + rgb_tensor = torch.randn((1, 4, 16, 32, 32)).cuda() * anim.scheduler.init_noise_sigma + # rgb_tensor = torch.randn((15, 3, 256, 256)).clamp(0, 1).cuda() + # rgb_tensor = torch.cat([rgb0.clone()] * 15).cuda() + # rgb_tensor[0] = rgb0 + rgb_tensor.requires_grad = True + # optim = torch.optim.AdamW([rgb_tensor], lr=0.05) + optim = torch.optim.Adam([rgb_tensor], lr=0.01) + from tqdm import tqdm + for i in tqdm(range(2000)): + # rgb_tensor[0] = rgb_tensor[0] * 0 + rgb0 + # loss = anim.train_step(torch.cat([rgb0, rgb_tensor], dim=0), text_emb, as_latent=False) + loss = anim.train_step(rgb_tensor, text_emb, as_latent=True) + # loss = anim.train_step(rgb_tensor, text_emb, as_latent=True) + loss.backward() + print('grad', rgb_tensor.grad.shape) + optim.step() + optim.zero_grad() + if i % 100 == 0: + res = anim.decode_latent(rgb_tensor).permute(0, 2, 3, 1) + # res = torch.cat([rgb0, rgb_tensor], dim=0).permute(0, 2, 3, 1) + res = res.detach().cpu().numpy() + res = (res * 255).astype(np.uint8) + print(res.shape) + imageio.mimwrite(f'{prefix}_{i}.mp4', res, fps=16, quality=8, macro_block_size=1) + res = anim.decode_latent(rgb_tensor).permute(0, 2, 3, 1) + res = rgb_tensor.permute(0, 2, 3, 1) + res = res.detach().cpu().numpy() + res = (res * 255).astype(np.uint8) + print(res.shape) + imageio.mimwrite(f'{prefix}.mp4', res, fps=16, quality=8, macro_block_size=1) + + # anim.train_step(pred_rgb, text_emb) diff --git a/guidance/clip.py b/guidance/clip.py new file mode 100644 index 0000000..730f9d1 --- /dev/null +++ b/guidance/clip.py @@ -0,0 +1,134 @@ +import torch +import torch.nn as nn +import os +import torchvision.transforms as T +import torchvision.transforms.functional as TF + +# import clip +from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTokenizer +from torchvision import transforms + +import torch.nn.functional as F + + +def spherical_dist_loss(x, y): + x = F.normalize(x, dim=-1) + y = F.normalize(y, dim=-1) + # print(x.shape, y.shape) + return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) + +class CLIP(nn.Module): + def __init__(self, device, clip_name = 'openai/clip-vit-base-patch32'): + super().__init__() + + self.device = device + + clip_name = clip_name + self.feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_name) + self.clip_model = CLIPModel.from_pretrained(clip_name).cuda() + self.tokenizer = CLIPTokenizer.from_pretrained(clip_name) + + self.normalize = transforms.Normalize(mean=self.feature_extractor.image_mean, std=self.feature_extractor.image_std) + + self.resize = transforms.Resize(224) + + # image augmentation + # self.aug = T.Compose([ + # T.Resize((224, 224)), + # T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + # ]) + + + def get_text_embeds(self, prompt, neg_prompt=None, dir=None): + + clip_text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ).input_ids.cuda() + text_z = self.clip_model.get_text_features(clip_text_input) + # text = clip.tokenize(prompt).to(self.device) + # text_z = self.clip_model.encode_text(text) + text_z = text_z / text_z.norm(dim=-1, keepdim=True) + + return text_z + + def set_epoch(self, epoch): + pass + + def get_img_embeds(self, img): + # img = self.aug(img) + assert len(img.shape) == 4 + img = self.resize(img) + img = self.normalize(img) + # print(img.shape) + image_z = self.clip_model.get_image_features(img) + image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features + # print(image_z.shape, 'clip image embed') + return image_z + + + def train_step(self, text_z, pred_rgb, image_ref_clip, **kwargs): + + pred_rgb = self.resize(pred_rgb) + pred_rgb = self.normalize(pred_rgb) + + image_z = self.clip_model.get_image_features(pred_rgb) + image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features + + # print(image_z.shape, text_z.shape) + loss = spherical_dist_loss(image_z, image_ref_clip) + + # loss = - (image_z * text_z).sum(-1).mean() + + return loss + + def text_loss(self, text_z, pred_rgb): + + pred_rgb = self.resize(pred_rgb) + pred_rgb = self.normalize(pred_rgb) + + image_z = self.clip_model.get_image_features(pred_rgb) + image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features + + # print(image_z.shape, text_z.shape) + loss = spherical_dist_loss(image_z, text_z) + + # loss = - (image_z * text_z).sum(-1).mean() + + return loss + + def img_loss(self, img_ref_z, pred_rgb): + # pred_rgb = self.aug(pred_rgb) + pred_rgb = self.resize(pred_rgb) + pred_rgb = self.normalize(pred_rgb) + + image_z = self.clip_model.get_image_features(pred_rgb) + image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features + + # loss = - (image_z * img_ref_z).sum(-1).mean() + loss = spherical_dist_loss(image_z, img_ref_z) + + return loss + + def img_img_loss(self, gt_rgb, pred_rgb): + # pred_rgb = self.aug(pred_rgb) + pred_rgb = self.resize(pred_rgb) + pred_rgb = self.normalize(pred_rgb) + + image_z = self.clip_model.get_image_features(pred_rgb) + image_z = image_z / image_z.norm(dim=-1, keepdim=True) # normalize features + + # loss = - (image_z * img_ref_z).sum(-1).mean() + loss = spherical_dist_loss(image_z, img_ref_z) + + return loss + +if __name__ == '__main__': + clip = CLIP('cuda') + im = torch.randn((1, 3, 512, 512)).cuda() + res = clip.get_img_embeds(im) + print(res.shape) + \ No newline at end of file diff --git a/guidance/imagedream_utils.py b/guidance/imagedream_utils.py new file mode 100644 index 0000000..f8465f0 --- /dev/null +++ b/guidance/imagedream_utils.py @@ -0,0 +1,334 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms.functional as TF + +from imagedream.camera_utils import get_camera, convert_opengl_to_blender, normalize_camera +from imagedream.model_zoo import build_model +from imagedream.ldm.models.diffusion.ddim import DDIMSampler + +from diffusers import DDIMScheduler + +class ImageDream(nn.Module): + def __init__( + self, + device, + model_name='sd-v2.1-base-4view-ipmv', + ckpt_path=None, + t_range=[0.02, 0.98], + ): + super().__init__() + + self.device = device + self.model_name = model_name + self.ckpt_path = ckpt_path + + self.model = build_model(self.model_name, ckpt_path=self.ckpt_path).eval().to(self.device) + self.model.device = device + for p in self.model.parameters(): + p.requires_grad_(False) + + self.dtype = torch.float32 + + self.num_train_timesteps = 1000 + self.min_step = int(self.num_train_timesteps * t_range[0]) + self.max_step = int(self.num_train_timesteps * t_range[1]) + + self.image_embeddings = {} + self.embeddings = {} + + self.scheduler = DDIMScheduler.from_pretrained( + "stabilityai/stable-diffusion-2-1-base", subfolder="scheduler", torch_dtype=self.dtype + ) + + @torch.no_grad() + def get_image_text_embeds(self, image, prompts, negative_prompts): + + image = F.interpolate(image, (256, 256), mode='bilinear', align_corners=False) + image_pil = TF.to_pil_image(image[0]) + image_embeddings = {} + ww = self.model.get_learned_image_conditioning(image_pil).repeat(5,1,1) + image_embeddings['pos'] = ww # [5, 257, 1280] + image_embeddings['neg'] = torch.zeros_like(ww) + + image_embeddings['ip_img'] = self.encode_imgs(image) + image_embeddings['neg_ip_img'] = torch.zeros_like(image_embeddings['ip_img']) + + pos_embeds = self.encode_text(prompts).repeat(5,1,1) + neg_embeds = self.encode_text(negative_prompts).repeat(5,1,1) + embeddings = {} + embeddings['pos'] = pos_embeds + embeddings['neg'] = neg_embeds + return image_embeddings, embeddings + + @torch.no_grad() + def prepare_embeds(self, image_li, prompts, negative_prompts): + return [self.get_image_text_embeds(image_li[idx:idx + 1], prompts, negative_prompts) for idx in range(len(image_li))] + # return [self.get_image_text_embeds(x, prompts, negative_prompts) for x in image_li] + + def encode_text(self, prompt): + # prompt: [str] + embeddings = self.model.get_learned_conditioning(prompt).to(self.device) + return embeddings + + @torch.no_grad() + def refine(self, pred_rgb, camera, + guidance_scale=5, steps=50, strength=0.8, + ): + + batch_size = pred_rgb.shape[0] + real_batch_size = batch_size // 4 + pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode='bilinear', align_corners=False) + latents = self.encode_imgs(pred_rgb_256.to(self.dtype)) + + self.scheduler.set_timesteps(steps) + init_step = int(steps * strength) + latents = self.scheduler.add_noise(latents, torch.randn_like(latents), self.scheduler.timesteps[init_step]) + + camera = camera[:, [0, 2, 1, 3]] # to blender convention (flip y & z axis) + camera[:, 1] *= -1 + camera = normalize_camera(camera).view(batch_size, 16) + + # extra view + camera = camera.view(real_batch_size, 4, 16) + camera = torch.cat([camera, torch.zeros_like(camera[:, :1])], dim=1) # [rB, 5, 16] + camera = camera.view(real_batch_size * 5, 16) + + camera = camera.repeat(2, 1) + embeddings = torch.cat([self.embeddings['neg'].repeat(real_batch_size, 1, 1), self.embeddings['pos'].repeat(real_batch_size, 1, 1)], dim=0) + image_embeddings = torch.cat([self.image_embeddings['neg'].repeat(real_batch_size, 1, 1), self.image_embeddings['pos'].repeat(real_batch_size, 1, 1)], dim=0) + ip_img_embeddings= torch.cat([self.image_embeddings['neg_ip_img'].repeat(real_batch_size, 1, 1, 1), self.image_embeddings['ip_img'].repeat(real_batch_size, 1, 1, 1)], dim=0) + + context = { + "context": embeddings, + "ip": image_embeddings, + "ip_img": ip_img_embeddings, + "camera": camera, + "num_frames": 4 + 1 + } + + for i, t in enumerate(self.scheduler.timesteps[init_step:]): + + # extra view + + latents = latents.view(real_batch_size, 4, 4, 32, 32) + latents = torch.cat([latents, torch.zeros_like(latents[:, :1])], dim=1).view(-1, 4, 32, 32) + latent_model_input = torch.cat([latents] * 2) + + tt = torch.cat([t.unsqueeze(0).repeat(real_batch_size * 5)] * 2).to(self.device) + + noise_pred = self.model.apply_model(latent_model_input, tt, context) + + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) + + # remove extra view + noise_pred_uncond = noise_pred_uncond.reshape(real_batch_size, 5, 4, 32, 32)[:, :-1].reshape(-1, 4, 32, 32) + noise_pred_cond = noise_pred_cond.reshape(real_batch_size, 5, 4, 32, 32)[:, :-1].reshape(-1, 4, 32, 32) + latents = latents.reshape(real_batch_size, 5, 4, 32, 32)[:, :-1].reshape(-1, 4, 32, 32) + + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + + latents = self.scheduler.step(noise_pred, t, latents).prev_sample + + imgs = self.decode_latents(latents) # [1, 3, 512, 512] + return imgs + + def train_step( + self, + cur_embedding, + pred_rgb, # [B, C, H, W] + camera, # [B, 4, 4] + step_ratio=None, + guidance_scale=5, + as_latent=False, + ): + image_embeddings_cur, embeddings_cur = cur_embedding + batch_size = pred_rgb.shape[0] + real_batch_size = batch_size // 4 + pred_rgb = pred_rgb.to(self.dtype) + + if as_latent: + latents = F.interpolate(pred_rgb, (32, 32), mode="bilinear", align_corners=False) * 2 - 1 + else: + # interp to 256x256 to be fed into vae. + pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode="bilinear", align_corners=False) + # encode image into latents with vae, requires grad! + latents = self.encode_imgs(pred_rgb_256) + + if step_ratio is not None: + # dreamtime-like + # t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio) + t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step) + t = torch.full((batch_size,), t, dtype=torch.long, device=self.device) + else: + t = torch.randint(self.min_step, self.max_step + 1, (real_batch_size,), dtype=torch.long, device=self.device).repeat(4) + + camera = camera[:, [0, 2, 1, 3]] # to blender convention (flip y & z axis) + camera[:, 1] *= -1 + camera = normalize_camera(camera).view(batch_size, 16) + + # extra view + camera = camera.view(real_batch_size, 4, 16) + camera = torch.cat([camera, torch.zeros_like(camera[:, :1])], dim=1) # [rB, 5, 16] + camera = camera.view(real_batch_size * 5, 16) + + camera = camera.repeat(2, 1) + embeddings = torch.cat([embeddings_cur['neg'].repeat(real_batch_size, 1, 1), embeddings_cur['pos'].repeat(real_batch_size, 1, 1)], dim=0) + image_embeddings = torch.cat([image_embeddings_cur['neg'].repeat(real_batch_size, 1, 1), image_embeddings_cur['pos'].repeat(real_batch_size, 1, 1)], dim=0) + ip_img_embeddings= torch.cat([image_embeddings_cur['neg_ip_img'].repeat(real_batch_size, 1, 1, 1), image_embeddings_cur['ip_img'].repeat(real_batch_size, 1, 1, 1)], dim=0) + + context = { + "context": embeddings, + "ip": image_embeddings, + "ip_img": ip_img_embeddings, + "camera": camera, + "num_frames": 4 + 1 + } + + # predict the noise residual with unet, NO grad! + with torch.no_grad(): + # add noise + noise = torch.randn_like(latents) + latents_noisy = self.model.q_sample(latents, t, noise) # [B=4, 4, 32, 32] + # extra view + t = t.view(real_batch_size, 4) + t = torch.cat([t, t[:, :1]], dim=1).view(-1) + latents_noisy = latents_noisy.view(real_batch_size, 4, 4, 32, 32) + latents_noisy = torch.cat([latents_noisy, torch.zeros_like(latents_noisy[:, :1])], dim=1).view(-1, 4, 32, 32) + # pred noise + latent_model_input = torch.cat([latents_noisy] * 2) + tt = torch.cat([t] * 2) + + # import kiui + # kiui.lo(latent_model_input, t, context['context'], context['camera']) + + noise_pred = self.model.apply_model(latent_model_input, tt, context) + + # perform guidance (high scale from paper!) + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) + + # remove extra view + noise_pred_uncond = noise_pred_uncond.reshape(real_batch_size, 5, 4, 32, 32)[:, :-1].reshape(-1, 4, 32, 32) + noise_pred_cond = noise_pred_cond.reshape(real_batch_size, 5, 4, 32, 32)[:, :-1].reshape(-1, 4, 32, 32) + + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + + grad = (noise_pred - noise) + grad = torch.nan_to_num(grad) + + target = (latents - grad).detach() + loss = F.mse_loss(latents.float(), target, reduction='sum') / latents.shape[0] + # loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum') / latents.shape[0] + + return loss + + def decode_latents(self, latents): + imgs = self.model.decode_first_stage(latents) + imgs = ((imgs + 1) / 2).clamp(0, 1) + return imgs + + def encode_imgs(self, imgs): + # imgs: [B, 3, 256, 256] + imgs = 2 * imgs - 1 + latents = self.model.get_first_stage_encoding(self.model.encode_first_stage(imgs)) + return latents # [B, 4, 32, 32] + + @torch.no_grad() + def prompt_to_img( + self, + image, + prompts, + negative_prompts="", + height=256, + width=256, + num_inference_steps=50, + guidance_scale=5.0, + latents=None, + elevation=0, + azimuth_start=0, + ): + if isinstance(prompts, str): + prompts = [prompts] + + if isinstance(negative_prompts, str): + negative_prompts = [negative_prompts] + + real_batch_size = len(prompts) + batch_size = len(prompts) * 5 + + # Text embeds -> img latents + sampler = DDIMSampler(self.model) + shape = [4, height // 8, width // 8] + + c_ = {"context": self.encode_text(prompts).repeat(5,1,1)} + uc_ = {"context": self.encode_text(negative_prompts).repeat(5,1,1)} + + # image embeddings + image = F.interpolate(image, (256, 256), mode='bilinear', align_corners=False) + image_pil = TF.to_pil_image(image[0]) + image_embeddings = self.model.get_learned_image_conditioning(image_pil).repeat(5,1,1).to(self.device) + c_["ip"] = image_embeddings + uc_["ip"] = torch.zeros_like(image_embeddings) + + ip_img = self.encode_imgs(image) + c_["ip_img"] = ip_img + uc_["ip_img"] = torch.zeros_like(ip_img) + + camera = get_camera(4, elevation=elevation, azimuth_start=azimuth_start, extra_view=True) + camera = camera.repeat(real_batch_size, 1).to(self.device) + + c_["camera"] = uc_["camera"] = camera + c_["num_frames"] = uc_["num_frames"] = 5 + + kiui.lo(image_embeddings, ip_img, camera) + + latents, _ = sampler.sample(S=num_inference_steps, conditioning=c_, + batch_size=batch_size, shape=shape, + verbose=False, + unconditional_guidance_scale=guidance_scale, + unconditional_conditioning=uc_, + eta=0, x_T=None) + + # Img latents -> imgs + imgs = self.decode_latents(latents) # [4, 3, 256, 256] + + kiui.lo(latents, imgs) + + # Img to Numpy + imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy() + imgs = (imgs * 255).round().astype("uint8") + + return imgs + + +if __name__ == "__main__": + import argparse + import matplotlib.pyplot as plt + import kiui + + parser = argparse.ArgumentParser() + parser.add_argument("image", type=str) + parser.add_argument("prompt", type=str) + parser.add_argument("--negative", default="", type=str) + parser.add_argument("--steps", type=int, default=30) + opt = parser.parse_args() + + device = torch.device("cuda") + + sd = ImageDream(device) + + image = kiui.read_image(opt.image, mode='tensor') + image = image.permute(2, 0, 1).unsqueeze(0).to(device) + + while True: + imgs = sd.prompt_to_img(image, opt.prompt, opt.negative, num_inference_steps=opt.steps) + + grid = np.concatenate([ + np.concatenate([imgs[0], imgs[1]], axis=1), + np.concatenate([imgs[2], imgs[3]], axis=1), + ], axis=0) + + # visualize image + plt.imshow(grid) + plt.show() diff --git a/guidance/sd_utils.py b/guidance/sd_utils.py new file mode 100644 index 0000000..39209c8 --- /dev/null +++ b/guidance/sd_utils.py @@ -0,0 +1,419 @@ +from transformers import CLIPTextModel, CLIPTokenizer, logging +from diffusers import ( + AutoencoderKL, + UNet2DConditionModel, + PNDMScheduler, + DDIMScheduler, + StableDiffusionPipeline, +) +from diffusers.utils.import_utils import is_xformers_available + +from typing import List + +# suppress partial model loading warning +logging.set_verbosity_error() + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def seed_everything(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + # torch.backends.cudnn.deterministic = True + # torch.backends.cudnn.benchmark = True + + +class StableDiffusion(nn.Module): + def __init__( + self, + device, + fp16=True, + vram_O=False, + sd_version="2.1", + hf_key=None, + t_range=[0.02, 0.98], + ): + super().__init__() + + self.device = device + self.sd_version = sd_version + + if hf_key is not None: + print(f"[INFO] using hugging face custom model key: {hf_key}") + model_key = hf_key + elif self.sd_version == "2.1": + model_key = "stabilityai/stable-diffusion-2-1-base" + elif self.sd_version == "2.0": + model_key = "stabilityai/stable-diffusion-2-base" + elif self.sd_version == "1.5": + model_key = "runwayml/stable-diffusion-v1-5" + else: + raise ValueError( + f"Stable-diffusion version {self.sd_version} not supported." + ) + + self.dtype = torch.float16 if fp16 else torch.float32 + + # Create model + pipe = StableDiffusionPipeline.from_pretrained( + model_key, torch_dtype=self.dtype + ) + + if vram_O: + pipe.enable_sequential_cpu_offload() + pipe.enable_vae_slicing() + pipe.unet.to(memory_format=torch.channels_last) + pipe.enable_attention_slicing(1) + # pipe.enable_model_cpu_offload() + else: + pipe.to(device) + + self.vae = pipe.vae + self.tokenizer = pipe.tokenizer + self.text_encoder = pipe.text_encoder + self.unet = pipe.unet + + self.scheduler = DDIMScheduler.from_pretrained( + model_key, subfolder="scheduler", torch_dtype=self.dtype + ) + + del pipe + + self.num_train_timesteps = self.scheduler.config.num_train_timesteps + self.min_step = int(self.num_train_timesteps * t_range[0]) + self.max_step = int(self.num_train_timesteps * t_range[1]) + self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience + + self.embeddings = None + + @torch.no_grad() + def get_text_embeds(self, prompts, negative_prompts): + pos_embeds = self.encode_text(prompts) # [1, 77, 768] + neg_embeds = self.encode_text(negative_prompts) + self.embeddings = torch.cat([neg_embeds, pos_embeds], dim=0) # [2, 77, 768] + + def encode_text(self, prompt): + # prompt: [str] + inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ) + embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0] + return embeddings + + @torch.no_grad() + def refine(self, pred_rgb, + guidance_scale=100, steps=50, strength=0.8, + ): + + batch_size = pred_rgb.shape[0] + pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False) + latents = self.encode_imgs(pred_rgb_512.to(self.dtype)) + # latents = torch.randn((1, 4, 64, 64), device=self.device, dtype=self.dtype) + + self.scheduler.set_timesteps(steps) + init_step = int(steps * strength) + latents = self.scheduler.add_noise(latents, torch.randn_like(latents), self.scheduler.timesteps[init_step]) + + for i, t in enumerate(self.scheduler.timesteps[init_step:]): + + latent_model_input = torch.cat([latents] * 2) + + noise_pred = self.unet( + latent_model_input, t, encoder_hidden_states=self.embeddings, + ).sample + + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + + latents = self.scheduler.step(noise_pred, t, latents).prev_sample + + imgs = self.decode_latents(latents) # [1, 3, 512, 512] + return imgs + + def train_step( + self, + pred_rgb, + step_ratio=None, + guidance_scale=100, + as_latent=False, + ): + + batch_size = pred_rgb.shape[0] + pred_rgb = pred_rgb.to(self.dtype) + + if as_latent: + latents = F.interpolate(pred_rgb, (64, 64), mode="bilinear", align_corners=False) * 2 - 1 + else: + # interp to 512x512 to be fed into vae. + pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode="bilinear", align_corners=False) + # encode image into latents with vae, requires grad! + latents = self.encode_imgs(pred_rgb_512) + + if step_ratio is not None: + # dreamtime-like + # t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio) + t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step) + t = torch.full((batch_size,), t, dtype=torch.long, device=self.device) + else: + t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device) + + # w(t), sigma_t^2 + w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1) + + # predict the noise residual with unet, NO grad! + with torch.no_grad(): + # add noise + noise = torch.randn_like(latents) + latents_noisy = self.scheduler.add_noise(latents, noise, t) + # pred noise + latent_model_input = torch.cat([latents_noisy] * 2) + tt = torch.cat([t] * 2) + + noise_pred = self.unet( + latent_model_input, tt, encoder_hidden_states=self.embeddings.repeat(batch_size, 1, 1) + ).sample + + # perform guidance (high scale from paper!) + noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_pos - noise_pred_uncond + ) + + grad = w * (noise_pred - noise) + grad = torch.nan_to_num(grad) + + # seems important to avoid NaN... + # grad = grad.clamp(-1, 1) + + target = (latents - grad).detach() + loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum') / latents.shape[0] + + return loss + + @torch.no_grad() + def produce_latents( + self, + height=512, + width=512, + num_inference_steps=50, + guidance_scale=7.5, + latents=None, + ): + if latents is None: + latents = torch.randn( + ( + self.embeddings.shape[0] // 2, + self.unet.in_channels, + height // 8, + width // 8, + ), + device=self.device, + ) + + self.scheduler.set_timesteps(num_inference_steps) + + for i, t in enumerate(self.scheduler.timesteps): + # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. + latent_model_input = torch.cat([latents] * 2) + # predict the noise residual + noise_pred = self.unet( + latent_model_input, t, encoder_hidden_states=self.embeddings + ).sample + + # perform guidance + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_cond - noise_pred_uncond + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents).prev_sample + + return latents + + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + + imgs = self.vae.decode(latents).sample + imgs = (imgs / 2 + 0.5).clamp(0, 1) + + return imgs + + def encode_imgs(self, imgs): + # imgs: [B, 3, H, W] + + imgs = 2 * imgs - 1 + + posterior = self.vae.encode(imgs).latent_dist + latents = posterior.sample() * self.vae.config.scaling_factor + + return latents + + def prompt_to_img( + self, + prompts, + negative_prompts="", + height=512, + width=512, + num_inference_steps=50, + guidance_scale=7.5, + latents=None, + ): + if isinstance(prompts, str): + prompts = [prompts] + + if isinstance(negative_prompts, str): + negative_prompts = [negative_prompts] + + # Prompts -> text embeds + self.get_text_embeds(prompts, negative_prompts) + + # Text embeds -> img latents + latents = self.produce_latents( + height=height, + width=width, + latents=latents, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + ) # [1, 4, 64, 64] + + # Img latents -> imgs + imgs = self.decode_latents(latents) # [1, 3, 512, 512] + + # Img to Numpy + imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy() + imgs = (imgs * 255).round().astype("uint8") + + return imgs + + @torch.no_grad() + def generate_img( + self, + emb, + height=512, + width=512, + num_inference_steps=50, + guidance_scale=7.5, + latents=None, + ): + neg_prompt = self.encode_text([""]) + self.embeddings = torch.cat([neg_prompt, emb.unsqueeze(0)], dim=0) # + # Text embeds -> img latents + latents = self.produce_latents( + height=height, + width=width, + latents=latents, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + ) # [1, 4, 64, 64] + + # Img latents -> imgs + imgs = self.decode_latents(latents) # [1, 3, 512, 512] + + # Img to Numpy + imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy() + imgs = (imgs * 255).round().astype("uint8") + + return imgs + + +def window_score(x, gamma: float = 0.6) -> torch.Tensor: + # return torch.exp(-torch.abs(gamma*x)) + return torch.cos(gamma*x) + + +# Collect similar info from attentive features for neglected concept +def sim_correction(embeddings: torch.Tensor, + correction_indices: List[int], + scores: torch.Tensor, + window: bool = True) -> torch.Tensor: + """ Embeddings shape (77, 768), computes similarity between embeddings, combine using similarity scores""" + ntk, dim = embeddings.shape + device = embeddings.device + + for i, tk in enumerate(correction_indices): + alpha = scores[i] + v = embeddings[tk].clone() + + sim = v.unsqueeze(0) * embeddings # nth,dim 77,768 + sim = torch.relu(sim) # 77,768 + + ind = torch.lt(sim, 0.5) # relu is not needed in this case + sim[ind] = 0. + sim[:tk] = 0. # 77, 768 + sim /= max(sim.max(), 1e-6) + + if window: + ws = window_score(torch.arange(0, ntk - tk).to(device), gamma=0.8) + ws = ws.unsqueeze(-1) # 77 - tk,1 + sim[tk:] = ws * sim[tk:] # 77, 768 + + successor = torch.sum(sim * embeddings, dim=0) + embeddings[tk] = (1 - alpha) * embeddings[tk] + alpha * successor + embeddings[tk] *= v.norm() / embeddings[tk].norm() + + return embeddings + +if __name__ == "__main__": + import argparse + import matplotlib.pyplot as plt + + parser = argparse.ArgumentParser() + parser.add_argument("prompt", type=str) + parser.add_argument("--negative", default="", type=str) + parser.add_argument( + "--sd_version", + type=str, + default="1.5", + choices=["1.5", "2.0", "2.1"], + help="stable diffusion version", + ) + parser.add_argument( + "--hf_key", + type=str, + default=None, + help="hugging face Stable diffusion model key", + ) + parser.add_argument("--fp16", action="store_true", help="use float16 for training") + parser.add_argument( + "--vram_O", action="store_true", help="optimization for low VRAM usage" + ) + parser.add_argument("-H", type=int, default=512) + parser.add_argument("-W", type=int, default=512) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--steps", type=int, default=50) + opt = parser.parse_args() + + seed_everything(opt.seed) + + device = torch.device("cuda") + + sd = StableDiffusion(device, opt.fp16, opt.vram_O, opt.sd_version, opt.hf_key) + + # imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps) + + # visualize image + # plt.imshow(imgs[0]) + # plt.show() + ww = sd.encode_text('a photo of a cat and a dog') + # ww = sd.encode_text('A teddy bear with a yellow bird') + token_indices = [5, 8] + cor_scores1 = [0.3, 0] + # from IPython import embed + # embed() + res = sim_correction(embeddings=ww[0], correction_indices=token_indices, scores=torch.tensor(cor_scores1, device=device)) + + imgs = sd.generate_img(res, opt.H, opt.W, opt.steps) + from PIL import Image + for i in range(len(imgs)): + Image.fromarray(imgs[i]).save(f'b_{i}.png') + imgs = sd.generate_img(ww[0], opt.H, opt.W, opt.steps) + from PIL import Image + for i in range(len(imgs)): + Image.fromarray(imgs[i]).save(f'c_{i}.png') diff --git a/guidance/zero123_utils.py b/guidance/zero123_utils.py new file mode 100644 index 0000000..65b3824 --- /dev/null +++ b/guidance/zero123_utils.py @@ -0,0 +1,244 @@ +from transformers import CLIPTextModel, CLIPTokenizer, logging +from diffusers import ( + AutoencoderKL, + UNet2DConditionModel, + DDIMScheduler, + StableDiffusionPipeline, +) +import torchvision.transforms.functional as TF + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import os +import sys +sys.path.append('./') + +from zero123 import Zero123Pipeline + + +class Zero123(nn.Module): + def __init__(self, device, fp16=True, t_range=[0.2, 0.6], zero123_path='ashawkey/stable-zero123-diffusers'): + super().__init__() + + self.device = device + self.fp16 = fp16 + self.dtype = torch.float16 if fp16 else torch.float32 + self.pipe = Zero123Pipeline.from_pretrained( + zero123_path, + variant="fp16_ema" if self.fp16 else None, + torch_dtype=self.dtype, + ).to(self.device) + + # for param in self.pipe.parameters(): + # param.requires_grad = False + + self.pipe.image_encoder.eval() + self.pipe.vae.eval() + self.pipe.unet.eval() + self.pipe.clip_camera_projection.eval() + + self.vae = self.pipe.vae + self.unet = self.pipe.unet + + self.pipe.set_progress_bar_config(disable=True) + + self.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config) + self.num_train_timesteps = self.scheduler.config.num_train_timesteps + + self.min_step = int(self.num_train_timesteps * t_range[0]) + self.max_step = int(self.num_train_timesteps * t_range[1]) + self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience + + # embeddings = None + self.use_stable_zero123 = 'stable' in zero123_path + + def get_cam_embeddings(self, polar, azimuth, radius): + if self.use_stable_zero123: + T = np.stack([np.deg2rad(polar), np.sin(np.deg2rad(azimuth)), np.cos(np.deg2rad(azimuth)), np.deg2rad(np.full_like(polar, 90))], axis=-1) + # 90 because pose0 is fixed + # https://github.com/threestudio-project/threestudio/pull/356/files#diff-7cab41ca8761951def6987763141c5cfe7b1e3c0d174ac3cb0f5b4ca8ec8309aR220 + else: + # original zero123 camera embedding + T = np.stack([np.deg2rad(polar), np.sin(np.deg2rad(azimuth)), np.cos(np.deg2rad(azimuth)), radius], axis=-1) + T = torch.from_numpy(T).unsqueeze(0).unsqueeze(0).to(dtype=self.dtype, device=self.device) # [8, 1, 4] + # print(T.shape) + # T = torch.from_numpy(T).unsqueeze(1).to(dtype=self.dtype, device=self.device) # [8, 1, 4] + return T + + @torch.no_grad() + def get_img_embeds(self, x): + # x: image tensor in [0, 1] + x = F.interpolate(x, (256, 256), mode='bilinear', align_corners=False) + x_pil = [TF.to_pil_image(image) for image in x] + x_clip = self.pipe.feature_extractor(images=x_pil, return_tensors="pt").pixel_values.to(device=self.device, dtype=self.dtype) + c = self.pipe.image_encoder(x_clip).image_embeds + v = self.encode_imgs(x.to(self.dtype)) / self.vae.config.scaling_factor + embeddings = [c, v] + return embeddings + + @torch.no_grad() + def get_img_embeds_pil(self, x, x_pil): + #x: image tensor in [0, 1] + x = F.interpolate(x, (256, 256), mode='bilinear', align_corners=False) + x_pil = [TF.to_pil_image(image) for image in x] + x_clip = self.pipe.feature_extractor(images=x_pil, return_tensors="pt").pixel_values.to(device=self.device, dtype=self.dtype) + c = self.pipe.image_encoder(x_clip).image_embeds + v = self.encode_imgs(x.to(self.dtype)) / self.vae.config.scaling_factor + return c, v + + + @torch.no_grad() + def get_vis_image(self, pred_rgb_256, latents_noisy, t, noise_pred): + # print(pred_rgb_256.shape, latents_noisy.shape, t.shape, noise_pred.shape) + with torch.no_grad(): + # visualize predicted denoised image + result_hopefully_less_noisy_image = self.decode_latents(self.pred_x0(latents_noisy, t, noise_pred)) + + # visualize noisier image + result_noisier_image = self.decode_latents(latents_noisy) + + # all 3 input images are [1, 3, H, W], e.g. [1, 3, 512, 512] + viz_images = torch.cat([pred_rgb_256, result_noisier_image, result_hopefully_less_noisy_image],dim=-1) + return viz_images + + def pred_x0(self, sample, timestep, model_output): + alpha_prod_t = self.alphas[timestep].to(self.device).view(-1, 1, 1, 1) + + beta_prod_t = 1 - alpha_prod_t + # print('alpha_prod_t', alpha_prod_t.shape) + if self.scheduler.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif self.scheduler.config.prediction_type == "sample": + pred_original_sample = model_output + elif self.scheduler.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + # predict V + model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," + " or `v_prediction`" + ) + + return pred_original_sample + + def train_step(self, pred_rgb, polar, azimuth, radius, embeddings, step_ratio=None, guidance_scale=2, as_latent=False): + # pred_rgb: tensor [1, 3, H, W] in [0, 1] + + batch_size = pred_rgb.shape[0] + + if as_latent: + latents = F.interpolate(pred_rgb, (32, 32), mode='bilinear', align_corners=False) * 2 - 1 + else: + pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode='bilinear', align_corners=False) + latents = self.encode_imgs(pred_rgb_256.to(self.dtype)) + + if step_ratio is not None: + # dreamtime-like + # t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio) + t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step) + t = torch.full((batch_size,), t, dtype=torch.long, device=self.device) + else: + t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device) + + w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1) + + with torch.no_grad(): + noise = torch.randn_like(latents) + latents_noisy = self.scheduler.add_noise(latents, noise, t) + + x_in = torch.cat([latents_noisy] * 2) + t_in = torch.cat([t] * 2) + + T = self.get_cam_embeddings(polar, azimuth, radius) + # T = np.stack([np.deg2rad(polar), np.sin(np.deg2rad(azimuth)), np.cos(np.deg2rad(azimuth)), radius], axis=-1) + # T = torch.from_numpy(T).unsqueeze(0).unsqueeze(0).to(self.dtype).to(self.device) # [8, 1, 4] + # T = torch.from_numpy(T).unsqueeze(1).to(self.dtype).to(self.device) # [8, 1, 4] + # print('embeddings[0].repeat(batch_size, 1, 1) ',embeddings[0].repeat(batch_size, 1, 1).shape) #[4, 1, 768] + # print('T ',T.shape) #[1, 1, 4] + cc_emb = torch.cat([embeddings[0].repeat(batch_size, 1, 1), T.repeat(batch_size,1,1)], dim=-1) + cc_emb = self.pipe.clip_camera_projection(cc_emb) + cc_emb = torch.cat([cc_emb, torch.zeros_like(cc_emb)], dim=0) + + vae_emb = embeddings[1].repeat(batch_size, 1, 1, 1) + vae_emb = torch.cat([vae_emb, torch.zeros_like(vae_emb)], dim=0) + + noise_pred = self.unet( + torch.cat([x_in, vae_emb], dim=1), + t_in.to(self.unet.dtype), + encoder_hidden_states=cc_emb, + ).sample + + noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + + grad = w * (noise_pred - noise) + grad = torch.nan_to_num(grad) + + target = (latents - grad).detach() + loss = F.mse_loss(latents.float(), target, reduction='sum') + + im = self.get_vis_image(pred_rgb_256[:4], latents_noisy[:4], t[:4], noise_pred[:4]) + + return loss, im + + + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + + imgs = self.vae.decode(latents.to(self.dtype)).sample + imgs = (imgs / 2 + 0.5).clamp(0, 1) + + return imgs + + def encode_imgs(self, imgs, mode=False): + # imgs: [B, 3, H, W] + + imgs = 2 * imgs - 1 + + posterior = self.vae.encode(imgs).latent_dist + if mode: + latents = posterior.mode() + else: + latents = posterior.sample() + latents = latents * self.vae.config.scaling_factor + + return latents + + +if __name__ == '__main__': + import cv2 + import argparse + import numpy as np + import matplotlib.pyplot as plt + + parser = argparse.ArgumentParser() + + parser.add_argument('input', type=str) + parser.add_argument('--polar', type=float, default=0, help='delta polar angle in [-90, 90]') + parser.add_argument('--azimuth', type=float, default=0, help='delta azimuth angle in [-180, 180]') + parser.add_argument('--radius', type=float, default=0, help='delta camera radius multiplier in [-0.5, 0.5]') + + opt = parser.parse_args() + + device = torch.device('cuda') + + print(f'[INFO] loading image from {opt.input} ...') + image = cv2.imread(opt.input, cv2.IMREAD_UNCHANGED) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image = cv2.resize(image, (256, 256), interpolation=cv2.INTER_AREA) + image = image.astype(np.float32) / 255.0 + image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).contiguous().to(device) + + print(f'[INFO] loading model ...') + zero123 = Zero123(device) + + print(f'[INFO] running embed ...') + emb=zero123.get_img_embeds(image) + print(f'[INFO] running model ...') + while True: + outputs = zero123.refine(image, polar=[opt.polar], azimuth=[opt.azimuth], radius=[opt.radius], embeddings=emb,strength=0) + plt.imshow(outputs.float().cpu().numpy().transpose(0, 2, 3, 1)[0]) + plt.show() diff --git a/guidance/zero123_xl_utils.py b/guidance/zero123_xl_utils.py new file mode 100644 index 0000000..1e441ee --- /dev/null +++ b/guidance/zero123_xl_utils.py @@ -0,0 +1,277 @@ +from transformers import CLIPTextModel, CLIPTokenizer, logging +from diffusers import ( + AutoencoderKL, + UNet2DConditionModel, + DDIMScheduler, + StableDiffusionPipeline, +) +import torchvision.transforms.functional as TF + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import os +import sys +sys.path.append('./') + +from zero123 import Zero123Pipeline + + +class Zero123(nn.Module): + def __init__(self, device, fp16=True, t_range=[0.2, 0.6]): + super().__init__() + + self.device = device + self.fp16 = fp16 + self.dtype = torch.float16 if fp16 else torch.float32 + zero123_path="bennyguo/zero123-xl-diffusers" + self.pipe = Zero123Pipeline.from_pretrained( + zero123_path, + variant="fp16_ema" if self.fp16 else None, + torch_dtype=self.dtype, + ).to(self.device) + + # for param in self.pipe.parameters(): + # param.requires_grad = False + + self.pipe.image_encoder.eval() + self.pipe.vae.eval() + self.pipe.unet.eval() + self.pipe.clip_camera_projection.eval() + + self.vae = self.pipe.vae + self.unet = self.pipe.unet + + self.pipe.set_progress_bar_config(disable=True) + + self.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config) + self.num_train_timesteps = self.scheduler.config.num_train_timesteps + + self.min_step = int(self.num_train_timesteps * t_range[0]) + self.max_step = int(self.num_train_timesteps * t_range[1]) + self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience + + # embeddings = None + + @torch.no_grad() + def get_img_embeds(self, x): + # x: image tensor in [0, 1] + x = F.interpolate(x, (256, 256), mode='bilinear', align_corners=False) + x_pil = [TF.to_pil_image(image) for image in x] + x_clip = self.pipe.feature_extractor(images=x_pil, return_tensors="pt").pixel_values.to(device=self.device, dtype=self.dtype) + c = self.pipe.image_encoder(x_clip).image_embeds + v = self.encode_imgs(x.to(self.dtype)) / self.vae.config.scaling_factor + embeddings = [c, v] + return embeddings + + @torch.no_grad() + def get_img_embeds_pil(self, x, x_pil): + #x: image tensor in [0, 1] + x = F.interpolate(x, (256, 256), mode='bilinear', align_corners=False) + x_pil = [TF.to_pil_image(image) for image in x] + x_clip = self.pipe.feature_extractor(images=x_pil, return_tensors="pt").pixel_values.to(device=self.device, dtype=self.dtype) + c = self.pipe.image_encoder(x_clip).image_embeds + v = self.encode_imgs(x.to(self.dtype)) / self.vae.config.scaling_factor + return c, v + + + @torch.no_grad() + def get_vis_image(self, pred_rgb_256, latents_noisy, t, noise_pred): + # print(pred_rgb_256.shape, latents_noisy.shape, t.shape, noise_pred.shape) + with torch.no_grad(): + # visualize predicted denoised image + result_hopefully_less_noisy_image = self.decode_latents(self.pred_x0(latents_noisy, t, noise_pred)) + + # visualize noisier image + result_noisier_image = self.decode_latents(latents_noisy) + + # all 3 input images are [1, 3, H, W], e.g. [1, 3, 512, 512] + viz_images = torch.cat([pred_rgb_256, result_noisier_image, result_hopefully_less_noisy_image],dim=-1) + return viz_images + + @torch.no_grad() + def refine(self, pred_rgb, polar, azimuth, radius, embeddings, + guidance_scale=5, steps=50, strength=0.8, + ): + + batch_size = pred_rgb.shape[0] + + self.scheduler.set_timesteps(steps) + + if strength == 0: + init_step = 0 + latents = torch.randn((1, 4, 32, 32), device=self.device, dtype=self.dtype) + else: + init_step = int(steps * strength) + pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode='bilinear', align_corners=False) + latents = self.encode_imgs(pred_rgb_256.to(self.dtype)) + latents = self.scheduler.add_noise(latents, torch.randn_like(latents), self.scheduler.timesteps[init_step]) + + T = np.stack([np.deg2rad(polar), np.sin(np.deg2rad(azimuth)), np.cos(np.deg2rad(azimuth)), radius], axis=-1) #(1,4) + T = torch.from_numpy(T).unsqueeze(1).to(self.dtype).to(self.device) # [1, 1, 4] + cc_emb = torch.cat([embeddings[0].repeat(batch_size, 1, 1), T], dim=-1) #embeddings[0] shape [1,768] + cc_emb = self.pipe.clip_camera_projection(cc_emb) + cc_emb = torch.cat([cc_emb, torch.zeros_like(cc_emb)], dim=0) + + vae_emb = embeddings[1].repeat(batch_size, 1, 1, 1) + vae_emb = torch.cat([vae_emb, torch.zeros_like(vae_emb)], dim=0) + + for i, t in enumerate(self.scheduler.timesteps[init_step:]): + print('step:',i) + x_in = torch.cat([latents] * 2) + t_in = torch.cat([t.view(1)] * 2).to(self.device) + + noise_pred = self.unet( + torch.cat([x_in, vae_emb], dim=1), + t_in.to(self.unet.dtype), + encoder_hidden_states=cc_emb, + ).sample + + noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + + latents = self.scheduler.step(noise_pred, t, latents).prev_sample + + imgs = self.decode_latents(latents) # [1, 3, 256, 256] + return imgs + + + def pred_x0(self, sample, timestep, model_output): + alpha_prod_t = self.alphas[timestep].to(self.device).view(-1, 1, 1, 1) + + beta_prod_t = 1 - alpha_prod_t + # print('alpha_prod_t', alpha_prod_t.shape) + if self.scheduler.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif self.scheduler.config.prediction_type == "sample": + pred_original_sample = model_output + elif self.scheduler.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + # predict V + model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," + " or `v_prediction`" + ) + + return pred_original_sample + + def train_step(self, pred_rgb, polar, azimuth, radius, embeddings, step_ratio=None, guidance_scale=2, as_latent=False): + # pred_rgb: tensor [1, 3, H, W] in [0, 1] + + batch_size = pred_rgb.shape[0] + + if as_latent: + latents = F.interpolate(pred_rgb, (32, 32), mode='bilinear', align_corners=False) * 2 - 1 + else: + pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode='bilinear', align_corners=False) + latents = self.encode_imgs(pred_rgb_256.to(self.dtype)) + + if step_ratio is not None: + # dreamtime-like + # t = self.max_step - (self.max_step - self.min_step) * np.sqrt(step_ratio) + t = np.round((1 - step_ratio) * self.num_train_timesteps).clip(self.min_step, self.max_step) + t = torch.full((batch_size,), t, dtype=torch.long, device=self.device) + else: + t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device) + + w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1) + + with torch.no_grad(): + noise = torch.randn_like(latents) + latents_noisy = self.scheduler.add_noise(latents, noise, t) + + x_in = torch.cat([latents_noisy] * 2) + t_in = torch.cat([t] * 2) + + T = np.stack([np.deg2rad(polar), np.sin(np.deg2rad(azimuth)), np.cos(np.deg2rad(azimuth)), radius], axis=-1) + T = torch.from_numpy(T).unsqueeze(0).unsqueeze(0).to(self.dtype).to(self.device) # [8, 1, 4] + # T = torch.from_numpy(T).unsqueeze(1).to(self.dtype).to(self.device) # [8, 1, 4] + # print('embeddings[0].repeat(batch_size, 1, 1) ',embeddings[0].repeat(batch_size, 1, 1).shape) #[4, 1, 768] + # print('T ',T.shape) #[1, 1, 4] + cc_emb = torch.cat([embeddings[0].repeat(batch_size, 1, 1), T.repeat(batch_size,1,1)], dim=-1) + cc_emb = self.pipe.clip_camera_projection(cc_emb) + cc_emb = torch.cat([cc_emb, torch.zeros_like(cc_emb)], dim=0) + + vae_emb = embeddings[1].repeat(batch_size, 1, 1, 1) + vae_emb = torch.cat([vae_emb, torch.zeros_like(vae_emb)], dim=0) + + noise_pred = self.unet( + torch.cat([x_in, vae_emb], dim=1), + t_in.to(self.unet.dtype), + encoder_hidden_states=cc_emb, + ).sample + + noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + + grad = w * (noise_pred - noise) + grad = torch.nan_to_num(grad) + + target = (latents - grad).detach() + loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum') + + im = self.get_vis_image(pred_rgb_256[:4], latents_noisy[:4], t[:4], noise_pred[:4]) + + return loss, im + + + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + + imgs = self.vae.decode(latents.to(self.dtype)).sample + imgs = (imgs / 2 + 0.5).clamp(0, 1) + + return imgs + + def encode_imgs(self, imgs, mode=False): + # imgs: [B, 3, H, W] + + imgs = 2 * imgs - 1 + + posterior = self.vae.encode(imgs).latent_dist + if mode: + latents = posterior.mode() + else: + latents = posterior.sample() + latents = latents * self.vae.config.scaling_factor + + return latents + + +if __name__ == '__main__': + import cv2 + import argparse + import numpy as np + import matplotlib.pyplot as plt + + parser = argparse.ArgumentParser() + + parser.add_argument('input', type=str) + parser.add_argument('--polar', type=float, default=0, help='delta polar angle in [-90, 90]') + parser.add_argument('--azimuth', type=float, default=0, help='delta azimuth angle in [-180, 180]') + parser.add_argument('--radius', type=float, default=0, help='delta camera radius multiplier in [-0.5, 0.5]') + + opt = parser.parse_args() + + device = torch.device('cuda') + + print(f'[INFO] loading image from {opt.input} ...') + image = cv2.imread(opt.input, cv2.IMREAD_UNCHANGED) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image = cv2.resize(image, (256, 256), interpolation=cv2.INTER_AREA) + image = image.astype(np.float32) / 255.0 + image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).contiguous().to(device) + + print(f'[INFO] loading model ...') + zero123 = Zero123(device) + + print(f'[INFO] running embed ...') + emb=zero123.get_img_embeds(image) + print(f'[INFO] running model ...') + while True: + outputs = zero123.refine(image, polar=[opt.polar], azimuth=[opt.azimuth], radius=[opt.radius], embeddings=emb,strength=0) + plt.imshow(outputs.float().cpu().numpy().transpose(0, 2, 3, 1)[0]) + plt.show() diff --git a/preprocess.py b/preprocess.py new file mode 100644 index 0000000..570d3c3 --- /dev/null +++ b/preprocess.py @@ -0,0 +1,81 @@ +import os +import glob +import sys +import cv2 +import argparse +import numpy as np +import matplotlib.pyplot as plt + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import transforms +from PIL import Image +import rembg + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('--path', type=str, help="path to image (png, jpeg, etc.)") + parser.add_argument('--model', default='u2net', type=str, help="rembg model, see https://github.com/danielgatis/rembg#models") + parser.add_argument('--size', default=512, type=int, help="output resolution") + parser.add_argument('--border_ratio', default=0.2, type=float, help="output border ratio") + parser.add_argument('--recenter', type=bool, default=False, help="recenter, potentially not helpful for multiview zero123") + opt = parser.parse_args() + + session = rembg.new_session(model_name=opt.model) + + if os.path.isdir(opt.path): + print(f'[INFO] processing directory {opt.path}...') + files = glob.glob(f'{opt.path}/*') + out_dir = opt.path + else: # isfile + files = [opt.path] + out_dir = os.path.dirname(opt.path) + + savedir=opt.path+'/rgba/' + os.makedirs(savedir,exist_ok=True) + for file in files: + if file.endswith('jpg') or file.endswith('png'): + out_base = os.path.basename(file).split('.')[0] + + out_rgba = os.path.join(savedir,out_base + '.png') + + # load image + print(f'[INFO] loading image {file}...') + image = cv2.imread(file, cv2.IMREAD_UNCHANGED) + + # carve background + print(f'[INFO] background removal...') + carved_image = rembg.remove(image, session=session) # [H, W, 4] + mask = carved_image[..., -1] > 0 + + # recenter + if opt.recenter: + # pass + # print('???') + print(f'[INFO] recenter...') + final_rgba = np.zeros((opt.size, opt.size, 4), dtype=np.uint8) + + coords = np.nonzero(mask) + x_min, x_max = coords[0].min(), coords[0].max() + y_min, y_max = coords[1].min(), coords[1].max() + h = x_max - x_min + w = y_max - y_min + desired_size = int(opt.size * (1 - opt.border_ratio)) + scale = desired_size / max(h, w) + h2 = int(h * scale) + w2 = int(w * scale) + x2_min = (opt.size - h2) // 2 + x2_max = x2_min + h2 + y2_min = (opt.size - w2) // 2 + y2_max = y2_min + w2 + print(x_min, x_max, x2_min, x2_max) + final_rgba[x2_min:x2_max, y2_min:y2_max] = cv2.resize(carved_image[x_min:x_max, y_min:y_max], (w2, h2), interpolation=cv2.INTER_AREA) + + else: + final_rgba = carved_image + + # write image + cv2.imwrite(out_rgba, final_rgba) + print('out path:',out_rgba) diff --git a/preprocess_sync.py b/preprocess_sync.py new file mode 100644 index 0000000..a785fe5 --- /dev/null +++ b/preprocess_sync.py @@ -0,0 +1,71 @@ +import os +import glob +import sys +import cv2 +import argparse +import numpy as np +import matplotlib.pyplot as plt + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import transforms +from PIL import Image +import rembg + + +def chop_image_into_16(image): + # Assuming 'image' is a cv2 image + height, width, _ = image.shape + + # Calculating the width of each slice + slice_width = width // 16 + + # Slicing the image into 16 pieces + slices = [image[:, i*slice_width:(i+1)*slice_width] for i in range(16)] + + return slices + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('--path', type=str, help="path to image (png, jpeg, etc.)") + parser.add_argument('--model', default='u2net', type=str, help="rembg model, see https://github.com/danielgatis/rembg#models") + parser.add_argument('--size', default=512, type=int, help="output resolution") + parser.add_argument('--border_ratio', default=0.2, type=float, help="output border ratio") + parser.add_argument('--recenter', type=bool, default=False, help="recenter, potentially not helpful for multiview zero123") + opt = parser.parse_args() + + session = rembg.new_session(model_name=opt.model) + + if os.path.isdir(opt.path): + print(f'[INFO] processing directory {opt.path}...') + files = glob.glob(f'{opt.path}/*') + out_dir = opt.path + else: # isfile + files = [opt.path] + out_dir = os.path.dirname(opt.path) + + os.makedirs(out_dir,exist_ok=True) + for file in files: + if file.endswith('jpg') or file.endswith('png') and not '_rgba.png' in file: + out_base = os.path.basename(file).split('.')[0] + + # load image + print(f'[INFO] loading image {file}...') + image = cv2.imread(file, cv2.IMREAD_UNCHANGED) + + slices = chop_image_into_16(image) + + for idx, image in enumerate(slices): + # carve background + print(f'[INFO] background removal...') + carved_image = rembg.remove(image, session=session) # [H, W, 4] + mask = carved_image[..., -1] > 0 + # else: + final_rgba = carved_image + + # write image + out_rgba = os.path.join(opt.path, out_base + f'_{idx}_rgba.png') + cv2.imwrite(out_rgba, final_rgba) + print('out path:',out_rgba) diff --git a/render.py b/render.py new file mode 100644 index 0000000..d5e9023 --- /dev/null +++ b/render.py @@ -0,0 +1,134 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# +import imageio +import numpy as np +import torch +from scene import Scene +import os +import cv2 +from tqdm import tqdm +from os import makedirs +from gaussian_renderer import render +import torchvision +from utils.general_utils import safe_state +from argparse import ArgumentParser +from arguments import ModelParams, PipelineParams,OptimizationParams, get_combined_args, ModelHiddenParams +from gaussian_renderer import GaussianModel +from time import time +to8b = lambda x : (255*np.clip(x.cpu().numpy(),0,1)).astype(np.uint8) +def render_set(model_path, name, iteration, views, gaussians, pipeline, background,multiview_video, fname='video_rgb.mp4'): + render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders") + gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt") + + makedirs(render_path, exist_ok=True) + makedirs(gts_path, exist_ok=True) + render_images = [] + gt_list = [] + render_list = [] + print(len(views)) + + # for idx, view in enumerate(tqdm(views, desc="Rendering progress")): + # for idx in tqdm(range (100)): + fnum = 100 + # fnum = 12 + for idx in tqdm(range (fnum)): + view = views[idx] + if idx == 0:time1 = time() + #ww = torch.tensor([idx / 12]).unsqueeze(0) + ww = torch.tensor([idx / fnum]).unsqueeze(0) + # ww = torch.tensor([idx / 100]).unsqueeze(0) + + if multiview_video: + rendering = render(view['cur_cam'], gaussians, pipeline, background, time=ww, stage='fine')["render"] + else: + rendering = render(view['pose0_cam'], gaussians, pipeline, background, time=ww, stage='fine')["render"] + render_images.append(to8b(rendering).transpose(1,2,0)) + render_list.append(rendering) + time2=time() + print("FPS:",(len(views)-1)/(time2-time1)) + print('Len', len(render_images)) + imageio.mimwrite(os.path.join(model_path, name, "ours_{}".format(iteration), fname), render_images, fps=8, quality=8) + + +def render_set_timefix(model_path, name, iteration, views, gaussians, pipeline, background,multiview_video, fname='video_rgb.mp4',time_fix=-1): + render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders") + gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt") + + makedirs(render_path, exist_ok=True) + makedirs(gts_path, exist_ok=True) + render_images = [] + gt_list = [] + render_list = [] + print(len(views)) + + # for idx, view in enumerate(tqdm(views, desc="Rendering progress")): + for idx in tqdm(range (12)): + #for idx in tqdm(range (100)): + view = views[idx] + if idx == 0:time1 = time() + # ww = torch.tensor([idx / 16]).unsqueeze(0) + ww = torch.tensor([idx / 100]).unsqueeze(0) + if time_fix!=-1: + ww=torch.tensor([time_fix/16]).unsqueeze(0) + if multiview_video: + rendering = render(view['cur_cam'], gaussians, pipeline, background, time=ww, stage='fine')["render"] + + render_images.append(to8b(rendering).transpose(1,2,0)) + render_list.append(rendering) + time2=time() + print("FPS:",(len(views)-1)/(time2-time1)) + print('Len', len(render_images)) + imageio.mimwrite(os.path.join(model_path, name, "ours_{}".format(iteration), fname), render_images, fps=7, quality=8) + + +def render_sets(dataset : ModelParams, hyperparam, opt,iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool, skip_video: bool,multiview_video: bool): + with torch.no_grad(): + gaussians = GaussianModel(dataset.sh_degree, hyperparam) + scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False) + + bg_color = [1,1,1] if dataset.white_background else [0, 0, 0] + background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") + + if not skip_train: + render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background,multiview_video) + + if not skip_test: + render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background,multiview_video) + if not skip_video: + #origin + render_set(dataset.model_path,"video",scene.loaded_iter,scene.getVideoCameras(),gaussians,pipeline,background,multiview_video=True, fname='multiview.mp4') + render_set(dataset.model_path,"video",scene.loaded_iter,scene.getVideoCameras(),gaussians,pipeline,background,multiview_video=False, fname='pose0.mp4') + +if __name__ == "__main__": + # Set up command line argument parser + parser = ArgumentParser(description="Testing script parameters") + model = ModelParams(parser) + op = OptimizationParams(parser) + pipeline = PipelineParams(parser) + hyperparam = ModelHiddenParams(parser) + parser.add_argument("--iteration", default=-1, type=int) + parser.add_argument("--skip_train", action="store_true") + parser.add_argument("--skip_test", action="store_true") + parser.add_argument("--quiet", action="store_true") + parser.add_argument("--skip_video", action="store_true") + parser.add_argument('--multiview_video',default=False,action="store_true") + parser.add_argument("--configs", type=str) + args = get_combined_args(parser) + print("Rendering " , args.model_path) + if args.configs: + import mmcv + from utils.params_utils import merge_hparams + config = mmcv.Config.fromfile(args.configs) + args = merge_hparams(args, config) + # Initialize system state (RNG) + safe_state(args.quiet) + + render_sets(model.extract(args), hyperparam.extract(args), op.extract(args),args.iteration, pipeline.extract(args), args.skip_train, args.skip_test, args.skip_video,args.multiview_video) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..6968554 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,12 @@ +mmcv==1.6.0 +matplotlib +argparse +lpips +plyfile +imageio-ffmpeg +transformers +diffusers +accelerate +tensorboard +imageio +opencv-python diff --git a/scene/__init__.py b/scene/__init__.py new file mode 100644 index 0000000..67b2bcb --- /dev/null +++ b/scene/__init__.py @@ -0,0 +1,102 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import os +import random +import json +from utils.system_utils import searchForMaxIteration +from scene.dataset_readers import sceneLoadTypeCallbacks +from scene.gaussian_model import GaussianModel +# from scene.dataset import FourDGSdataset +from scene.i2v_dataset import FourDGSdataset, ImageDreamdataset +# from scene.rife_sync_dataset import FourDGSdataset +from arguments import ModelParams +from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON +from torch.utils.data import Dataset +import numpy as np + +class Scene: + + gaussians : GaussianModel + + def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None,shuffle=True, resolution_scales=[1.0], load_coarse=False): + """b + :param path: Path to colmap scene main folder. + """ + self.model_path = args.model_path + self.loaded_iter = None + self.gaussians = gaussians + + if load_iteration: + if load_iteration == -1: + self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud")) + else: + self.loaded_iter = load_iteration + print("Loading trained model at iteration {}".format(self.loaded_iter)) + + self.train_cameras = {} + self.test_cameras = {} + self.video_cameras = {} + self.cameras_extent = 1 # scene_info.nerf_normalization["radius"] + + print("Loading Training Cameras") + if args.imagedream: + ds = ImageDreamdataset + else: + ds = FourDGSdataset + self.train_camera = ds(split='train', frame_num=args.frame_num,name=args.name,rife=args.rife,static=args.static) + print("Loading Test Cameras") + self.maxtime = self.train_camera.pose0_num + self.test_camera = ds(split='test', frame_num=args.frame_num,name=args.name,rife=args.rife,static=args.static) + print("Loading Video Cameras") + self.video_cameras = ds(split='video', frame_num=args.frame_num,name=args.name,rife=args.rife,static=args.static) + xyz_max = [2.5, 2.5, 2.5] + xyz_min = [-2.5, -2.5, -2.5] + self.gaussians._deformation.deformation_net.grid.set_aabb(xyz_max,xyz_min) + # assert not self.loaded_iter + if self.loaded_iter: + self.gaussians.load_ply(os.path.join(self.model_path, + "point_cloud", + "iteration_" + str(self.loaded_iter), + "point_cloud.ply")) + self.gaussians.load_model(os.path.join(self.model_path, + "point_cloud", + "iteration_" + str(self.loaded_iter), + )) + else: + # TODO: accept argparse + num_pts = int(2.5e4) + + + # random init + self.gaussians.random_init(num_pts, 1, radius=0.5) + # point cloud init + + # cloud_path='./data/eagle1_1.ply' # + + # 4 is not used + # self.gaussians.load_3studio_ply(cloud_path, spatial_lr_scale=1, time_line=self.maxtime, step=1, position_scale=1, load_color=False) ## imagedream + + def save(self, iteration, stage): + if stage == "coarse": + point_cloud_path = os.path.join(self.model_path, "point_cloud/coarse_iteration_{}".format(iteration)) + + else: + point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration)) + self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply")) + self.gaussians.save_deformation(point_cloud_path) + def getTrainCameras(self, scale=1.0): + return self.train_camera + + def getTestCameras(self, scale=1.0): + return self.test_camera + def getVideoCameras(self, scale=1.0): + return self.video_cameras \ No newline at end of file diff --git a/scene/cam_utils.py b/scene/cam_utils.py new file mode 100644 index 0000000..1416fa4 --- /dev/null +++ b/scene/cam_utils.py @@ -0,0 +1,62 @@ +import numpy as np +from scipy.spatial.transform import Rotation as R + +import torch + +def dot(x, y): + if isinstance(x, np.ndarray): + return np.sum(x * y, -1, keepdims=True) + else: + return torch.sum(x * y, -1, keepdim=True) + + +def length(x, eps=1e-20): + if isinstance(x, np.ndarray): + return np.sqrt(np.maximum(np.sum(x * x, axis=-1, keepdims=True), eps)) + else: + return torch.sqrt(torch.clamp(dot(x, x), min=eps)) + + +def safe_normalize(x, eps=1e-20): + return x / length(x, eps) + + +def look_at(campos, target, opengl=True): + # campos: [N, 3], camera/eye position + # target: [N, 3], object to look at + # return: [N, 3, 3], rotation matrix + if not opengl: + # camera forward aligns with -z + forward_vector = safe_normalize(target - campos) + up_vector = np.array([0, 1, 0], dtype=np.float32) + right_vector = safe_normalize(np.cross(forward_vector, up_vector)) + up_vector = safe_normalize(np.cross(right_vector, forward_vector)) + else: + # camera forward aligns with +z + forward_vector = safe_normalize(campos - target) + up_vector = np.array([0, 1, 0], dtype=np.float32) + right_vector = safe_normalize(np.cross(up_vector, forward_vector)) + up_vector = safe_normalize(np.cross(forward_vector, right_vector)) + R = np.stack([right_vector, up_vector, forward_vector], axis=1) + return R + + +# elevation & azimuth to pose (cam2world) matrix +def orbit_camera(elevation, azimuth, radius=1, is_degree=True, target=None, opengl=True): + # radius: scalar + # elevation: scalar, in (-90, 90), from +y to -y is (-90, 90) + # azimuth: scalar, in (-180, 180), from +z to +x is (0, 90) + # return: [4, 4], camera pose matrix + if is_degree: + elevation = np.deg2rad(elevation) + azimuth = np.deg2rad(azimuth) + x = radius * np.cos(elevation) * np.sin(azimuth) + y = - radius * np.sin(elevation) + z = radius * np.cos(elevation) * np.cos(azimuth) + if target is None: + target = np.zeros([3], dtype=np.float32) + campos = np.array([x, y, z]) + target # [3] + T = np.eye(4, dtype=np.float32) + T[:3, :3] = look_at(campos, target, opengl) + T[:3, 3] = campos + return T diff --git a/scene/cameras.py b/scene/cameras.py new file mode 100644 index 0000000..0abb65e --- /dev/null +++ b/scene/cameras.py @@ -0,0 +1,77 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +from torch import nn +import numpy as np +from utils.graphics_utils import getWorld2View2, getProjectionMatrix + +class Camera(nn.Module): + def __init__(self, R, T, FoVx, FoVy, image, gt_alpha_mask, + image_name, uid, + trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda", time = 0 + ): + super(Camera, self).__init__() + + self.uid = uid + # self.colmap_id = colmap_id + self.R = R + self.T = T + self.FoVx = FoVx + self.FoVy = FoVy + self.image_name = image_name + self.time = time + try: + self.data_device = torch.device(data_device) + except Exception as e: + print(e) + print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" ) + self.data_device = torch.device("cuda") + self.original_image = image.clamp(0.0, 1.0) + # .to(self.data_device) + self.image_width = self.original_image.shape[2] + self.image_height = self.original_image.shape[1] + + if gt_alpha_mask is not None: + self.original_image *= gt_alpha_mask + # .to(self.data_device) + else: + self.original_image *= torch.ones((1, self.image_height, self.image_width)) + # , device=self.data_device) + + + self.zfar = 100.0 + self.znear = 0.01 + + self.trans = trans + self.scale = scale + + self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1) + # .cuda() + self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1) + # .cuda() + self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) + self.camera_center = self.world_view_transform.inverse()[3, :3] + +class MiniCam: + def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform, time): + self.image_width = width + self.image_height = height + self.FoVy = fovy + self.FoVx = fovx + self.znear = znear + self.zfar = zfar + self.world_view_transform = world_view_transform + self.full_proj_transform = full_proj_transform + view_inv = torch.inverse(self.world_view_transform) + self.camera_center = view_inv[3][:3] + self.time = time + diff --git a/scene/colmap_loader.py b/scene/colmap_loader.py new file mode 100644 index 0000000..0f32d23 --- /dev/null +++ b/scene/colmap_loader.py @@ -0,0 +1,282 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import numpy as np +import collections +import struct + +CameraModel = collections.namedtuple( + "CameraModel", ["model_id", "model_name", "num_params"]) +Camera = collections.namedtuple( + "Camera", ["id", "model", "width", "height", "params"]) +BaseImage = collections.namedtuple( + "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) +Point3D = collections.namedtuple( + "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) +CAMERA_MODELS = { + CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), + CameraModel(model_id=1, model_name="PINHOLE", num_params=4), + CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), + CameraModel(model_id=3, model_name="RADIAL", num_params=5), + CameraModel(model_id=4, model_name="OPENCV", num_params=8), + CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), + CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), + CameraModel(model_id=7, model_name="FOV", num_params=5), + CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), + CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), + CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) +} +CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) + for camera_model in CAMERA_MODELS]) +CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model) + for camera_model in CAMERA_MODELS]) + + +def qvec2rotmat(qvec): + return np.array([ + [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, + 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], + 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], + [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], + 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, + 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], + [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], + 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], + 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) + +def rotmat2qvec(R): + Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat + K = np.array([ + [Rxx - Ryy - Rzz, 0, 0, 0], + [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], + [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], + [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 + eigvals, eigvecs = np.linalg.eigh(K) + qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] + if qvec[0] < 0: + qvec *= -1 + return qvec + +class Image(BaseImage): + def qvec2rotmat(self): + return qvec2rotmat(self.qvec) + +def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): + """Read and unpack the next bytes from a binary file. + :param fid: + :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. + :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. + :param endian_character: Any of {@, =, <, >, !} + :return: Tuple of read and unpacked values. + """ + data = fid.read(num_bytes) + return struct.unpack(endian_character + format_char_sequence, data) + +def read_points3D_text(path): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DText(const std::string& path) + void Reconstruction::WritePoints3DText(const std::string& path) + """ + xyzs = None + rgbs = None + errors = None + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + xyz = np.array(tuple(map(float, elems[1:4]))) + rgb = np.array(tuple(map(int, elems[4:7]))) + error = np.array(float(elems[7])) + if xyzs is None: + xyzs = xyz[None, ...] + rgbs = rgb[None, ...] + errors = error[None, ...] + else: + xyzs = np.append(xyzs, xyz[None, ...], axis=0) + rgbs = np.append(rgbs, rgb[None, ...], axis=0) + errors = np.append(errors, error[None, ...], axis=0) + return xyzs, rgbs, errors + +def read_points3D_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DBinary(const std::string& path) + void Reconstruction::WritePoints3DBinary(const std::string& path) + """ + + + with open(path_to_model_file, "rb") as fid: + num_points = read_next_bytes(fid, 8, "Q")[0] + + xyzs = np.empty((num_points, 3)) + rgbs = np.empty((num_points, 3)) + errors = np.empty((num_points, 1)) + + for p_id in range(num_points): + binary_point_line_properties = read_next_bytes( + fid, num_bytes=43, format_char_sequence="QdddBBBd") + xyz = np.array(binary_point_line_properties[1:4]) + rgb = np.array(binary_point_line_properties[4:7]) + error = np.array(binary_point_line_properties[7]) + track_length = read_next_bytes( + fid, num_bytes=8, format_char_sequence="Q")[0] + track_elems = read_next_bytes( + fid, num_bytes=8*track_length, + format_char_sequence="ii"*track_length) + xyzs[p_id] = xyz + rgbs[p_id] = rgb + errors[p_id] = error + return xyzs, rgbs, errors + +def read_intrinsics_text(path): + """ + Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py + """ + cameras = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + camera_id = int(elems[0]) + model = elems[1] + assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE" + width = int(elems[2]) + height = int(elems[3]) + params = np.array(tuple(map(float, elems[4:]))) + cameras[camera_id] = Camera(id=camera_id, model=model, + width=width, height=height, + params=params) + return cameras + +def read_extrinsics_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadImagesBinary(const std::string& path) + void Reconstruction::WriteImagesBinary(const std::string& path) + """ + images = {} + with open(path_to_model_file, "rb") as fid: + num_reg_images = read_next_bytes(fid, 8, "Q")[0] + for _ in range(num_reg_images): + binary_image_properties = read_next_bytes( + fid, num_bytes=64, format_char_sequence="idddddddi") + image_id = binary_image_properties[0] + qvec = np.array(binary_image_properties[1:5]) + tvec = np.array(binary_image_properties[5:8]) + camera_id = binary_image_properties[8] + image_name = "" + current_char = read_next_bytes(fid, 1, "c")[0] + while current_char != b"\x00": # look for the ASCII 0 entry + image_name += current_char.decode("utf-8") + current_char = read_next_bytes(fid, 1, "c")[0] + num_points2D = read_next_bytes(fid, num_bytes=8, + format_char_sequence="Q")[0] + x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D, + format_char_sequence="ddq"*num_points2D) + xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), + tuple(map(float, x_y_id_s[1::3]))]) + point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) + images[image_id] = Image( + id=image_id, qvec=qvec, tvec=tvec, + camera_id=camera_id, name=image_name, + xys=xys, point3D_ids=point3D_ids) + return images + + +def read_intrinsics_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::WriteCamerasBinary(const std::string& path) + void Reconstruction::ReadCamerasBinary(const std::string& path) + """ + cameras = {} + with open(path_to_model_file, "rb") as fid: + num_cameras = read_next_bytes(fid, 8, "Q")[0] + for _ in range(num_cameras): + camera_properties = read_next_bytes( + fid, num_bytes=24, format_char_sequence="iiQQ") + camera_id = camera_properties[0] + model_id = camera_properties[1] + model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name + width = camera_properties[2] + height = camera_properties[3] + num_params = CAMERA_MODEL_IDS[model_id].num_params + params = read_next_bytes(fid, num_bytes=8*num_params, + format_char_sequence="d"*num_params) + cameras[camera_id] = Camera(id=camera_id, + model=model_name, + width=width, + height=height, + params=np.array(params)) + assert len(cameras) == num_cameras + return cameras + + +def read_extrinsics_text(path): + """ + Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py + """ + images = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + image_id = int(elems[0]) + qvec = np.array(tuple(map(float, elems[1:5]))) + tvec = np.array(tuple(map(float, elems[5:8]))) + camera_id = int(elems[8]) + image_name = elems[9] + elems = fid.readline().split() + xys = np.column_stack([tuple(map(float, elems[0::3])), + tuple(map(float, elems[1::3]))]) + point3D_ids = np.array(tuple(map(int, elems[2::3]))) + images[image_id] = Image( + id=image_id, qvec=qvec, tvec=tvec, + camera_id=camera_id, name=image_name, + xys=xys, point3D_ids=point3D_ids) + return images + + +def read_colmap_bin_array(path): + """ + Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py + + :param path: path to the colmap binary file. + :return: nd array with the floating point values in the value + """ + with open(path, "rb") as fid: + width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1, + usecols=(0, 1, 2), dtype=int) + fid.seek(0) + num_delimiter = 0 + byte = fid.read(1) + while True: + if byte == b"&": + num_delimiter += 1 + if num_delimiter >= 3: + break + byte = fid.read(1) + array = np.fromfile(fid, np.float32) + array = array.reshape((width, height, channels), order="F") + return np.transpose(array, (1, 0, 2)).squeeze() diff --git a/scene/dataset.py b/scene/dataset.py new file mode 100644 index 0000000..a97f754 --- /dev/null +++ b/scene/dataset.py @@ -0,0 +1,58 @@ +from torch.utils.data import Dataset +# from scene.cameras import Camera +import numpy as np +from utils.general_utils import PILtoTorch +from utils.graphics_utils import fov2focal, focal2fov +import torch +from utils.camera_utils import loadCam +from utils.graphics_utils import focal2fov + +from torchvision.transforms import ToTensor +from PIL import Image +import glob +from scene.cam_utils import orbit_camera +import math + +def getProjectionMatrix(znear, zfar, fovX, fovY): + tanHalfFovY = math.tan((fovY / 2)) + tanHalfFovX = math.tan((fovX / 2)) + + P = torch.zeros(4, 4) + + z_sign = 1.0 + + P[0, 0] = 1 / tanHalfFovX + P[1, 1] = 1 / tanHalfFovY + P[3, 2] = z_sign + P[2, 2] = z_sign * zfar / (zfar - znear) + P[2, 3] = -(zfar * znear) / (zfar - znear) + return P + + +class MiniCam: + def __init__(self, c2w, width, height, fovy, fovx, znear, zfar): + # c2w (pose) should be in NeRF convention. + + self.image_width = width + self.image_height = height + self.FoVy = fovy + self.FoVx = fovx + self.znear = znear + self.zfar = zfar + + w2c = np.linalg.inv(c2w) + + # rectify... + w2c[1:3, :3] *= -1 + w2c[:3, 3] *= -1 + + self.world_view_transform = torch.tensor(w2c).transpose(0, 1)#.cuda() + self.projection_matrix = ( + getProjectionMatrix( + znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy + ) + .transpose(0, 1) + # .cuda() + ) + self.full_proj_transform = self.world_view_transform @ self.projection_matrix + self.camera_center = -torch.tensor(c2w[:3, 3])#.cuda() diff --git a/scene/dataset_readers.py b/scene/dataset_readers.py new file mode 100644 index 0000000..ce6aba3 --- /dev/null +++ b/scene/dataset_readers.py @@ -0,0 +1,481 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import os +import sys +from PIL import Image +from typing import NamedTuple +from scene.colmap_loader import read_extrinsics_text, read_intrinsics_text, qvec2rotmat, \ + read_extrinsics_binary, read_intrinsics_binary, read_points3D_binary, read_points3D_text +from scene.hyper_loader import Load_hyper_data, format_hyper_data +import torchvision.transforms as transforms +import copy +from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal +import numpy as np +import torch +import json +from pathlib import Path +from plyfile import PlyData, PlyElement +from utils.sh_utils import SH2RGB +from scene.gaussian_model import BasicPointCloud +from utils.general_utils import PILtoTorch +from tqdm import tqdm +class CameraInfo(NamedTuple): + uid: int + R: np.array + T: np.array + FovY: np.array + FovX: np.array + image: np.array + image_path: str + image_name: str + width: int + height: int + time : float + +class SceneInfo(NamedTuple): + point_cloud: BasicPointCloud + train_cameras: list + test_cameras: list + video_cameras: list + nerf_normalization: dict + ply_path: str + maxtime: int + +def getNerfppNorm(cam_info): + def get_center_and_diag(cam_centers): + cam_centers = np.hstack(cam_centers) + avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True) + center = avg_cam_center + dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True) + diagonal = np.max(dist) + return center.flatten(), diagonal + + cam_centers = [] + + for cam in cam_info: + W2C = getWorld2View2(cam.R, cam.T) + C2W = np.linalg.inv(W2C) + cam_centers.append(C2W[:3, 3:4]) + + center, diagonal = get_center_and_diag(cam_centers) + radius = diagonal * 1.1 + + translate = -center + + return {"translate": translate, "radius": radius} + +def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder): + cam_infos = [] + for idx, key in enumerate(cam_extrinsics): + sys.stdout.write('\r') + # the exact output you're looking for: + sys.stdout.write("Reading camera {}/{}".format(idx+1, len(cam_extrinsics))) + sys.stdout.flush() + + extr = cam_extrinsics[key] + intr = cam_intrinsics[extr.camera_id] + height = intr.height + width = intr.width + + uid = intr.id + R = np.transpose(qvec2rotmat(extr.qvec)) + T = np.array(extr.tvec) + + if intr.model in ["SIMPLE_PINHOLE", "SIMPLE_RADIAL"]: + focal_length_x = intr.params[0] + FovY = focal2fov(focal_length_x, height) + FovX = focal2fov(focal_length_x, width) + elif intr.model=="PINHOLE": + focal_length_x = intr.params[0] + focal_length_y = intr.params[1] + FovY = focal2fov(focal_length_y, height) + FovX = focal2fov(focal_length_x, width) + elif intr.model == "OPENCV": + focal_length_x = intr.params[0] + focal_length_y = intr.params[1] + FovY = focal2fov(focal_length_y, height) + FovX = focal2fov(focal_length_x, width) + else: + assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!" + + image_path = os.path.join(images_folder, os.path.basename(extr.name)) + image_name = os.path.basename(image_path).split(".")[0] + image = Image.open(image_path) + image = PILtoTorch(image,None) + cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image, + image_path=image_path, image_name=image_name, width=width, height=height, + time = 0) + cam_infos.append(cam_info) + sys.stdout.write('\n') + return cam_infos + +def fetchPly(path): + plydata = PlyData.read(path) + vertices = plydata['vertex'] + positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T + colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0 + normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T + return BasicPointCloud(points=positions, colors=colors, normals=normals) + +def storePly(path, xyz, rgb): + # Define the dtype for the structured array + dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), + ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'), + ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')] + + normals = np.zeros_like(xyz) + + elements = np.empty(xyz.shape[0], dtype=dtype) + attributes = np.concatenate((xyz, normals, rgb), axis=1) + elements[:] = list(map(tuple, attributes)) + + # Create the PlyData object and write to file + vertex_element = PlyElement.describe(elements, 'vertex') + ply_data = PlyData([vertex_element]) + ply_data.write(path) + +def readColmapSceneInfo(path, images, eval, llffhold=8): + try: + cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin") + cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin") + cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file) + cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file) + except: + cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt") + cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt") + cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file) + cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file) + + reading_dir = "images" if images == None else images + cam_infos_unsorted = readColmapCameras(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, images_folder=os.path.join(path, reading_dir)) + cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : x.image_name) + + if eval: + train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0] + test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0] + else: + train_cam_infos = cam_infos + test_cam_infos = [] + + nerf_normalization = getNerfppNorm(train_cam_infos) + + ply_path = os.path.join(path, "sparse/0/points3D.ply") + bin_path = os.path.join(path, "sparse/0/points3D.bin") + txt_path = os.path.join(path, "sparse/0/points3D.txt") + if not os.path.exists(ply_path): + print("Converting point3d.bin to .ply, will happen only the first time you open the scene.") + try: + xyz, rgb, _ = read_points3D_binary(bin_path) + except: + xyz, rgb, _ = read_points3D_text(txt_path) + storePly(ply_path, xyz, rgb) + + try: + pcd = fetchPly(ply_path) + + except: + pcd = None + + scene_info = SceneInfo(point_cloud=pcd, + train_cameras=train_cam_infos, + test_cameras=test_cam_infos, + video_cameras=train_cam_infos, + maxtime=0, + nerf_normalization=nerf_normalization, + ply_path=ply_path) + return scene_info +def generateCamerasFromTransforms(path, template_transformsfile, extension, maxtime): + trans_t = lambda t : torch.Tensor([ + [1,0,0,0], + [0,1,0,0], + [0,0,1,t], + [0,0,0,1]]).float() + + rot_phi = lambda phi : torch.Tensor([ + [1,0,0,0], + [0,np.cos(phi),-np.sin(phi),0], + [0,np.sin(phi), np.cos(phi),0], + [0,0,0,1]]).float() + + rot_theta = lambda th : torch.Tensor([ + [np.cos(th),0,-np.sin(th),0], + [0,1,0,0], + [np.sin(th),0, np.cos(th),0], + [0,0,0,1]]).float() + def pose_spherical(theta, phi, radius): + c2w = trans_t(radius) + c2w = rot_phi(phi/180.*np.pi) @ c2w + c2w = rot_theta(theta/180.*np.pi) @ c2w + c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w + return c2w + cam_infos = [] + # generate render poses and times + render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0) + render_times = torch.linspace(0,maxtime,render_poses.shape[0]) + with open(os.path.join(path, template_transformsfile)) as json_file: + template_json = json.load(json_file) + fovx = template_json["camera_angle_x"] + # load a single image to get image info. + for idx, frame in enumerate(template_json["frames"]): + cam_name = os.path.join(path, frame["file_path"] + extension) + image_path = os.path.join(path, cam_name) + image_name = Path(cam_name).stem + image = Image.open(image_path) + im_data = np.array(image.convert("RGBA")) + image = PILtoTorch(image,(800,800)) + break + # format information + for idx, (time, poses) in enumerate(zip(render_times,render_poses)): + time = time/maxtime + matrix = np.linalg.inv(np.array(poses)) + R = -np.transpose(matrix[:3,:3]) + R[:,0] = -R[:,0] + T = -matrix[:3, 3] + fovy = focal2fov(fov2focal(fovx, image.shape[1]), image.shape[2]) + FovY = fovy + FovX = fovx + cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image, + image_path=None, image_name=None, width=image.shape[1], height=image.shape[2], + time = time)) + return cam_infos +def readCamerasFromTransforms(path, transformsfile, white_background, extension=".png", mapper = {}): + cam_infos = [] + + with open(os.path.join(path, transformsfile)) as json_file: + contents = json.load(json_file) + fovx = contents["camera_angle_x"] + + frames = contents["frames"] + for idx, frame in enumerate(frames): + cam_name = os.path.join(path, frame["file_path"] + extension) + time = mapper[frame["time"]] + matrix = np.linalg.inv(np.array(frame["transform_matrix"])) + R = -np.transpose(matrix[:3,:3]) + R[:,0] = -R[:,0] + T = -matrix[:3, 3] + + image_path = os.path.join(path, cam_name) + image_name = Path(cam_name).stem + image = Image.open(image_path) + + im_data = np.array(image.convert("RGBA")) + + bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0]) + + norm_data = im_data / 255.0 + arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4]) + image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB") + image = PILtoTorch(image,(800,800)) + fovy = focal2fov(fov2focal(fovx, image.shape[1]), image.shape[2]) + FovY = fovy + FovX = fovx + + cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image, + image_path=image_path, image_name=image_name, width=image.shape[1], height=image.shape[2], + time = time)) + + return cam_infos +def read_timeline(path): + with open(os.path.join(path, "transforms_train.json")) as json_file: + train_json = json.load(json_file) + with open(os.path.join(path, "transforms_test.json")) as json_file: + test_json = json.load(json_file) + time_line = [frame["time"] for frame in train_json["frames"]] + [frame["time"] for frame in test_json["frames"]] + time_line = set(time_line) + time_line = list(time_line) + time_line.sort() + timestamp_mapper = {} + max_time_float = max(time_line) + for index, time in enumerate(time_line): + # timestamp_mapper[time] = index + timestamp_mapper[time] = time/max_time_float + + return timestamp_mapper, max_time_float +def readNerfSyntheticInfo(path, white_background, eval, extension=".png"): + timestamp_mapper, max_time = read_timeline(path) + print("Reading Training Transforms") + train_cam_infos = readCamerasFromTransforms(path, "transforms_train.json", white_background, extension, timestamp_mapper) + print("Reading Test Transforms") + test_cam_infos = readCamerasFromTransforms(path, "transforms_test.json", white_background, extension, timestamp_mapper) + print("Generating Video Transforms") + video_cam_infos = generateCamerasFromTransforms(path, "transforms_train.json", extension, max_time) + if not eval: + train_cam_infos.extend(test_cam_infos) + test_cam_infos = [] + + nerf_normalization = getNerfppNorm(train_cam_infos) + + ply_path = os.path.join(path, "points3d.ply") + # Since this data set has no colmap data, we start with random points + num_pts = 2000 + print(f"Generating random point cloud ({num_pts})...") + + # We create random points inside the bounds of the synthetic Blender scenes + xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3 + shs = np.random.random((num_pts, 3)) / 255.0 + pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3))) + storePly(ply_path, xyz, SH2RGB(shs) * 255) + try: + pcd = fetchPly(ply_path) + except: + pcd = None + + scene_info = SceneInfo(point_cloud=pcd, + train_cameras=train_cam_infos, + test_cameras=test_cam_infos, + video_cameras=video_cam_infos, + nerf_normalization=nerf_normalization, + ply_path=ply_path, + maxtime=max_time + ) + return scene_info +def format_infos(dataset,split): + # loading + cameras = [] + image = dataset[0][0] + if split == "train": + for idx in tqdm(range(len(dataset))): + image_path = None + image_name = f"{idx}" + time = dataset.image_times[idx] + # matrix = np.linalg.inv(np.array(pose)) + R,T = dataset.load_pose(idx) + FovX = focal2fov(dataset.focal[0], image.shape[1]) + FovY = focal2fov(dataset.focal[0], image.shape[2]) + cameras.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image, + image_path=image_path, image_name=image_name, width=image.shape[2], height=image.shape[1], + time = time)) + + return cameras + + +def readHyperDataInfos(datadir,use_bg_points,eval): + train_cam_infos = Load_hyper_data(datadir,0.5,use_bg_points,split ="train") + test_cam_infos = Load_hyper_data(datadir,0.5,use_bg_points,split="test") + + train_cam = format_hyper_data(train_cam_infos,"train") + max_time = train_cam_infos.max_time + video_cam_infos = copy.deepcopy(test_cam_infos) + video_cam_infos.split="video" + + ply_path = os.path.join(datadir, "points.npy") + + xyz = np.load(ply_path,allow_pickle=True) + xyz -= train_cam_infos.scene_center + xyz *= train_cam_infos.coord_scale + xyz = xyz.astype(np.float32) + shs = np.random.random((xyz.shape[0], 3)) / 255.0 + pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((xyz.shape[0], 3))) + + + nerf_normalization = getNerfppNorm(train_cam) + + scene_info = SceneInfo(point_cloud=pcd, + train_cameras=train_cam_infos, + test_cameras=test_cam_infos, + video_cameras=video_cam_infos, + nerf_normalization=nerf_normalization, + ply_path=ply_path, + maxtime=max_time + ) + + return scene_info +def format_render_poses(poses,data_infos): + cameras = [] + tensor_to_pil = transforms.ToPILImage() + len_poses = len(poses) + times = [i/len_poses for i in range(len_poses)] + image = data_infos[0][0] + for idx, p in tqdm(enumerate(poses)): + # image = None + image_path = None + image_name = f"{idx}" + time = times[idx] + pose = np.eye(4) + pose[:3,:] = p[:3,:] + # matrix = np.linalg.inv(np.array(pose)) + R = pose[:3,:3] + R = - R + R[:,0] = -R[:,0] + T = -pose[:3,3].dot(R) + FovX = focal2fov(data_infos.focal[0], image.shape[2]) + FovY = focal2fov(data_infos.focal[0], image.shape[1]) + cameras.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image, + image_path=image_path, image_name=image_name, width=image.shape[2], height=image.shape[1], + time = time)) + return cameras + + +def readdynerfInfo(datadir,use_bg_points,eval): + # loading all the data follow hexplane format + ply_path = os.path.join(datadir, "points3d.ply") + + from scene.neural_3D_dataset_NDC import Neural3D_NDC_Dataset + train_dataset = Neural3D_NDC_Dataset( + datadir, + "train", + 1.0, + time_scale=1, + scene_bbox_min=[-2.5, -2.0, -1.0], + scene_bbox_max=[2.5, 2.0, 1.0], + eval_index=0, + ) + test_dataset = Neural3D_NDC_Dataset( + datadir, + "test", + 1.0, + time_scale=1, + scene_bbox_min=[-2.5, -2.0, -1.0], + scene_bbox_max=[2.5, 2.0, 1.0], + eval_index=0, + ) + train_cam_infos = format_infos(train_dataset,"train") + + # test_cam_infos = format_infos(test_dataset,"test") + val_cam_infos = format_render_poses(test_dataset.val_poses,test_dataset) + nerf_normalization = getNerfppNorm(train_cam_infos) + # create pcd + # if not os.path.exists(ply_path): + # Since this data set has no colmap data, we start with random points + num_pts = 2000 + print(f"Generating random point cloud ({num_pts})...") + threshold = 3 + # xyz_max = np.array([1.5*threshold, 1.5*threshold, 1.5*threshold]) + # xyz_min = np.array([-1.5*threshold, -1.5*threshold, -3*threshold]) + xyz_max = np.array([1.5*threshold, 1.5*threshold, 1.5*threshold]) + xyz_min = np.array([-1.5*threshold, -1.5*threshold, -1.5*threshold]) + # We create random points inside the bounds of the synthetic Blender scenes + xyz = (np.random.random((num_pts, 3)))* (xyz_max-xyz_min) + xyz_min + print("point cloud initialization:",xyz.max(axis=0),xyz.min(axis=0)) + shs = np.random.random((num_pts, 3)) / 255.0 + pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3))) + storePly(ply_path, xyz, SH2RGB(shs) * 255) + try: + # xyz = np.load + pcd = fetchPly(ply_path) + except: + pcd = None + scene_info = SceneInfo(point_cloud=pcd, + train_cameras=train_dataset, + test_cameras=test_dataset, + video_cameras=val_cam_infos, + nerf_normalization=nerf_normalization, + ply_path=ply_path, + maxtime=300 + ) + return scene_info +sceneLoadTypeCallbacks = { + "Colmap": readColmapSceneInfo, + "Blender" : readNerfSyntheticInfo, + "dynerf" : readdynerfInfo, + "nerfies": readHyperDataInfos, # NeRFies & HyperNeRF dataset proposed by [https://github.com/google/hypernerf/releases/tag/v0.1] +} diff --git a/scene/deformation.py b/scene/deformation.py new file mode 100644 index 0000000..099ef6e --- /dev/null +++ b/scene/deformation.py @@ -0,0 +1,255 @@ +import functools +import math +import os +import time +# from tkinter import W + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.cpp_extension import load +import torch.nn.init as init +from collections import OrderedDict +from scene.hexplane import HexPlaneField + +class Deformation(nn.Module): + def __init__(self, D=8, W=256, input_ch=27, input_ch_time=9, skips=[], args=None): + super(Deformation, self).__init__() + self.D = D + self.W = W + self.input_ch = input_ch + self.input_ch_time = input_ch_time + self.skips = skips + self.grid_merge = args.grid_merge + + self.no_grid = args.no_grid + self.grid = HexPlaneField(args.bounds, args.kplanes_config, args.multires, grid_merge=args.grid_merge) + self.pos_deform, self.scales_deform, self.rotations_deform, self.opacity_deform ,self.color_deform= self.create_net() + # self.pos_deform.fc1.weight.data.zero_() + # self.pos_deform.fc1.bias.data.zero_() + # self.scales_deform.fc1.weight.data.zero_() + # self.scales_deform.fc1.bias.data.zero_() + # self.rotations_deform.fc1.weight.data.zero_() + # self.rotations_deform.fc1.bias.data.zero_() + # self.opacity_deform.fc1.weight.data.zero_() + # self.opacity_deform.fc1.bias.data.zero_() + # self.color_deform.fc1.weight.data.zero_() + # self.color_deform.fc1.bias.data.zero_() + + self.args = args + def create_net(self): + + mlp_out_dim = 0 + if self.no_grid: + self.feature_out = [nn.Linear(4,self.W)] + else: + if self.grid_merge == 'cat': + self.feature_out = [nn.Linear(mlp_out_dim + self.grid.feat_dim * 6, self.W)] + else: + self.feature_out = [nn.Linear(mlp_out_dim + self.grid.feat_dim, self.W)] + + for i in range(self.D-1): + self.feature_out.append(nn.ReLU()) + self.feature_out.append(nn.Linear(self.W,self.W)) + self.feature_out = nn.Sequential(*self.feature_out) + output_dim = self.W + # pose, scale, rotation, opacity + return \ + nn.Sequential( + OrderedDict([ + ('act0', nn.ReLU()), + ('fc2', nn.Linear(self.W, self.W)), + ('act3', nn.ReLU()), + ('fc1', nn.Linear(self.W, 3)), + ]) + # nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 3) + ),\ + nn.Sequential( + OrderedDict([ + ('act0', nn.ReLU()), + ('fc2', nn.Linear(self.W, self.W)), + ('act3', nn.ReLU()), + ('fc1', nn.Linear(self.W, 1)), + ]) + # nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 1) + ),\ + nn.Sequential( + OrderedDict([ + ('act0', nn.ReLU()), + ('fc2', nn.Linear(self.W, self.W)), + ('act3', nn.ReLU()), + ('fc1', nn.Linear(self.W, 4)), + ]) + # nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 4) + ), \ + nn.Sequential( + OrderedDict([ + ('act0', nn.ReLU()), + ('fc2', nn.Linear(self.W, self.W)), + ('act3', nn.ReLU()), + ('fc1', nn.Linear(self.W, 1)), + ]) + # nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 1) + ),\ + nn.Sequential( + OrderedDict([ + ('act0', nn.ReLU()), + ('fc2', nn.Linear(self.W, self.W)), + ('act3', nn.ReLU()), + ('fc1', nn.Linear(self.W, 3)), + ('act4',nn.Tanh()) + ]) + # nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 1) + ) + + def query_time(self, rays_pts_emb, scales_emb, rotations_emb, time_emb): + + if self.no_grid: + h = torch.cat([rays_pts_emb[:,:3],time_emb[:,:1]],-1) + else: + grid_feature = self.grid(rays_pts_emb[:,:3], time_emb[:,:1]) + + h = grid_feature + + h = self.feature_out(h) + + return h + + def forward(self, rays_pts_emb, scales_emb=None, rotations_emb=None, opacity = None,color=None, time_emb=None): + # if time_emb.sum() == 0: + # # if time_emb is None: + # return self.forward_static(rays_pts_emb[:,:3], scales_emb, rotations_emb, opacity, time_emb) + # # return self.forward_static(rays_pts_emb[:,:3]) + # else: + return self.forward_dynamic(rays_pts_emb, scales_emb, rotations_emb, opacity, color,time_emb) + + def forward_static(self, pts, scales, rotations, opacity, time): + # def forward_static(self, rays_pts_emb): + return pts, scales, rotations, opacity + # print('??????? forward_static') + # grid_feature = self.grid(rays_pts_emb[:,:3]) + # dx = self.static_mlp(grid_feature) + # return rays_pts_emb[:, :3] + dx + def forward_dynamic(self,rays_pts_emb, scales_emb, rotations_emb, opacity_emb,color_emb, time_emb): + hidden = self.query_time(rays_pts_emb, scales_emb, rotations_emb, time_emb).float() + dx = self.pos_deform(hidden) + pts = rays_pts_emb[:, :3] + dx + # print(scales_emb.shape, rotations_emb.shape, opacity_emb.shape, time_emb.shape) + # print('no_ds', self.args.no_ds, self.args.no_dr, self.args.no_do) + if self.args.no_ds: + scales = scales_emb[:,:3] + else: + ds = self.scales_deform(hidden) + scales = scales_emb[:,:3] + ds + + if self.args.no_dr: + rotations = rotations_emb[:,:4] + else: + #print('dr======================================') + dr = self.rotations_deform(hidden) #[40000, 4] + rotations = rotations_emb[:,:4] + dr #([40000, 3]+[40000, 4]=[40000, 4] + # print('rotations_emb[:,:3] shape===',rotations_emb[:,:3].shape) + # print('dr shape=======',dr.shape) + # print('rotations shape=======',rotations.shape) + + if self.args.no_do: + opacity = opacity_emb[:,:1] + else: + do = self.opacity_deform(hidden) + opacity = opacity_emb[:,:1] + do + + if self.args.no_dc: + # print('no dc======================================') + color=color_emb[:,:3] + else: + # print('dc======================================') + # print('hidden shape=======',hidden.shape) + dc = self.color_deform(hidden) #[40000, 256]->[40000, 3] + color = color_emb[:,:3] + dc #[40000, 3]+[40000, 3] + # print('color_emb[:,:3] shape===',color_emb[:,:3].shape) + # print('dc shape=======',dc.shape) + # print('color shape=======',color.shape) + # hidden shape======= torch.Size([40000, 256]) [11/11 13:55:16] + # color_emb[:,:3] shape=== torch.Size([40000, 1, 3]) [11/11 13:55:16] + # dc shape======= torch.Size([40000, 3]) [11/11 13:55:16] + # color shape======= torch.Size([40000, 40000, 3]) [11/11 13:55:16] + # color_final shape torch.Size([40000, 1, 3]) [11/11 13:55:16] + # color_deform shape torch.Size([40000, 40000, 3]) [11/11 13:55:16] + # deformation_point shape torch.Size([40000]) [11/11 13:55:16] + # + do + # print("deformation value:","pts:",torch.abs(dx).mean(),"rotation:",torch.abs(dr).mean()) + + return pts, scales, rotations, opacity,color + def get_mlp_parameters(self): + parameter_list = [] + for name, param in self.named_parameters(): + if "grid" not in name: + parameter_list.append(param) + return parameter_list + def get_grid_parameters(self): + return list(self.grid.parameters() ) + # + list(self.timegrid.parameters()) +class deform_network(nn.Module): + def __init__(self, args) : + super(deform_network, self).__init__() + net_width = args.net_width + timebase_pe = args.timebase_pe + defor_depth= args.defor_depth + posbase_pe= args.posebase_pe + scale_rotation_pe = args.scale_rotation_pe + opacity_pe = args.opacity_pe + timenet_width = args.timenet_width + timenet_output = args.timenet_output + times_ch = 2*timebase_pe+1 + # self.timenet = nn.Sequential( + # nn.Linear(times_ch, timenet_width), nn.ReLU(), + # nn.Linear(timenet_width, timenet_output)) + self.deformation_net = Deformation(W=net_width, D=defor_depth, input_ch=(4+3)+((4+3)*scale_rotation_pe)*2, input_ch_time=timenet_output, args=args) + # self.register_buffer('time_poc', torch.FloatTensor([(2**i) for i in range(timebase_pe)])) + self.register_buffer('pos_poc', torch.FloatTensor([(2**i) for i in range(posbase_pe)])) + self.register_buffer('rotation_scaling_poc', torch.FloatTensor([(2**i) for i in range(scale_rotation_pe)])) + self.register_buffer('opacity_poc', torch.FloatTensor([(2**i) for i in range(opacity_pe)])) + self.apply(initialize_weights) + # print(self) + + def forward(self, point, scales=None, rotations=None, opacity=None,color=None, times_sel=None): + # raise NotImplementedError + # print('>>>>> time', times_sel) + if times_sel is not None: + means3D_, scales_, rotations_, opacity_ ,color_= self.forward_dynamic(point, scales, rotations, opacity,color, times_sel) + # return means3D_, scales, rotations, opacity + return means3D_, scales_, rotations_, opacity_,color_ + else: + raise NotImplementedError + return self.forward_static(point) + + + def forward_static(self, points): + points = self.deformation_net(points) + return points + def forward_dynamic(self, point, scales=None, rotations=None, opacity=None,color=None, times_sel=None): + # times_emb = poc_fre(times_sel, self.time_poc) + + means3D, scales, rotations, opacity,color = self.deformation_net( point, + scales, + rotations, + opacity, + color, + # times_feature, + times_sel) + return means3D, scales, rotations, opacity,color + def get_mlp_parameters(self): + return self.deformation_net.get_mlp_parameters() + # + list(self.timenet.parameters()) + def get_grid_parameters(self): + return self.deformation_net.get_grid_parameters() + +def initialize_weights(m): + pass + # if isinstance(m, nn.Linear): + # init.constant_(m.weight, 0) + # # init.xavier_uniform_(m.weight,gain=1) + # if m.bias is not None: + # # init.xavier_uniform_(m.weight,gain=1) + # init.constant_(m.bias, 0) diff --git a/scene/gaussian_model.py b/scene/gaussian_model.py new file mode 100644 index 0000000..f56607e --- /dev/null +++ b/scene/gaussian_model.py @@ -0,0 +1,625 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +import numpy as np +from utils.general_utils import inverse_sigmoid, get_expon_lr_func, build_rotation +from torch import nn +import os +from utils.system_utils import mkdir_p +from plyfile import PlyData, PlyElement +from random import randint +from utils.sh_utils import RGB2SH +from simple_knn._C import distCUDA2 +from utils.graphics_utils import BasicPointCloud +from utils.general_utils import strip_symmetric, build_scaling_rotation +from scene.deformation import deform_network +from scene.regulation import compute_plane_smoothness + + +def sh2rgb(x): + return x * 0.28209479177387814 + 0.5 + +class GaussianModel: + + def setup_functions(self): + def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation): + L = build_scaling_rotation(scaling_modifier * scaling, rotation) + actual_covariance = L @ L.transpose(1, 2) + symm = strip_symmetric(actual_covariance) + return symm + + self.scaling_activation = torch.exp + self.scaling_inverse_activation = torch.log + + self.covariance_activation = build_covariance_from_scaling_rotation + + self.opacity_activation = torch.sigmoid + self.inverse_opacity_activation = inverse_sigmoid + + self.rotation_activation = torch.nn.functional.normalize + + + def __init__(self, sh_degree : int, args): + self.active_sh_degree = 0 + self.max_sh_degree = sh_degree + self._xyz = torch.empty(0) + # self._deformation = torch.empty(0) + self._deformation = deform_network(args) + # self.grid = TriPlaneGrid() + self._features_dc = torch.empty(0) + self._features_rest = torch.empty(0) + self._scaling = torch.empty(0) + self._rotation = torch.empty(0) + self._opacity = torch.empty(0) + self.max_radii2D = torch.empty(0) + self.xyz_gradient_accum = torch.empty(0) + self.denom = torch.empty(0) + self.optimizer = None + self.percent_dense = 0 + self.spatial_lr_scale = 0 + self._deformation_table = torch.empty(0) + self.setup_functions() + + def capture(self): + return ( + self.active_sh_degree, + self._xyz, + self._deformation.state_dict(), + self._deformation_table, + # self.grid, + self._features_dc, + self._features_rest, + self._scaling, + self._rotation, + self._opacity, + self.max_radii2D, + self.xyz_gradient_accum, + self.denom, + self.optimizer.state_dict(), + self.spatial_lr_scale, + ) + + def restore(self, model_args, training_args): + (self.active_sh_degree, + self._xyz, + self._deformation_table, + self._deformation, + # self.grid, + self._features_dc, + self._features_rest, + self._scaling, + self._rotation, + self._opacity, + self.max_radii2D, + xyz_gradient_accum, + denom, + opt_dict, + self.spatial_lr_scale) = model_args + self.training_setup(training_args) + self.xyz_gradient_accum = xyz_gradient_accum + self.denom = denom + self.optimizer.load_state_dict(opt_dict) + + @property + def get_scaling(self): + #return self._scaling + + return self.scaling_activation(self._scaling) + + @property + def get_rotation(self): + #return self._rotation + return self.rotation_activation(self._rotation) + + @property + def get_xyz(self): + return self._xyz + + @property + def get_features(self): + features_dc = self._features_dc + features_rest = self._features_rest + return torch.cat((features_dc, features_rest), dim=1) + + @property + def get_features_dc(self): + features_dc = self._features_dc + return features_dc + + @property + def get_features_rest(self): + features_rest = self._features_rest + return features_rest + + + + @property + def get_opacity(self): + return self.opacity_activation(self._opacity) + + def get_covariance(self, scaling_modifier = 1): + return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation) + + def oneupSHdegree(self): + if self.active_sh_degree < self.max_sh_degree: + self.active_sh_degree += 1 + + def load_colmap_ply(self, path, spatial_lr_scale=1, time_line=4): + # https://github.com/graphdeco-inria/gaussian-splatting/blob/f11001b46c5c73a0a7d553353c898efd68412abe/scene/dataset_readers.py#L107 + plydata = PlyData.read(path) + vertices = plydata['vertex'] + positions = np.vstack([vertices['y'], vertices['z'], vertices['x']]).T # [N, 3] + # positions = positions[::2] + print('Loaded points from ply ', positions.shape) + colors = np.zeros_like(positions) + 0.5 + pcd = BasicPointCloud(points=positions, colors=colors, normals=None) + self.create_from_pcd(pcd, spatial_lr_scale=spatial_lr_scale, time_line=time_line) + + def load_3studio_ply(self, path, spatial_lr_scale=1, time_line=4, step=1, position_scale=1, load_color=True): + # https://github.com/graphdeco-inria/gaussian-splatting/blob/f11001b46c5c73a0a7d553353c898efd68412abe/scene/dataset_readers.py#L107 + plydata = PlyData.read(path) + vertices = plydata['vertex'] + positions = np.vstack([vertices['x'], vertices['z'], -vertices['y']]).T # [N, 3] # image dream axis + # positions = np.vstack([vertices['y'], vertices['z'], vertices['x']]).T # [N, 3] # 3studio coord is this + positions = positions[::step] * position_scale + print('Loaded points from ply ', positions.shape) + # positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T # [N, 3] + # positions=np.concatenate([-positions[:,0:1],-positions[:,1:2],-positions[:,2:3]],1)#*train_dataset.scale_factor) + if load_color: + colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0 + else: + colors = np.zeros_like(positions) + 0.5 + + colors = colors[::step] + # normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T + pcd = BasicPointCloud(points=positions, colors=colors, normals=None) + self.create_from_pcd(pcd, spatial_lr_scale=spatial_lr_scale, time_line=time_line) + + + def random_init(self, num_pts, lr=10, radius=1): + phis = np.random.random((num_pts,)) * 2 * np.pi + costheta = np.random.random((num_pts,)) * 2 - 1 + thetas = np.arccos(costheta) + mu = np.random.random((num_pts,)) + radius = radius * np.cbrt(mu) + x = radius * np.sin(thetas) * np.cos(phis) + y = radius * np.sin(thetas) * np.sin(phis) + z = radius * np.cos(thetas) + xyz = np.stack((x, y, z), axis=1) + # xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3 + + shs = np.random.random((num_pts, 3)) / 255.0 + pcd = BasicPointCloud( + points=xyz, colors=sh2rgb(shs), normals=np.zeros((num_pts, 3)) + ) + self.create_from_pcd(pcd, lr, 4) # 4 not used + + + def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float, time_line: int): + self.spatial_lr_scale = spatial_lr_scale + fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda() + fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda()) + features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda() + features[:, :3, 0 ] = fused_color + features[:, 3:, 1:] = 0.0 + + print("Number of points at initialisation : ", fused_point_cloud.shape[0]) + + dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001) + # scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 1) + scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3) + #scales = torch.ones_like(scales ) * 0.03 + rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda") + rots[:, 0] = 1 + + opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda")) + + self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True)) + self._deformation = self._deformation.to("cuda") + # self.grid = self.grid.to("cuda") + self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True)) + self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True)) + self._scaling = nn.Parameter(scales.requires_grad_(True)) + self._rotation = nn.Parameter(rots.requires_grad_(True)) + self._opacity = nn.Parameter(opacities.requires_grad_(True)) + self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") + self._deformation_table = torch.gt(torch.ones((self.get_xyz.shape[0]),device="cuda"),0) + def training_setup(self, training_args): + self.percent_dense = training_args.percent_dense + self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") + self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") + self._deformation_accum = torch.zeros((self.get_xyz.shape[0],3),device="cuda") + + + l = [ + {'params': [self._xyz], 'lr': training_args.position_lr_init * self.spatial_lr_scale, "name": "xyz"}, + {'params': list(self._deformation.get_mlp_parameters()), 'lr': training_args.deformation_lr_init * self.spatial_lr_scale, "name": "deformation"}, + {'params': list(self._deformation.get_grid_parameters()), 'lr': training_args.grid_lr_init * self.spatial_lr_scale, "name": "grid"}, + {'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"}, + {'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, "name": "f_rest"}, + {'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"}, + {'params': [self._scaling], 'lr': training_args.scaling_lr, "name": "scaling"}, + {'params': [self._rotation], 'lr': training_args.rotation_lr, "name": "rotation"} + + ] + + self.optimizer = torch.optim.Adam(l, lr=0.0) + # self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15) + self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init*self.spatial_lr_scale, + lr_final=training_args.position_lr_final*self.spatial_lr_scale, + lr_delay_mult=training_args.position_lr_delay_mult, + max_steps=training_args.position_lr_max_steps) + self.deformation_scheduler_args = get_expon_lr_func(lr_init=training_args.deformation_lr_init*self.spatial_lr_scale, + lr_final=training_args.deformation_lr_final*self.spatial_lr_scale, + lr_delay_mult=training_args.deformation_lr_delay_mult, + max_steps=training_args.position_lr_max_steps) + self.grid_scheduler_args = get_expon_lr_func(lr_init=training_args.grid_lr_init*self.spatial_lr_scale, + lr_final=training_args.grid_lr_final*self.spatial_lr_scale, + lr_delay_mult=training_args.deformation_lr_delay_mult, + max_steps=training_args.position_lr_max_steps) + + def update_learning_rate(self, iteration): + ''' Learning rate scheduling per step ''' + for param_group in self.optimizer.param_groups: + if param_group["name"] == "xyz": + lr = self.xyz_scheduler_args(iteration) + param_group['lr'] = lr + # return lr + if "grid" in param_group["name"]: + lr = self.grid_scheduler_args(iteration) + param_group['lr'] = lr + # return lr + elif param_group["name"] == "deformation": + lr = self.deformation_scheduler_args(iteration) + param_group['lr'] = lr + # return lr + + def construct_list_of_attributes(self): + l = ['x', 'y', 'z', 'nx', 'ny', 'nz'] + # All channels except the 3 DC + for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]): + l.append('f_dc_{}'.format(i)) + for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]): + l.append('f_rest_{}'.format(i)) + l.append('opacity') + for i in range(self._scaling.shape[1]): + l.append('scale_{}'.format(i)) + for i in range(self._rotation.shape[1]): + l.append('rot_{}'.format(i)) + return l + # def compute_deformation(self,time): + + # deform = self._deformation[:,:,:time].sum(dim=-1) + # xyz = self._xyz + deform + # return xyz + # def save_ply_dynamic(path): + # for time in range(self._deformation.shape(-1)): + # xyz = self.compute_deformation(time) + def load_model(self, path): + print("loading model from exists{}".format(path)) + weight_dict = torch.load(os.path.join(path,"deformation.pth"),map_location="cuda") + self._deformation.load_state_dict(weight_dict) + self._deformation = self._deformation.to("cuda") + self._deformation_table = torch.gt(torch.ones((self.get_xyz.shape[0]),device="cuda"),0) + self._deformation_accum = torch.zeros((self.get_xyz.shape[0],3),device="cuda") + if os.path.exists(os.path.join(path, "deformation_table.pth")): + self._deformation_table = torch.load(os.path.join(path, "deformation_table.pth"),map_location="cuda") + if os.path.exists(os.path.join(path, "deformation_accum.pth")): + self._deformation_accum = torch.load(os.path.join(path, "deformation_accum.pth"),map_location="cuda") + self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") + # print(self._deformation.deformation_net.grid.) + def save_deformation(self, path): + torch.save(self._deformation.state_dict(),os.path.join(path, "deformation.pth")) + torch.save(self._deformation_table,os.path.join(path, "deformation_table.pth")) + torch.save(self._deformation_accum,os.path.join(path, "deformation_accum.pth")) + def save_ply(self, path): + mkdir_p(os.path.dirname(path)) + + xyz = self._xyz.detach().cpu().numpy() + normals = np.zeros_like(xyz) + f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() + f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() + opacities = self._opacity.detach().cpu().numpy() + scale = self._scaling.detach().cpu().numpy() + rotation = self._rotation.detach().cpu().numpy() + + dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()] + + elements = np.empty(xyz.shape[0], dtype=dtype_full) + attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1) + elements[:] = list(map(tuple, attributes)) + el = PlyElement.describe(elements, 'vertex') + PlyData([el]).write(path) + + def reset_opacity(self): + opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity)*0.01)) + optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity") + self._opacity = optimizable_tensors["opacity"] + + def load_ply(self, path): + plydata = PlyData.read(path) + + xyz = np.stack((np.asarray(plydata.elements[0]["x"]), + np.asarray(plydata.elements[0]["y"]), + np.asarray(plydata.elements[0]["z"])), axis=1) + opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] + + features_dc = np.zeros((xyz.shape[0], 3, 1)) + features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) + features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"]) + features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"]) + + extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")] + extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1])) + assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3 + features_extra = np.zeros((xyz.shape[0], len(extra_f_names))) + for idx, attr_name in enumerate(extra_f_names): + features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name]) + # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC) + features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1)) + + scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")] + scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1])) + scales = np.zeros((xyz.shape[0], len(scale_names))) + for idx, attr_name in enumerate(scale_names): + scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) + + rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")] + rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1])) + rots = np.zeros((xyz.shape[0], len(rot_names))) + for idx, attr_name in enumerate(rot_names): + rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) + + self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True)) + self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True)) + self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True)) + self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True)) + self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True)) + self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True)) + self.active_sh_degree = self.max_sh_degree + + def replace_tensor_to_optimizer(self, tensor, name): + optimizable_tensors = {} + for group in self.optimizer.param_groups: + if group["name"] == name: + stored_state = self.optimizer.state.get(group['params'][0], None) + stored_state["exp_avg"] = torch.zeros_like(tensor) + stored_state["exp_avg_sq"] = torch.zeros_like(tensor) + + del self.optimizer.state[group['params'][0]] + group["params"][0] = nn.Parameter(tensor.requires_grad_(True)) + self.optimizer.state[group['params'][0]] = stored_state + + optimizable_tensors[group["name"]] = group["params"][0] + return optimizable_tensors + + def _prune_optimizer(self, mask): + optimizable_tensors = {} + for group in self.optimizer.param_groups: + if len(group["params"]) > 1: + continue + stored_state = self.optimizer.state.get(group['params'][0], None) + if stored_state is not None: + stored_state["exp_avg"] = stored_state["exp_avg"][mask] + stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask] + + del self.optimizer.state[group['params'][0]] + group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True))) + self.optimizer.state[group['params'][0]] = stored_state + + optimizable_tensors[group["name"]] = group["params"][0] + else: + group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True)) + optimizable_tensors[group["name"]] = group["params"][0] + return optimizable_tensors + + def prune_points(self, mask): + valid_points_mask = ~mask + optimizable_tensors = self._prune_optimizer(valid_points_mask) + + self._xyz = optimizable_tensors["xyz"] + self._features_dc = optimizable_tensors["f_dc"] + self._features_rest = optimizable_tensors["f_rest"] + self._opacity = optimizable_tensors["opacity"] + self._scaling = optimizable_tensors["scaling"] + self._rotation = optimizable_tensors["rotation"] + self._deformation_accum = self._deformation_accum[valid_points_mask] + self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask] + self._deformation_table = self._deformation_table[valid_points_mask] + self.denom = self.denom[valid_points_mask] + self.max_radii2D = self.max_radii2D[valid_points_mask] + + def cat_tensors_to_optimizer(self, tensors_dict): + optimizable_tensors = {} + for group in self.optimizer.param_groups: + if len(group["params"])>1:continue + assert len(group["params"]) == 1 + extension_tensor = tensors_dict[group["name"]] + stored_state = self.optimizer.state.get(group['params'][0], None) + if stored_state is not None: + + stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0) + stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0) + + del self.optimizer.state[group['params'][0]] + group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True)) + self.optimizer.state[group['params'][0]] = stored_state + + optimizable_tensors[group["name"]] = group["params"][0] + else: + group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True)) + optimizable_tensors[group["name"]] = group["params"][0] + + return optimizable_tensors + + def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation, new_deformation_table): + d = {"xyz": new_xyz, + "f_dc": new_features_dc, + "f_rest": new_features_rest, + "opacity": new_opacities, + "scaling" : new_scaling, + "rotation" : new_rotation, + # "deformation": new_deformation + } + + optimizable_tensors = self.cat_tensors_to_optimizer(d) + self._xyz = optimizable_tensors["xyz"] + self._features_dc = optimizable_tensors["f_dc"] + self._features_rest = optimizable_tensors["f_rest"] + self._opacity = optimizable_tensors["opacity"] + self._scaling = optimizable_tensors["scaling"] + self._rotation = optimizable_tensors["rotation"] + # self._deformation = optimizable_tensors["deformation"] + + self._deformation_table = torch.cat([self._deformation_table,new_deformation_table],-1) + self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") + self._deformation_accum = torch.zeros((self.get_xyz.shape[0], 3), device="cuda") + self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") + self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") + + def densify_and_split(self, grads, grad_threshold, scene_extent, N=2): + n_init_points = self.get_xyz.shape[0] + # Extract points that satisfy the gradient condition + padded_grad = torch.zeros((n_init_points), device="cuda") + padded_grad[:grads.shape[0]] = grads.squeeze() + print('split', padded_grad.mean(), grad_threshold) + selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False) + selected_pts_mask = torch.logical_and(selected_pts_mask, + torch.max(self.get_scaling, dim=1).values > self.percent_dense*scene_extent) + if not selected_pts_mask.any(): + return + stds = self.get_scaling[selected_pts_mask].repeat(N,1) + means =torch.zeros((stds.size(0), 3),device="cuda") + samples = torch.normal(mean=means, std=stds) + rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N,1,1) + new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1) + new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N,1) / (0.8*N)) + new_rotation = self._rotation[selected_pts_mask].repeat(N,1) + new_features_dc = self._features_dc[selected_pts_mask].repeat(N,1,1) + new_features_rest = self._features_rest[selected_pts_mask].repeat(N,1,1) + new_opacity = self._opacity[selected_pts_mask].repeat(N,1) + new_deformation_table = self._deformation_table[selected_pts_mask].repeat(N) + self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation, new_deformation_table) + + prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool))) + self.prune_points(prune_filter) + + def densify_and_clone(self, grads, grad_threshold, scene_extent): + # Extract points that satisfy the gradient condition + print('clone', torch.norm(grads, dim=-1).mean(), grad_threshold) + selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False) + selected_pts_mask = torch.logical_and(selected_pts_mask, + torch.max(self.get_scaling, dim=1).values <= self.percent_dense*scene_extent) + + new_xyz = self._xyz[selected_pts_mask] + # - 0.001 * self._xyz.grad[selected_pts_mask] + new_features_dc = self._features_dc[selected_pts_mask] + new_features_rest = self._features_rest[selected_pts_mask] + new_opacities = self._opacity[selected_pts_mask] + new_scaling = self._scaling[selected_pts_mask] + new_rotation = self._rotation[selected_pts_mask] + new_deformation_table = self._deformation_table[selected_pts_mask] + + self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation, new_deformation_table) + def prune(self, max_grad, min_opacity, extent, max_screen_size): + prune_mask = (self.get_opacity < min_opacity).squeeze() + # prune_mask_2 = torch.logical_and(self.get_opacity <= inverse_sigmoid(0.101 , dtype=torch.float, device="cuda"), self.get_opacity >= inverse_sigmoid(0.999 , dtype=torch.float, device="cuda")) + # prune_mask = torch.logical_or(prune_mask, prune_mask_2) + # deformation_sum = abs(self._deformation).sum(dim=-1).mean(dim=-1) + # deformation_mask = (deformation_sum < torch.quantile(deformation_sum, torch.tensor([0.5]).to("cuda"))) + # prune_mask = prune_mask & deformation_mask + if max_screen_size: + big_points_vs = self.max_radii2D > max_screen_size + big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent + prune_mask = torch.logical_or(prune_mask, big_points_vs) + + prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws) + self.prune_points(prune_mask) + + torch.cuda.empty_cache() + def densify(self, max_grad, min_opacity, extent, max_screen_size): + grads = self.xyz_gradient_accum / self.denom + grads[grads.isnan()] = 0.0 + + self.densify_and_clone(grads, max_grad, extent) + self.densify_and_split(grads, max_grad, extent) + def standard_constaint(self): + + means3D = self._xyz.detach() + scales = self._scaling.detach() + rotations = self._rotation.detach() + opacity = self._opacity.detach() + color=self._ + time = torch.tensor(0).to("cuda").repeat(means3D.shape[0],1) + means3D_deform, scales_deform, rotations_deform, _ = self._deformation(means3D, scales, rotations, opacity, time) + position_error = (means3D_deform - means3D)**2 + rotation_error = (rotations_deform - rotations)**2 + scaling_erorr = (scales_deform - scales)**2 + return position_error.mean() + rotation_error.mean() + scaling_erorr.mean() + + + def add_densification_stats(self, viewspace_point_tensor, update_filter): + self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor[update_filter,:2], dim=-1, keepdim=True) + self.denom[update_filter] += 1 + @torch.no_grad() + # def update_deformation_table(self,threshold): + # # print("origin deformation point nums:",self._deformation_table.sum()) + # self._deformation_table = torch.gt(self._deformation_accum.max(dim=-1).values/100,threshold) + def print_deformation_weight_grad(self): + for name, weight in self._deformation.named_parameters(): + if weight.requires_grad: + if weight.grad is None: + print(name," :",weight.grad) + else: + if weight.grad.mean() != 0: + print(name," :",weight.grad.mean(), weight.grad.min(), weight.grad.max()) + print("-"*50) + def _plane_regulation(self): + multi_res_grids = self._deformation.deformation_net.grid.grids + total = 0 + # model.grids is 6 x [1, rank * F_dim, reso, reso] + for grids in multi_res_grids: + if len(grids) == 3: + time_grids = [] + else: + time_grids = [0,1,3] + for grid_id in time_grids: + total += compute_plane_smoothness(grids[grid_id]) + return total + def _time_regulation(self): + multi_res_grids = self._deformation.deformation_net.grid.grids + total = 0 + # model.grids is 6 x [1, rank * F_dim, reso, reso] + for grids in multi_res_grids: + if len(grids) == 3: + time_grids = [] + else: + time_grids =[2, 4, 5] + for grid_id in time_grids: + total += compute_plane_smoothness(grids[grid_id]) + return total + def _l1_regulation(self): + # model.grids is 6 x [1, rank * F_dim, reso, reso] + multi_res_grids = self._deformation.deformation_net.grid.grids + + total = 0.0 + for grids in multi_res_grids: + if len(grids) == 3: + continue + else: + # These are the spatiotemporal grids + spatiotemporal_grids = [2, 4, 5] + for grid_id in spatiotemporal_grids: + total += torch.abs(1 - grids[grid_id]).mean() + return total + def compute_regulation(self, time_smoothness_weight, l1_time_planes_weight, plane_tv_weight): + return plane_tv_weight * self._plane_regulation() + time_smoothness_weight * self._time_regulation() + l1_time_planes_weight * self._l1_regulation() diff --git a/scene/hexplane.py b/scene/hexplane.py new file mode 100644 index 0000000..668732b --- /dev/null +++ b/scene/hexplane.py @@ -0,0 +1,221 @@ +import itertools +import logging as log +from typing import Optional, Union, List, Dict, Sequence, Iterable, Collection, Callable + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def get_normalized_directions(directions): + """SH encoding must be in the range [0, 1] + + Args: + directions: batch of directions + """ + return (directions + 1.0) / 2.0 + + +def normalize_aabb(pts, aabb): + return (pts - aabb[0]) * (2.0 / (aabb[1] - aabb[0])) - 1.0 +def grid_sample_wrapper(grid: torch.Tensor, coords: torch.Tensor, align_corners: bool = True) -> torch.Tensor: + grid_dim = coords.shape[-1] + + if grid.dim() == grid_dim + 1: + # no batch dimension present, need to add it + grid = grid.unsqueeze(0) + if coords.dim() == 2: + coords = coords.unsqueeze(0) + + if grid_dim == 2 or grid_dim == 3: + grid_sampler = F.grid_sample + else: + raise NotImplementedError(f"Grid-sample was called with {grid_dim}D data but is only " + f"implemented for 2 and 3D data.") + + coords = coords.view([coords.shape[0]] + [1] * (grid_dim - 1) + list(coords.shape[1:])) + B, feature_dim = grid.shape[:2] + n = coords.shape[-2] + interp = grid_sampler( + grid, # [B, feature_dim, reso, ...] + coords, # [B, 1, ..., n, grid_dim] + align_corners=align_corners, + mode='bilinear', padding_mode='border') + interp = interp.view(B, feature_dim, n).transpose(-1, -2) # [B, n, feature_dim] + interp = interp.squeeze() # [B?, n, feature_dim?] + return interp + +def init_grid_param( + grid_nd: int, + in_dim: int, + out_dim: int, + reso: Sequence[int], + a: float = 0.1, + b: float = 0.5): + assert in_dim == len(reso), "Resolution must have same number of elements as input-dimension" + has_time_planes = in_dim == 4 + assert grid_nd <= in_dim + coo_combs = list(itertools.combinations(range(in_dim), grid_nd)) + grid_coefs = nn.ParameterList() + for ci, coo_comb in enumerate(coo_combs): + new_grid_coef = nn.Parameter(torch.empty( + [1, out_dim] + [reso[cc] for cc in coo_comb[::-1]] + )) + if has_time_planes and 3 in coo_comb: # Initialize time planes to 1 + # print('time planes', new_grid_coef.shape) + # nn.init.normal_(new_grid_coef) + nn.init.ones_(new_grid_coef) + else: + nn.init.uniform_(new_grid_coef, a=a, b=b) + grid_coefs.append(new_grid_coef) + + return grid_coefs + + +def interpolate_ms_features(pts: torch.Tensor, + ms_grids: Collection[Iterable[nn.Module]], + grid_dimensions: int, + concat_features: bool, + num_levels: Optional[int], + grid_merge: str, + ) -> torch.Tensor: + coo_combs = list(itertools.combinations( + range(pts.shape[-1]), grid_dimensions) + ) + if num_levels is None: + num_levels = len(ms_grids) + multi_scale_interp = [] if concat_features else 0. + grid: nn.ParameterList + for scale_id, grid in enumerate(ms_grids[:num_levels]): + if grid_merge == 'cat': + interp_space = [] + else: + interp_space = 1. + for ci, coo_comb in enumerate(coo_combs): + # interpolate in plane + feature_dim = grid[ci].shape[1] # shape of grid[ci]: 1, out_dim, *reso + interp_out_plane = ( + grid_sample_wrapper(grid[ci], pts[..., coo_comb]) + .view(-1, feature_dim) + ) + # print(ci, coo_comb, interp_out_plane) + # compute product over planes + if grid_merge == 'plus': + interp_space = interp_space + interp_out_plane + elif grid_merge == 'mul': + interp_space = interp_space * interp_out_plane + elif grid_merge == 'cat': + interp_space.append(interp_out_plane) + else: + raise NotImplementedError + + # combine over scales + # print('length per scale', len(interp_space)) # is 6 + if concat_features: + if grid_merge == 'cat': + for cur in interp_space: + multi_scale_interp.append(cur) + else: + multi_scale_interp.append(interp_space) + else: + raise NotImplementedError + multi_scale_interp = multi_scale_interp + interp_space + + if concat_features: + multi_scale_interp = torch.cat(multi_scale_interp, dim=-1) + return multi_scale_interp + + +class HexPlaneField(nn.Module): + def __init__( + self, + bounds, + planeconfig, + multires, + grid_merge, + ) -> None: + super().__init__() + aabb = torch.tensor([[bounds,bounds,bounds], + [-bounds,-bounds,-bounds]]) + self.aabb = nn.Parameter(aabb, requires_grad=False) + self.grid_config = [planeconfig] + self.multiscale_res_multipliers = multires + self.concat_features = True + self.grid_merge = grid_merge + + # 1. Init planes + self.grids = nn.ModuleList() + self.feat_dim = 0 + for res in self.multiscale_res_multipliers: + # initialize coordinate grid + config = self.grid_config[0].copy() + # Resolution fix: multi-res only on spatial planes + config["resolution"] = [ + r * res for r in config["resolution"][:3] + ] + config["resolution"][3:] + gp = init_grid_param( + grid_nd=config["grid_dimensions"], + in_dim=config["input_coordinate_dim"], + out_dim=config["output_coordinate_dim"], + reso=config["resolution"], + ) + # shape[1] is out-dim - Concatenate over feature len for each scale + if self.concat_features: + self.feat_dim += gp[-1].shape[1] + else: + self.feat_dim = gp[-1].shape[1] + self.grids.append(gp) + # print(f"Initialized model grids: {self.grids}") + # print("feature_dim:",self.feat_dim) + + + def set_aabb(self,xyz_max, xyz_min): + aabb = torch.tensor([ + xyz_max, + xyz_min + ]) + self.aabb = nn.Parameter(aabb,requires_grad=True) + print("Voxel Plane: set aabb=",self.aabb) + + def get_density(self, pts: torch.Tensor, timestamps: Optional[torch.Tensor] = None): + """Computes and returns the densities.""" + + pts = normalize_aabb(pts, self.aabb) + pts = torch.cat((pts, timestamps), dim=-1) # [n_rays, n_samples, 4] + # print('pts', pts, self.concat_features) + + pts = pts.reshape(-1, pts.shape[-1]) + features = interpolate_ms_features( + pts, ms_grids=self.grids, # noqa + grid_dimensions=self.grid_config[0]["grid_dimensions"], + concat_features=self.concat_features, num_levels=None, grid_merge=self.grid_merge) + # print('hexplane features', features.shape) + if len(features) < 1: + raise NotImplementedError + features = torch.zeros((0, 1)).to(features.device) + + return features + + def forward(self, + pts: torch.Tensor, + timestamps: Optional[torch.Tensor] = None): + + features = self.get_density(pts, timestamps) + + return features + +if __name__ == '__main__': + kplanes_config = { + 'grid_dimensions': 2, + 'input_coordinate_dim': 4, + 'output_coordinate_dim': 32, + 'resolution': [64, 64, 64, 16] + # 'resolution': [64, 64, 64, 150] + } + grid = HexPlaneField(2.5, kplanes_config, [1, 2, 4, 8 ]) + pts = torch.randn(1, 3) + for idx in range(16): + feat = grid(pts, torch.tensor([idx / 16]).unsqueeze(0).repeat(pts.shape[0],1)) + print(feat.std(), feat.mean()) + # print(feat.shape, feat[:,:5]) + # print() diff --git a/scene/hyper_loader.py b/scene/hyper_loader.py new file mode 100644 index 0000000..dd12027 --- /dev/null +++ b/scene/hyper_loader.py @@ -0,0 +1,188 @@ +import warnings + +warnings.filterwarnings("ignore") + +import json +import os +import random + +import numpy as np +import torch +from PIL import Image +import math +from tqdm import tqdm +from scene.utils import Camera +from typing import NamedTuple +from torch.utils.data import Dataset +from utils.general_utils import PILtoTorch +# from scene.dataset_readers import +from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal +import copy +class CameraInfo(NamedTuple): + uid: int + R: np.array + T: np.array + FovY: np.array + FovX: np.array + image: np.array + image_path: str + image_name: str + width: int + height: int + time : float + + +class Load_hyper_data(Dataset): + def __init__(self, + datadir, + ratio=1.0, + use_bg_points=False, + split="train" + ): + + from .utils import Camera + datadir = os.path.expanduser(datadir) + with open(f'{datadir}/scene.json', 'r') as f: + scene_json = json.load(f) + with open(f'{datadir}/metadata.json', 'r') as f: + meta_json = json.load(f) + with open(f'{datadir}/dataset.json', 'r') as f: + dataset_json = json.load(f) + + self.near = scene_json['near'] + self.far = scene_json['far'] + self.coord_scale = scene_json['scale'] + self.scene_center = scene_json['center'] + + self.all_img = dataset_json['ids'] + self.val_id = dataset_json['val_ids'] + self.split = split + if len(self.val_id) == 0: + self.i_train = np.array([i for i in np.arange(len(self.all_img)) if + (i%4 == 0)]) + self.i_test = self.i_train+2 + self.i_test = self.i_test[:-1,] + else: + self.train_id = dataset_json['train_ids'] + self.i_test = [] + self.i_train = [] + for i in range(len(self.all_img)): + id = self.all_img[i] + if id in self.val_id: + self.i_test.append(i) + if id in self.train_id: + self.i_train.append(i) + + + self.all_cam = [meta_json[i]['camera_id'] for i in self.all_img] + self.all_time = [meta_json[i]['warp_id'] for i in self.all_img] + max_time = max(self.all_time) + self.all_time = [meta_json[i]['warp_id']/max_time for i in self.all_img] + self.selected_time = set(self.all_time) + self.ratio = ratio + self.max_time = max(self.all_time) + self.min_time = min(self.all_time) + self.i_video = [i for i in range(len(self.all_img))] + self.i_video.sort() + # all poses + self.all_cam_params = [] + for im in self.all_img: + camera = Camera.from_json(f'{datadir}/camera/{im}.json') + camera = camera.scale(ratio) + camera.position -= self.scene_center + camera.position *= self.coord_scale + self.all_cam_params.append(camera) + + self.all_img = [f'{datadir}/rgb/{int(1/ratio)}x/{i}.png' for i in self.all_img] + self.h, self.w = self.all_cam_params[0].image_shape + self.map = {} + self.image_one = Image.open(self.all_img[0]) + self.image_one_torch = PILtoTorch(self.image_one,None).to(torch.float32) + + def __getitem__(self, index): + if self.split == "train": + return self.load_raw(self.i_train[index]) + + elif self.split == "test": + return self.load_raw(self.i_test[index]) + elif self.split == "video": + return self.load_video(self.i_video[index]) + def __len__(self): + if self.split == "train": + return len(self.i_train) + elif self.split == "test": + return len(self.i_test) + elif self.split == "video": + # return len(self.i_video) + return len(self.video_v2) + def load_video(self, idx): + if idx in self.map.keys(): + return self.map[idx] + camera = self.all_cam_params[idx] + w = self.image_one.size[0] + h = self.image_one.size[1] + # image = PILtoTorch(image,None) + # image = image.to(torch.float32) + time = self.all_time[idx] + R = camera.orientation.T + T = - camera.position @ R + FovY = focal2fov(camera.focal_length, self.h) + FovX = focal2fov(camera.focal_length, self.w) + image_path = "/".join(self.all_img[idx].split("/")[:-1]) + image_name = self.all_img[idx].split("/")[-1] + caminfo = CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=self.image_one_torch, + image_path=image_path, image_name=image_name, width=w, height=h, time=time, + ) + self.map[idx] = caminfo + return caminfo + def load_raw(self, idx): + if idx in self.map.keys(): + return self.map[idx] + camera = self.all_cam_params[idx] + image = Image.open(self.all_img[idx]) + w = image.size[0] + h = image.size[1] + image = PILtoTorch(image,None) + image = image.to(torch.float32) + time = self.all_time[idx] + R = camera.orientation.T + T = - camera.position @ R + FovY = focal2fov(camera.focal_length, self.h) + FovX = focal2fov(camera.focal_length, self.w) + image_path = "/".join(self.all_img[idx].split("/")[:-1]) + image_name = self.all_img[idx].split("/")[-1] + caminfo = CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image, + image_path=image_path, image_name=image_name, width=w, height=h, time=time, + ) + self.map[idx] = caminfo + return caminfo + + +def format_hyper_data(data_class, split): + if split == "train": + data_idx = data_class.i_train + elif split == "test": + data_idx = data_class.i_test + # dataset = data_class.copy() + # dataset.mode = split + cam_infos = [] + for uid, index in tqdm(enumerate(data_idx)): + camera = data_class.all_cam_params[index] + # image = Image.open(data_class.all_img[index]) + # image = PILtoTorch(image,None) + time = data_class.all_time[index] + R = camera.orientation.T + T = - camera.position @ R + FovY = focal2fov(camera.focal_length, data_class.h) + FovX = focal2fov(camera.focal_length, data_class.w) + image_path = "/".join(data_class.all_img[index].split("/")[:-1]) + image_name = data_class.all_img[index].split("/")[-1] + cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=None, + image_path=image_path, image_name=image_name, width=int(data_class.w), height=int(data_class.h), time=time, + ) + cam_infos.append(cam_info) + return cam_infos + # matrix = np.linalg.inv(np.array(poses)) + # R = -np.transpose(matrix[:3,:3]) + # R[:,0] = -R[:,0] + # T = -matrix[:3, 3] \ No newline at end of file diff --git a/scene/i2v_dataset.py b/scene/i2v_dataset.py new file mode 100644 index 0000000..10ff3ba --- /dev/null +++ b/scene/i2v_dataset.py @@ -0,0 +1,589 @@ +from torch.utils.data import Dataset +# from scene.cameras import Camera +import numpy as np +from utils.general_utils import PILtoTorch +from utils.graphics_utils import fov2focal, focal2fov +import torch +from utils.camera_utils import loadCam +from utils.graphics_utils import focal2fov + +from torchvision.transforms import ToTensor +from PIL import Image +import glob +from scene.cam_utils import orbit_camera +import math, os + +def getProjectionMatrix(znear, zfar, fovX, fovY): + tanHalfFovY = math.tan((fovY / 2)) + tanHalfFovX = math.tan((fovX / 2)) + + P = torch.zeros(4, 4) + + z_sign = 1.0 + + P[0, 0] = 1 / tanHalfFovX + P[1, 1] = 1 / tanHalfFovY + P[3, 2] = z_sign + P[2, 2] = z_sign * zfar / (zfar - znear) + P[2, 3] = -(zfar * znear) / (zfar - znear) + return P + + +class MiniCam: + def __init__(self, c2w, width, height, fovy, fovx, znear, zfar): + # c2w (pose) should be in NeRF convention. + + self.image_width = width + self.image_height = height + self.FoVy = fovy + self.FoVx = fovx + self.znear = znear + self.zfar = zfar + + w2c = np.linalg.inv(c2w) + + # rectify... + w2c[1:3, :3] *= -1 + w2c[:3, 3] *= -1 + + self.world_view_transform = torch.tensor(w2c).transpose(0, 1)#.cuda() + self.projection_matrix = ( + getProjectionMatrix( + znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy + ) + .transpose(0, 1) + # .cuda() + ) + self.full_proj_transform = self.world_view_transform @ self.projection_matrix + self.camera_center = -torch.tensor(c2w[:3, 3])#.cuda() + + +class FourDGSdataset(Dataset): + def __init__( + self, + split, + frame_num = 16, + name='panda', + rife=False, + static=False, + ): + self.split = split + # self.args = args + + # https://github.com/threestudio-project/threestudio/blob/main/configs/magic123-coarse-sd.yaml#L22 + self.radius = 2.5 + self.W = 512 + self.H = 512 + self.fovy = np.deg2rad(40) + self.fovx = np.deg2rad(40) + # self.fovy = np.deg2rad(49.1) + # self.fovx = np.deg2rad(49.1) + # align with zero123 rendering setting (ref: https://github.com/cvlab-columbia/zero123/blob/main/objaverse-rendering/scripts/blender_script.py#L61 + self.near = 0.01 + self.far = 100 + self.T = ToTensor() + self.len_pose0 = frame_num + self.name=name + self.rife=rife + self.static=static + + pose0_dir=f'data/{self.name}_pose0/' + # pose0_dir=f'data/{self.name}_rgba_pose0/' + + frame_list = range(frame_num) + pose0_im_names = [pose0_dir + f'{x}.png' for x in frame_list] + idx_list = range(frame_num) + if not os.path.exists(pose0_im_names[0]): # check 0 index + pose0_im_names = pose0_im_names[1:] + [pose0_dir + f'{frame_num}.png'] # use 1 index + idx_list = list(idx_list)[1:] + [frame_num] + + base_dir=f'./data/{self.name}_sync' + + syncdreamer_im = [] + # for fname in t0_im_names: + assert self.static==False + if self.static==False: + for frame_idx in idx_list: + # for frame_idx in range(1, frame_num + 1): + li = [] + for view_idx in range(16): + fname = os.path.join(base_dir, f"{frame_idx}_0_{view_idx}_rgba.png") + im = Image.open(fname).resize((self.W, self.H))#.convert('RGB') + # use RGBA + ww = self.T(im) + assert ww.shape[0] == 4 + ww[:3] = ww[:3] * ww[-1:] + (1 - ww[-1:]) + li.append(ww) + li = torch.stack(li, dim=0)#.permute(0, 2, 3, 1) + syncdreamer_im.append(li) + self.syncdreamer_im = torch.stack(syncdreamer_im, 0) # [fn, 16, 3, 512, 512] + else: + #sync only read frame0 + # (dejia): not used + for frame_idx in range(frame_num): + li = [] + frame_idx=0 + for view_idx in range(16): + fname = os.path.join(base_dir, f"{frame_idx}_0_{view_idx}_rgba.png") + # fname = os.path.join(base_dir, f"{self.name}{frame_idx}_0_{view_idx}_rgba.png") + im = Image.open(fname).resize((self.W, self.H))#.convert('RGB') + # use RGBA + ww = self.T(im) + assert ww.shape[0] == 4 + ww[:3] = ww[:3] * ww[-1:] + (1 - ww[-1:]) + li.append(ww) + li = torch.stack(li, dim=0)#.permute(0, 2, 3, 1) + syncdreamer_im.append(li) + self.syncdreamer_im = torch.stack(syncdreamer_im, 0) # [fn, 16, 3, 512, 512] + + print(f"syncdreamer images loaded {self.syncdreamer_im.shape}.") + + self.pose0_im_list = [] + # TODO: should images be RGBA when input?? + for fname in pose0_im_names: + im = Image.open(fname).resize((self.W, self.H))#.convert('RGB') + ww = self.T(im) + ww[:3] = ww[:3] * ww[-1:] + (1 - ww[-1:]) + self.pose0_im_list.append(ww) + # self.pose0_im_list.append(self.T(im)) + while len(self.pose0_im_list) < self.len_pose0: + self.pose0_im_list.append(ww) + self.pose0_im_list = torch.stack(self.pose0_im_list, dim=0)#.permute(0, 2, 3, 1) + # self.pose0_im_list = self.pose0_im_list.expand(fn, 3, 256, 256) + print(f"Pose0 images loaded {self.pose0_im_list.shape}") + self.syncdreamer_im = torch.cat([self.pose0_im_list.unsqueeze(1), self.syncdreamer_im], 1) + print(f"New syncdreamer shape {self.syncdreamer_im.shape}") + self.max_frames = self.pose0_im_list.shape[0] + print(f"Loaded SDS Dataset. Max {self.max_frames} frames.") + + # self.t0_num = self.t0_im_list.shape[0] + self.pose0_num = self.pose0_im_list.shape[0] + if self.split == 'train': + self.t0_num = 16 + 1 # fixed + else: + self.t0_num = 100 + self.len_ = (self.t0_num) * (self.pose0_num) + + pose0_pose = orbit_camera(0, 0, self.radius) + self.pose0_cam = MiniCam( + pose0_pose, + self.H, # NOTE: order might be wrong + self.W, + self.fovy, + self.fovx, + self.near, + self.far, + ) + self.t0_pose = [self.pose0_cam] + [MiniCam( + # self.t0_pose = [MiniCam( + orbit_camera(-30, azimuth, self.radius), + self.H, # NOTE: order might be wrong + self.W, + self.fovy, + self.fovx, + self.near, + self.far, + ) for azimuth in np.concatenate([np.arange(0, 180, 22.5), np.arange(-180, 0, 22.5)])] + + # we sample (pose, t) + def __getitem__(self, index): + if self.split == 'train': + t0_idx = index // self.pose0_num + pose0_idx = index % self.pose0_num + time = torch.tensor([pose0_idx]).unsqueeze(0)#.expand(1, self.W * self.H) + else: + t0_idx = index # self.t0_num // 2 + pose0_idx = 1 + time = torch.tensor([pose0_idx]).unsqueeze(0) + + out = { + # timestamp is per pixel + "time": time / self.pose0_num, + 'pose0': self.pose0_im_list[pose0_idx], + 'pose0_idx': pose0_idx, + 't0_idx': t0_idx, + 't0_weight': min(abs(t0_idx), abs(self.t0_num - t0_idx)), + # 't0': self.t0_im_list[t0_idx].view(-1, 3), + # 'pose0': self.pose0_im_list[pose0_idx].view(-1, 3), + # 'bg_color': torch.ones((1, 3), dtype=torch.float32), + "pose0_cam": self.pose0_cam, + } + #t0_idx=0 + if self.split == 'train': + out['t0'] = self.syncdreamer_im[0][t0_idx] + out['gtim'] = self.syncdreamer_im[pose0_idx][t0_idx] # coarse stage + + t0_cam = self.t0_pose[t0_idx] + out['t0_cam'] = t0_cam + # out['sync_cam'] = self.sync_pose + + + + ## for render.py multiview_video + + ver = 0 + hor = (index / 100) * 360 + # ver = np.random.randint(-45, 45) + # hor = np.random.randint(-180, 180) + pose = orbit_camera(0 + ver, hor, self.radius) + out['hor'] = hor + out['ver'] = ver + + cur_cam = MiniCam( + pose, + self.H, # NOTE: order might be wrong + self.W, + self.fovy, + self.fovx, + self.near, + self.far, + ) + out['cur_cam'] = cur_cam + + # for fine stage, random seq + + rand_seq = [] + ver_list = [] + hor_list = [] + # for i in range(self.pose0_num - 1): + for i in range(self.pose0_num): + ver = np.random.randint(-30, 30) + hor = np.random.randint(-180, 180) + cur_pose = orbit_camera(ver, hor, self.radius) + ver_list.append(ver) + hor_list.append(hor) + # cur_pose = orbit_camera(ver_offset[i], hor_offset[i], self.radius) + rand_seq.append(MiniCam( + cur_pose if self.split == 'train' else pose, + # cur_pose, + self.H, # NOTE: order might be wrong + self.W, + self.fovy, + self.fovx, + self.near, + self.far, + )) + out['rand_poses'] = rand_seq + out['rand_ver'] = np.array(ver_list) + out['rand_hor'] = np.array(hor_list) + # out['rand_ver'] = ver_offset + # out['rand_hor'] = hor_offset + + back_pose=orbit_camera(0, 180, self.radius) + out['back_cam']=MiniCam( + back_pose, + self.H, # NOTE: order might be wrong + self.W, + self.fovy, + self.fovx, + self.near, + self.far, + ) + + side_pose=orbit_camera(0, 90, self.radius) + out['side_cam']=MiniCam( + side_pose, + self.H, # NOTE: order might be wrong + self.W, + self.fovy, + self.fovx, + self.near, + self.far, + ) + + side_pose=orbit_camera(0, 70, self.radius) + out['side_cam2']=MiniCam( + side_pose, + self.H, # NOTE: order might be wrong + self.W, + self.fovy, + self.fovx, + self.near, + self.far, + ) + + front_pose=orbit_camera(0, 0, self.radius) + out['front_cam']=MiniCam( + front_pose, + self.H, # NOTE: order might be wrong + self.W, + self.fovy, + self.fovx, + self.near, + self.far, + ) + return out + + def __len__(self): + # we sample (pose, t) + if self.split == 'train': + return self.len_ + if self.split == 'test': + return self.pose0_num + # return self.t0_num + if self.split == 'video': + return 100 + + +class ImageDreamdataset(Dataset): + def __init__( + self, + split, + frame_num = 16, + name='panda', + rife=False, + static=False, + ): + self.split = split + # self.args = args + + # https://github.com/threestudio-project/threestudio/blob/main/configs/magic123-coarse-sd.yaml#L22 + # self.radius = 2.5 + self.radius = 2.0 ## imagedream https://github.com/bytedance/ImageDream/blob/13e05566ca27c66b6bc5b3ee42bc68ddfb471585/configs/imagedream-sd21-shading.yaml#L20 + self.W = 512 + self.H = 512 + self.fovy = np.deg2rad(40) + self.fovx = np.deg2rad(40) + # self.fovy = np.deg2rad(49.1) + # self.fovx = np.deg2rad(49.1) + # align with zero123 rendering setting (ref: https://github.com/cvlab-columbia/zero123/blob/main/objaverse-rendering/scripts/blender_script.py#L61 + self.near = 0.01 + self.far = 100 + self.T = ToTensor() + self.len_pose0 = frame_num + self.name=name + self.rife=rife + self.static=static + + pose0_dir=f'./data/ImageDream/{self.name}/rgba/' + + frame_list = range(frame_num) + pose0_im_names = [pose0_dir + f'{x}.png' for x in frame_list] + idx_list = range(frame_num) + if not os.path.exists(pose0_im_names[0]): # check 0 index + pose0_im_names = pose0_im_names[1:] + [pose0_dir + f'{frame_num}.png'] # use 1 index + idx_list = list(idx_list)[1:] + [frame_num] + + base_dir=f'./data/output_svd/{self.name}' + syncdreamer_im = [] + assert self.static==False + if self.static==False: + for frame_idx in idx_list: + li = [] + for view_idx in range(4): + #view_idx=0 + fname = os.path.join(base_dir, f"{frame_idx}_{view_idx}_rgba.png") + im = Image.open(fname).resize((self.W, self.H))#.convert('RGB') + # use RGBA + ww = self.T(im) + assert ww.shape[0] == 4 + ww[:3] = ww[:3] * ww[-1:] + (1 - ww[-1:]) + li.append(ww) + li = torch.stack(li, dim=0)#.permute(0, 2, 3, 1) + syncdreamer_im.append(li) + self.syncdreamer_im = torch.stack(syncdreamer_im, 0) # [fn, 16, 3, 512, 512] + else: + raise NotImplementedError + + + + + print(f"imagedream images loaded {self.syncdreamer_im.shape}.") + + self.pose0_im_list = [] + # TODO: should images be RGBA when input?? + for fname in pose0_im_names: + im = Image.open(fname).resize((self.W, self.H))#.convert('RGB') + ww = self.T(im) + ww[:3] = ww[:3] * ww[-1:] + (1 - ww[-1:]) + self.pose0_im_list.append(ww) + # self.pose0_im_list.append(self.T(im)) + while len(self.pose0_im_list) < self.len_pose0: + self.pose0_im_list.append(ww) + self.pose0_im_list = torch.stack(self.pose0_im_list, dim=0)#.permute(0, 2, 3, 1) + # self.pose0_im_list = self.pose0_im_list.expand(fn, 3, 256, 256) + print(f"Pose0 images loaded {self.pose0_im_list.shape}") + # self.syncdreamer_im = torch.cat([self.pose0_im_list.unsqueeze(1), self.syncdreamer_im], 1) + print(f"New syncdreamer shape {self.syncdreamer_im.shape}") + self.max_frames = self.pose0_im_list.shape[0] + print(f"Loaded SDS Dataset. Max {self.max_frames} frames.") + + # self.t0_num = self.t0_im_list.shape[0] + self.pose0_num = self.pose0_im_list.shape[0] + if self.split == 'train': + self.t0_num = 4# + 1 # fixed + else: + self.t0_num = 100 + self.len_ = (self.t0_num) * (self.pose0_num) + + # NOTE: this is different!! + pose0_pose = orbit_camera(0, 90, self.radius) + self.pose0_cam = MiniCam( + pose0_pose, + self.H, # NOTE: order might be wrong + self.W, + self.fovy, + self.fovx, + self.near, + self.far, + ) + # self.t0_pose = [self.pose0_cam] + [MiniCam( + self.t0_pose = [MiniCam( + orbit_camera(0, azimuth, self.radius), + self.H, # NOTE: order might be wrong + self.W, + self.fovy, + self.fovx, + self.near, + self.far, + ) for azimuth in np.concatenate([np.arange(0, 180, 90), np.arange(-180, 0, 90)])] + + # we sample (pose, t) + def __getitem__(self, index): + if self.split == 'train': + t0_idx = index // self.pose0_num + pose0_idx = index % self.pose0_num + time = torch.tensor([pose0_idx]).unsqueeze(0)#.expand(1, self.W * self.H) + else: + t0_idx = index # self.t0_num // 2 + pose0_idx = 1 + time = torch.tensor([pose0_idx]).unsqueeze(0) + + out = { + # timestamp is per pixel + "time": time / self.pose0_num, + 'pose0': self.pose0_im_list[pose0_idx], + 'pose0_idx': pose0_idx, + 't0_idx': t0_idx, + 't0_weight': min(abs(t0_idx), abs(self.t0_num - t0_idx)), + # 't0': self.t0_im_list[t0_idx].view(-1, 3), + # 'pose0': self.pose0_im_list[pose0_idx].view(-1, 3), + # 'bg_color': torch.ones((1, 3), dtype=torch.float32), + "pose0_cam": self.pose0_cam, + } + #t0_idx=0 + if self.split == 'train': + out['t0'] = self.syncdreamer_im[0][t0_idx] + out['gtim'] = self.syncdreamer_im[pose0_idx][t0_idx] # coarse stage + + t0_cam = self.t0_pose[t0_idx] + out['t0_cam'] = t0_cam + + ## for render.py multiview_video + ver = 0 + hor = (index / 100) * 360 + pose = orbit_camera(0 + ver, hor, self.radius) + out['hor'] = hor + out['ver'] = ver + + cur_cam = MiniCam( + pose, + self.H, # NOTE: order might be wrong + self.W, + self.fovy, + self.fovx, + self.near, + self.far, + ) + out['cur_cam'] = cur_cam + + # for fine stage, random seq + + rand_seq = [] + ver_list = [] + hor_list = [] + # for i in range(self.pose0_num - 1): + for i in range(self.pose0_num): + ver = np.random.randint(-30, 30) + hor = np.random.randint(-180, 180) + cur_pose = orbit_camera(ver, hor, self.radius) + ver_list.append(ver) + hor_list.append(hor) + # cur_pose = orbit_camera(ver_offset[i], hor_offset[i], self.radius) + rand_seq.append(MiniCam( + cur_pose if self.split == 'train' else pose, + # cur_pose, + self.H, # NOTE: order might be wrong + self.W, + self.fovy, + self.fovx, + self.near, + self.far, + )) + out['rand_poses'] = rand_seq + out['rand_ver'] = np.array(ver_list) + out['rand_hor'] = np.array(hor_list) + # out['rand_ver'] = ver_offset + # out['rand_hor'] = hor_offset + + back_pose=orbit_camera(0, 180, self.radius) + out['back_cam']=MiniCam( + back_pose, + self.H, # NOTE: order might be wrong + self.W, + self.fovy, + self.fovx, + self.near, + self.far, + ) + + side_pose=orbit_camera(0, 90, self.radius) + out['side_cam']=MiniCam( + side_pose, + self.H, # NOTE: order might be wrong + self.W, + self.fovy, + self.fovx, + self.near, + self.far, + ) + + side_pose=orbit_camera(0, 70, self.radius) + out['side_cam2']=MiniCam( + side_pose, + self.H, # NOTE: order might be wrong + self.W, + self.fovy, + self.fovx, + self.near, + self.far, + ) + + front_pose=orbit_camera(0, 0, self.radius) + out['front_cam']=MiniCam( + front_pose, + self.H, # NOTE: order might be wrong + self.W, + self.fovy, + self.fovx, + self.near, + self.far, + ) + + ver = np.random.randint(-30, 30) + hor = np.random.randint(-180, 180) + li = [orbit_camera(ver, hor, self.radius)] + for view_i in range(1, 4): + li.append(orbit_camera(ver, hor + 90 * view_i, self.radius)) + out['dream_pose_mat'] = torch.from_numpy(np.stack(li, axis=0)) + out['dream_pose'] = [MiniCam( + cur_pose, + # cur_pose, + self.H, # NOTE: order might be wrong + self.W, + self.fovy, + self.fovx, + self.near, + self.far, + ) for cur_pose in li] + return out + + def __len__(self): + # we sample (pose, t) + if self.split == 'train': + return self.len_ + if self.split == 'test': + return self.pose0_num + # return self.t0_num + if self.split == 'video': + return 100 diff --git a/scene/neural_3D_dataset_NDC.py b/scene/neural_3D_dataset_NDC.py new file mode 100644 index 0000000..63bbcad --- /dev/null +++ b/scene/neural_3D_dataset_NDC.py @@ -0,0 +1,376 @@ +import concurrent.futures +import gc +import glob +import os + +import cv2 +import numpy as np +import torch +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms as T +from tqdm import tqdm + + +def normalize(v): + """Normalize a vector.""" + return v / np.linalg.norm(v) + + +def average_poses(poses): + """ + Calculate the average pose, which is then used to center all poses + using @center_poses. Its computation is as follows: + 1. Compute the center: the average of pose centers. + 2. Compute the z axis: the normalized average z axis. + 3. Compute axis y': the average y axis. + 4. Compute x' = y' cross product z, then normalize it as the x axis. + 5. Compute the y axis: z cross product x. + + Note that at step 3, we cannot directly use y' as y axis since it's + not necessarily orthogonal to z axis. We need to pass from x to y. + Inputs: + poses: (N_images, 3, 4) + Outputs: + pose_avg: (3, 4) the average pose + """ + # 1. Compute the center + center = poses[..., 3].mean(0) # (3) + + # 2. Compute the z axis + z = normalize(poses[..., 2].mean(0)) # (3) + + # 3. Compute axis y' (no need to normalize as it's not the final output) + y_ = poses[..., 1].mean(0) # (3) + + # 4. Compute the x axis + x = normalize(np.cross(z, y_)) # (3) + + # 5. Compute the y axis (as z and x are normalized, y is already of norm 1) + y = np.cross(x, z) # (3) + + pose_avg = np.stack([x, y, z, center], 1) # (3, 4) + + return pose_avg + + +def center_poses(poses, blender2opencv): + """ + Center the poses so that we can use NDC. + See https://github.com/bmild/nerf/issues/34 + Inputs: + poses: (N_images, 3, 4) + Outputs: + poses_centered: (N_images, 3, 4) the centered poses + pose_avg: (3, 4) the average pose + """ + poses = poses @ blender2opencv + pose_avg = average_poses(poses) # (3, 4) + pose_avg_homo = np.eye(4) + pose_avg_homo[ + :3 + ] = pose_avg # convert to homogeneous coordinate for faster computation + pose_avg_homo = pose_avg_homo + # by simply adding 0, 0, 0, 1 as the last row + last_row = np.tile(np.array([0, 0, 0, 1]), (len(poses), 1, 1)) # (N_images, 1, 4) + poses_homo = np.concatenate( + [poses, last_row], 1 + ) # (N_images, 4, 4) homogeneous coordinate + + poses_centered = np.linalg.inv(pose_avg_homo) @ poses_homo # (N_images, 4, 4) + # poses_centered = poses_centered @ blender2opencv + poses_centered = poses_centered[:, :3] # (N_images, 3, 4) + + return poses_centered, pose_avg_homo + + +def viewmatrix(z, up, pos): + vec2 = normalize(z) + vec1_avg = up + vec0 = normalize(np.cross(vec1_avg, vec2)) + vec1 = normalize(np.cross(vec2, vec0)) + m = np.eye(4) + m[:3] = np.stack([-vec0, vec1, vec2, pos], 1) + return m + + +def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, N_rots=2, N=120): + render_poses = [] + rads = np.array(list(rads) + [1.0]) + + for theta in np.linspace(0.0, 2.0 * np.pi * N_rots, N + 1)[:-1]: + c = np.dot( + c2w[:3, :4], + np.array([np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.0]) + * rads, + ) + z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.0]))) + render_poses.append(viewmatrix(z, up, c)) + return render_poses + + + +def process_video(video_data_save, video_path, img_wh, downsample, transform): + """ + Load video_path data to video_data_save tensor. + """ + video_frames = cv2.VideoCapture(video_path) + count = 0 + video_images_path = video_path.split('.')[0] + image_path = os.path.join(video_images_path,"images") + + if not os.path.exists(image_path): + os.makedirs(image_path) + while video_frames.isOpened(): + ret, video_frame = video_frames.read() + if ret: + video_frame = cv2.cvtColor(video_frame, cv2.COLOR_BGR2RGB) + video_frame = Image.fromarray(video_frame) + if downsample != 1.0: + + img = video_frame.resize(img_wh, Image.LANCZOS) + img.save(os.path.join(image_path,"%04d.png"%count)) + + img = transform(img) + video_data_save[count] = img.permute(1,2,0) + count += 1 + else: + break + + else: + images_path = os.listdir(image_path) + images_path.sort() + + for path in images_path: + img = Image.open(os.path.join(image_path,path)) + if downsample != 1.0: + img = img.resize(img_wh, Image.LANCZOS) + img = transform(img) + video_data_save[count] = img.permute(1,2,0) + count += 1 + + video_frames.release() + print(f"Video {video_path} processed.") + return None + + +# define a function to process all videos +def process_videos(videos, skip_index, img_wh, downsample, transform, num_workers=1): + """ + A multi-threaded function to load all videos fastly and memory-efficiently. + To save memory, we pre-allocate a tensor to store all the images and spawn multi-threads to load the images into this tensor. + """ + all_imgs = torch.zeros(len(videos) - 1, 300, img_wh[-1] , img_wh[-2], 3) + with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: + # start a thread for each video + current_index = 0 + futures = [] + for index, video_path in enumerate(videos): + # skip the video with skip_index (eval video) + if index == skip_index: + continue + else: + future = executor.submit( + process_video, + all_imgs[current_index], + video_path, + img_wh, + downsample, + transform, + ) + futures.append(future) + current_index += 1 + return all_imgs + +def get_spiral(c2ws_all, near_fars, rads_scale=1.0, N_views=120): + """ + Generate a set of poses using NeRF's spiral camera trajectory as validation poses. + """ + # center pose + c2w = average_poses(c2ws_all) + + # Get average pose + up = normalize(c2ws_all[:, :3, 1].sum(0)) + + # Find a reasonable "focus depth" for this dataset + dt = 0.75 + close_depth, inf_depth = near_fars.min() * 0.9, near_fars.max() * 5.0 + focal = 1.0 / ((1.0 - dt) / close_depth + dt / inf_depth) + + # Get radii for spiral path + zdelta = near_fars.min() * 0.2 + tt = c2ws_all[:, :3, 3] + rads = np.percentile(np.abs(tt), 90, 0) * rads_scale + render_poses = render_path_spiral( + c2w, up, rads, focal, zdelta, zrate=0.5, N=N_views + ) + return np.stack(render_poses) + + +class Neural3D_NDC_Dataset(Dataset): + def __init__( + self, + datadir, + split="train", + downsample=1.0, + is_stack=True, + cal_fine_bbox=False, + N_vis=-1, + time_scale=1.0, + scene_bbox_min=[-1.0, -1.0, -1.0], + scene_bbox_max=[1.0, 1.0, 1.0], + N_random_pose=1000, + bd_factor=0.75, + eval_step=1, + eval_index=0, + sphere_scale=1.0, + ): + self.img_wh = ( + int(1352 / downsample), + int(1014 / downsample), + ) # According to the neural 3D paper, the default resolution is 1024x768 + self.root_dir = datadir + self.split = split + self.downsample = 2704 / self.img_wh[0] + self.is_stack = is_stack + self.N_vis = N_vis + self.time_scale = time_scale + self.scene_bbox = torch.tensor([scene_bbox_min, scene_bbox_max]) + + self.world_bound_scale = 1.1 + self.bd_factor = bd_factor + self.eval_step = eval_step + self.eval_index = eval_index + self.blender2opencv = np.eye(4) + self.transform = T.ToTensor() + + self.near = 0.0 + self.far = 1.0 + self.near_far = [self.near, self.far] # NDC near far is [0, 1.0] + self.white_bg = False + self.ndc_ray = True + self.depth_data = False + + self.load_meta() + print(f"meta data loaded, total image:{len(self)}") + + def load_meta(self): + """ + Load meta data from the dataset. + """ + # Read poses and video file paths. + poses_arr = np.load(os.path.join(self.root_dir, "poses_bounds.npy")) + poses = poses_arr[:, :-2].reshape([-1, 3, 5]) # (N_cams, 3, 5) + self.near_fars = poses_arr[:, -2:] + videos = glob.glob(os.path.join(self.root_dir, "cam*")) + videos = sorted(videos) + assert len(videos) == poses_arr.shape[0] + + H, W, focal = poses[0, :, -1] + focal = focal / self.downsample + self.focal = [focal, focal] + poses = np.concatenate([poses[..., 1:2], -poses[..., :1], poses[..., 2:4]], -1) + poses, _ = center_poses( + poses, self.blender2opencv + ) # Re-center poses so that the average is near the center. + + near_original = self.near_fars.min() + scale_factor = near_original * 0.75 + self.near_fars /= ( + scale_factor # rescale nearest plane so that it is at z = 4/3. + ) + poses[..., 3] /= scale_factor + + # Sample N_views poses for validation - NeRF-like camera trajectory. + N_views = 120 + self.val_poses = get_spiral(poses, self.near_fars, N_views=N_views) + # self.val_poses = self.directions + W, H = self.img_wh + poses_i_train = [] + + for i in range(len(poses)): + if i != self.eval_index: + poses_i_train.append(i) + self.poses = poses[poses_i_train] + self.poses_all = poses + self.image_paths, self.image_poses, self.image_times, N_cam, N_time = self.load_images_path(videos, self.split) + self.cam_number = N_cam + self.time_number = N_time + def get_val_pose(self): + render_poses = self.val_poses + render_times = torch.linspace(0.0, 1.0, render_poses.shape[0]) * 2.0 - 1.0 + return render_poses, self.time_scale * render_times + def load_images_path(self,videos,split): + image_paths = [] + image_poses = [] + image_times = [] + N_cams = 0 + N_time = 0 + countss = 300 + for index, video_path in enumerate(videos): + + if index == self.eval_index: + if split =="train": + continue + else: + if split == "test": + continue + N_cams +=1 + count = 0 + video_images_path = video_path.split('.')[0] + image_path = os.path.join(video_images_path,"images") + video_frames = cv2.VideoCapture(video_path) + if not os.path.exists(image_path): + print(f"no images saved in {image_path}, extract images from video.") + os.makedirs(image_path) + this_count = 0 + while video_frames.isOpened(): + ret, video_frame = video_frames.read() + if this_count >= countss:break + if ret: + video_frame = cv2.cvtColor(video_frame, cv2.COLOR_BGR2RGB) + video_frame = Image.fromarray(video_frame) + if self.downsample != 1.0: + + img = video_frame.resize(self.img_wh, Image.LANCZOS) + img.save(os.path.join(image_path,"%04d.png"%count)) + + # img = transform(img) + count += 1 + this_count+=1 + else: + break + + images_path = os.listdir(image_path) + images_path.sort() + this_count = 0 + for idx, path in enumerate(images_path): + if this_count >=countss:break + image_paths.append(os.path.join(image_path,path)) + pose = np.array(self.poses_all[index]) + R = pose[:3,:3] + R = -R + R[:,0] = -R[:,0] + T = -pose[:3,3].dot(R) + image_times.append(idx/countss) + image_poses.append((R,T)) + # if self.downsample != 1.0: + # img = video_frame.resize(self.img_wh, Image.LANCZOS) + # img.save(os.path.join(image_path,"%04d.png"%count)) + this_count+=1 + N_time = len(images_path) + + # video_data_save[count] = img.permute(1,2,0) + # count += 1 + return image_paths, image_poses, image_times, N_cams, N_time + def __len__(self): + return len(self.image_paths) + def __getitem__(self,index): + img = Image.open(self.image_paths[index]) + img = img.resize(self.img_wh, Image.LANCZOS) + + img = self.transform(img) + return img, self.image_poses[index], self.image_times[index] + def load_pose(self,index): + return self.image_poses[index] + diff --git a/scene/regulation.py b/scene/regulation.py new file mode 100644 index 0000000..80583a3 --- /dev/null +++ b/scene/regulation.py @@ -0,0 +1,176 @@ +import abc +import os +from typing import Sequence + +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.optim.lr_scheduler +from torch import nn + + + +def compute_plane_tv(t): + batch_size, c, h, w = t.shape + count_h = batch_size * c * (h - 1) * w + count_w = batch_size * c * h * (w - 1) + h_tv = torch.square(t[..., 1:, :] - t[..., :h-1, :]).sum() + w_tv = torch.square(t[..., :, 1:] - t[..., :, :w-1]).sum() + return 2 * (h_tv / count_h + w_tv / count_w) # This is summing over batch and c instead of avg + + +def compute_plane_smoothness(t): + batch_size, c, h, w = t.shape + # Convolve with a second derivative filter, in the time dimension which is dimension 2 + first_difference = t[..., 1:, :] - t[..., :h-1, :] # [batch, c, h-1, w] + second_difference = first_difference[..., 1:, :] - first_difference[..., :h-2, :] # [batch, c, h-2, w] + # Take the L2 norm of the result + return torch.square(second_difference).mean() + + +class Regularizer(): + def __init__(self, reg_type, initialization): + self.reg_type = reg_type + self.initialization = initialization + self.weight = float(self.initialization) + self.last_reg = None + + def step(self, global_step): + pass + + def report(self, d): + if self.last_reg is not None: + d[self.reg_type].update(self.last_reg.item()) + + def regularize(self, *args, **kwargs) -> torch.Tensor: + out = self._regularize(*args, **kwargs) * self.weight + self.last_reg = out.detach() + return out + + @abc.abstractmethod + def _regularize(self, *args, **kwargs) -> torch.Tensor: + raise NotImplementedError() + + def __str__(self): + return f"Regularizer({self.reg_type}, weight={self.weight})" + + +class PlaneTV(Regularizer): + def __init__(self, initial_value, what: str = 'field'): + if what not in {'field', 'proposal_network'}: + raise ValueError(f'what must be one of "field" or "proposal_network" ' + f'but {what} was passed.') + name = f'planeTV-{what[:2]}' + super().__init__(name, initial_value) + self.what = what + + def step(self, global_step): + pass + + def _regularize(self, model, **kwargs): + multi_res_grids: Sequence[nn.ParameterList] + if self.what == 'field': + multi_res_grids = model.field.grids + elif self.what == 'proposal_network': + multi_res_grids = [p.grids for p in model.proposal_networks] + else: + raise NotImplementedError(self.what) + total = 0 + # Note: input to compute_plane_tv should be of shape [batch_size, c, h, w] + for grids in multi_res_grids: + if len(grids) == 3: + spatial_grids = [0, 1, 2] + else: + spatial_grids = [0, 1, 3] # These are the spatial grids; the others are spatiotemporal + for grid_id in spatial_grids: + total += compute_plane_tv(grids[grid_id]) + for grid in grids: + # grid: [1, c, h, w] + total += compute_plane_tv(grid) + return total + + +class TimeSmoothness(Regularizer): + def __init__(self, initial_value, what: str = 'field'): + if what not in {'field', 'proposal_network'}: + raise ValueError(f'what must be one of "field" or "proposal_network" ' + f'but {what} was passed.') + name = f'time-smooth-{what[:2]}' + super().__init__(name, initial_value) + self.what = what + + def _regularize(self, model, **kwargs) -> torch.Tensor: + multi_res_grids: Sequence[nn.ParameterList] + if self.what == 'field': + multi_res_grids = model.field.grids + elif self.what == 'proposal_network': + multi_res_grids = [p.grids for p in model.proposal_networks] + else: + raise NotImplementedError(self.what) + total = 0 + # model.grids is 6 x [1, rank * F_dim, reso, reso] + for grids in multi_res_grids: + if len(grids) == 3: + time_grids = [] + else: + time_grids = [2, 4, 5] + for grid_id in time_grids: + total += compute_plane_smoothness(grids[grid_id]) + return torch.as_tensor(total) + + + +class L1ProposalNetwork(Regularizer): + def __init__(self, initial_value): + super().__init__('l1-proposal-network', initial_value) + + def _regularize(self, model, **kwargs) -> torch.Tensor: + grids = [p.grids for p in model.proposal_networks] + total = 0.0 + for pn_grids in grids: + for grid in pn_grids: + total += torch.abs(grid).mean() + return torch.as_tensor(total) + + +class DepthTV(Regularizer): + def __init__(self, initial_value): + super().__init__('tv-depth', initial_value) + + def _regularize(self, model, model_out, **kwargs) -> torch.Tensor: + depth = model_out['depth'] + tv = compute_plane_tv( + depth.reshape(64, 64)[None, None, :, :] + ) + return tv + + +class L1TimePlanes(Regularizer): + def __init__(self, initial_value, what='field'): + if what not in {'field', 'proposal_network'}: + raise ValueError(f'what must be one of "field" or "proposal_network" ' + f'but {what} was passed.') + super().__init__(f'l1-time-{what[:2]}', initial_value) + self.what = what + + def _regularize(self, model, **kwargs) -> torch.Tensor: + # model.grids is 6 x [1, rank * F_dim, reso, reso] + multi_res_grids: Sequence[nn.ParameterList] + if self.what == 'field': + multi_res_grids = model.field.grids + elif self.what == 'proposal_network': + multi_res_grids = [p.grids for p in model.proposal_networks] + else: + raise NotImplementedError(self.what) + + total = 0.0 + for grids in multi_res_grids: + if len(grids) == 3: + continue + else: + # These are the spatiotemporal grids + spatiotemporal_grids = [2, 4, 5] + for grid_id in spatiotemporal_grids: + total += torch.abs(1 - grids[grid_id]).mean() + return torch.as_tensor(total) + diff --git a/scene/utils.py b/scene/utils.py new file mode 100644 index 0000000..d6edf7e --- /dev/null +++ b/scene/utils.py @@ -0,0 +1,429 @@ +import copy +import json +import math +import os +import pathlib +from typing import Any, Callable, List, Optional, Text, Tuple, Union + +import numpy as np +import scipy.signal +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + + +PRNGKey = Any +Shape = Tuple[int] +Dtype = Any # this could be a real type? +Array = Any +Activation = Callable[[Array], Array] +Initializer = Callable[[PRNGKey, Shape, Dtype], Array] +Normalizer = Callable[[], Callable[[Array], Array]] +PathType = Union[Text, pathlib.PurePosixPath] + +from pathlib import PurePosixPath as GPath + + +def _compute_residual_and_jacobian( + x: np.ndarray, + y: np.ndarray, + xd: np.ndarray, + yd: np.ndarray, + k1: float = 0.0, + k2: float = 0.0, + k3: float = 0.0, + p1: float = 0.0, + p2: float = 0.0, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, + np.ndarray]: + """Auxiliary function of radial_and_tangential_undistort().""" + + r = x * x + y * y + d = 1.0 + r * (k1 + r * (k2 + k3 * r)) + + fx = d * x + 2 * p1 * x * y + p2 * (r + 2 * x * x) - xd + fy = d * y + 2 * p2 * x * y + p1 * (r + 2 * y * y) - yd + + # Compute derivative of d over [x, y] + d_r = (k1 + r * (2.0 * k2 + 3.0 * k3 * r)) + d_x = 2.0 * x * d_r + d_y = 2.0 * y * d_r + + # Compute derivative of fx over x and y. + fx_x = d + d_x * x + 2.0 * p1 * y + 6.0 * p2 * x + fx_y = d_y * x + 2.0 * p1 * x + 2.0 * p2 * y + + # Compute derivative of fy over x and y. + fy_x = d_x * y + 2.0 * p2 * y + 2.0 * p1 * x + fy_y = d + d_y * y + 2.0 * p2 * x + 6.0 * p1 * y + + return fx, fy, fx_x, fx_y, fy_x, fy_y + + +def _radial_and_tangential_undistort( + xd: np.ndarray, + yd: np.ndarray, + k1: float = 0, + k2: float = 0, + k3: float = 0, + p1: float = 0, + p2: float = 0, + eps: float = 1e-9, + max_iterations=10) -> Tuple[np.ndarray, np.ndarray]: + """Computes undistorted (x, y) from (xd, yd).""" + # Initialize from the distorted point. + x = xd.copy() + y = yd.copy() + + for _ in range(max_iterations): + fx, fy, fx_x, fx_y, fy_x, fy_y = _compute_residual_and_jacobian( + x=x, y=y, xd=xd, yd=yd, k1=k1, k2=k2, k3=k3, p1=p1, p2=p2) + denominator = fy_x * fx_y - fx_x * fy_y + x_numerator = fx * fy_y - fy * fx_y + y_numerator = fy * fx_x - fx * fy_x + step_x = np.where( + np.abs(denominator) > eps, x_numerator / denominator, + np.zeros_like(denominator)) + step_y = np.where( + np.abs(denominator) > eps, y_numerator / denominator, + np.zeros_like(denominator)) + + x = x + step_x + y = y + step_y + + return x, y + + +class Camera: + """Class to handle camera geometry.""" + + def __init__(self, + orientation: np.ndarray, + position: np.ndarray, + focal_length: Union[np.ndarray, float], + principal_point: np.ndarray, + image_size: np.ndarray, + skew: Union[np.ndarray, float] = 0.0, + pixel_aspect_ratio: Union[np.ndarray, float] = 1.0, + radial_distortion: Optional[np.ndarray] = None, + tangential_distortion: Optional[np.ndarray] = None, + dtype=np.float32): + """Constructor for camera class.""" + if radial_distortion is None: + radial_distortion = np.array([0.0, 0.0, 0.0], dtype) + if tangential_distortion is None: + tangential_distortion = np.array([0.0, 0.0], dtype) + + self.orientation = np.array(orientation, dtype) + self.position = np.array(position, dtype) + self.focal_length = np.array(focal_length, dtype) + self.principal_point = np.array(principal_point, dtype) + self.skew = np.array(skew, dtype) + self.pixel_aspect_ratio = np.array(pixel_aspect_ratio, dtype) + self.radial_distortion = np.array(radial_distortion, dtype) + self.tangential_distortion = np.array(tangential_distortion, dtype) + self.image_size = np.array(image_size, np.uint32) + self.dtype = dtype + + @classmethod + def from_json(cls, path: PathType): + """Loads a JSON camera into memory.""" + path = GPath(path) + # with path.open('r') as fp: + with open(path, 'r') as fp: + camera_json = json.load(fp) + + # Fix old camera JSON. + if 'tangential' in camera_json: + camera_json['tangential_distortion'] = camera_json['tangential'] + + return cls( + orientation=np.asarray(camera_json['orientation']), + position=np.asarray(camera_json['position']), + focal_length=camera_json['focal_length'], + principal_point=np.asarray(camera_json['principal_point']), + skew=camera_json['skew'], + pixel_aspect_ratio=camera_json['pixel_aspect_ratio'], + radial_distortion=np.asarray(camera_json['radial_distortion']), + tangential_distortion=np.asarray(camera_json['tangential_distortion']), + image_size=np.asarray(camera_json['image_size']), + ) + + def to_json(self): + return { + k: (v.tolist() if hasattr(v, 'tolist') else v) + for k, v in self.get_parameters().items() + } + + def get_parameters(self): + return { + 'orientation': self.orientation, + 'position': self.position, + 'focal_length': self.focal_length, + 'principal_point': self.principal_point, + 'skew': self.skew, + 'pixel_aspect_ratio': self.pixel_aspect_ratio, + 'radial_distortion': self.radial_distortion, + 'tangential_distortion': self.tangential_distortion, + 'image_size': self.image_size, + } + + @property + def scale_factor_x(self): + return self.focal_length + + @property + def scale_factor_y(self): + return self.focal_length * self.pixel_aspect_ratio + + @property + def principal_point_x(self): + return self.principal_point[0] + + @property + def principal_point_y(self): + return self.principal_point[1] + + @property + def has_tangential_distortion(self): + return any(self.tangential_distortion != 0.0) + + @property + def has_radial_distortion(self): + return any(self.radial_distortion != 0.0) + + @property + def image_size_y(self): + return self.image_size[1] + + @property + def image_size_x(self): + return self.image_size[0] + + @property + def image_shape(self): + return self.image_size_y, self.image_size_x + + @property + def optical_axis(self): + return self.orientation[2, :] + + @property + def translation(self): + return -np.matmul(self.orientation, self.position) + + def pixel_to_local_rays(self, pixels: np.ndarray): + """Returns the local ray directions for the provided pixels.""" + y = ((pixels[..., 1] - self.principal_point_y) / self.scale_factor_y) + x = ((pixels[..., 0] - self.principal_point_x - y * self.skew) / + self.scale_factor_x) + + if self.has_radial_distortion or self.has_tangential_distortion: + x, y = _radial_and_tangential_undistort( + x, + y, + k1=self.radial_distortion[0], + k2=self.radial_distortion[1], + k3=self.radial_distortion[2], + p1=self.tangential_distortion[0], + p2=self.tangential_distortion[1]) + + dirs = np.stack([x, y, np.ones_like(x)], axis=-1) + return dirs / np.linalg.norm(dirs, axis=-1, keepdims=True) + + def pixels_to_rays(self, pixels: np.ndarray) -> np.ndarray: + """Returns the rays for the provided pixels. + + Args: + pixels: [A1, ..., An, 2] tensor or np.array containing 2d pixel positions. + + Returns: + An array containing the normalized ray directions in world coordinates. + """ + if pixels.shape[-1] != 2: + raise ValueError('The last dimension of pixels must be 2.') + if pixels.dtype != self.dtype: + raise ValueError(f'pixels dtype ({pixels.dtype!r}) must match camera ' + f'dtype ({self.dtype!r})') + + batch_shape = pixels.shape[:-1] + pixels = np.reshape(pixels, (-1, 2)) + + local_rays_dir = self.pixel_to_local_rays(pixels) + rays_dir = np.matmul(self.orientation.T, local_rays_dir[..., np.newaxis]) + rays_dir = np.squeeze(rays_dir, axis=-1) + + # Normalize rays. + rays_dir /= np.linalg.norm(rays_dir, axis=-1, keepdims=True) + rays_dir = rays_dir.reshape((*batch_shape, 3)) + return rays_dir + + def pixels_to_points(self, pixels: np.ndarray, depth: np.ndarray): + rays_through_pixels = self.pixels_to_rays(pixels) + cosa = np.matmul(rays_through_pixels, self.optical_axis) + points = ( + rays_through_pixels * depth[..., np.newaxis] / cosa[..., np.newaxis] + + self.position) + return points + + def points_to_local_points(self, points: np.ndarray): + translated_points = points - self.position + local_points = (np.matmul(self.orientation, translated_points.T)).T + return local_points + + def project(self, points: np.ndarray): + """Projects a 3D point (x,y,z) to a pixel position (x,y).""" + batch_shape = points.shape[:-1] + points = points.reshape((-1, 3)) + local_points = self.points_to_local_points(points) + + # Get normalized local pixel positions. + x = local_points[..., 0] / local_points[..., 2] + y = local_points[..., 1] / local_points[..., 2] + r2 = x**2 + y**2 + + # Apply radial distortion. + distortion = 1.0 + r2 * ( + self.radial_distortion[0] + r2 * + (self.radial_distortion[1] + self.radial_distortion[2] * r2)) + + # Apply tangential distortion. + x_times_y = x * y + x = ( + x * distortion + 2.0 * self.tangential_distortion[0] * x_times_y + + self.tangential_distortion[1] * (r2 + 2.0 * x**2)) + y = ( + y * distortion + 2.0 * self.tangential_distortion[1] * x_times_y + + self.tangential_distortion[0] * (r2 + 2.0 * y**2)) + + # Map the distorted ray to the image plane and return the depth. + pixel_x = self.focal_length * x + self.skew * y + self.principal_point_x + pixel_y = (self.focal_length * self.pixel_aspect_ratio * y + + self.principal_point_y) + + pixels = np.stack([pixel_x, pixel_y], axis=-1) + return pixels.reshape((*batch_shape, 2)) + + def get_pixel_centers(self): + """Returns the pixel centers.""" + xx, yy = np.meshgrid(np.arange(self.image_size_x, dtype=self.dtype), + np.arange(self.image_size_y, dtype=self.dtype)) + return np.stack([xx, yy], axis=-1) + 0.5 + + def scale(self, scale: float): + """Scales the camera.""" + if scale <= 0: + raise ValueError('scale needs to be positive.') + + new_camera = Camera( + orientation=self.orientation.copy(), + position=self.position.copy(), + focal_length=self.focal_length * scale, + principal_point=self.principal_point.copy() * scale, + skew=self.skew, + pixel_aspect_ratio=self.pixel_aspect_ratio, + radial_distortion=self.radial_distortion.copy(), + tangential_distortion=self.tangential_distortion.copy(), + image_size=np.array((int(round(self.image_size[0] * scale)), + int(round(self.image_size[1] * scale)))), + ) + return new_camera + + def look_at(self, position, look_at, up, eps=1e-6): + """Creates a copy of the camera which looks at a given point. + + Copies the provided vision_sfm camera and returns a new camera that is + positioned at `camera_position` while looking at `look_at_position`. + Camera intrinsics are copied by this method. A common value for the + up_vector is (0, 1, 0). + + Args: + position: A (3,) numpy array representing the position of the camera. + look_at: A (3,) numpy array representing the location the camera + looks at. + up: A (3,) numpy array representing the up direction, whose + projection is parallel to the y-axis of the image plane. + eps: a small number to prevent divides by zero. + + Returns: + A new camera that is copied from the original but is positioned and + looks at the provided coordinates. + + Raises: + ValueError: If the camera position and look at position are very close + to each other or if the up-vector is parallel to the requested optical + axis. + """ + + look_at_camera = self.copy() + optical_axis = look_at - position + norm = np.linalg.norm(optical_axis) + if norm < eps: + raise ValueError('The camera center and look at position are too close.') + optical_axis /= norm + + right_vector = np.cross(optical_axis, up) + norm = np.linalg.norm(right_vector) + if norm < eps: + raise ValueError('The up-vector is parallel to the optical axis.') + right_vector /= norm + + # The three directions here are orthogonal to each other and form a right + # handed coordinate system. + camera_rotation = np.identity(3) + camera_rotation[0, :] = right_vector + camera_rotation[1, :] = np.cross(optical_axis, right_vector) + camera_rotation[2, :] = optical_axis + + look_at_camera.position = position + look_at_camera.orientation = camera_rotation + return look_at_camera + + def crop_image_domain( + self, left: int = 0, right: int = 0, top: int = 0, bottom: int = 0): + """Returns a copy of the camera with adjusted image bounds. + + Args: + left: number of pixels by which to reduce (or augment, if negative) the + image domain at the associated boundary. + right: likewise. + top: likewise. + bottom: likewise. + + The crop parameters may not cause the camera image domain dimensions to + become non-positive. + + Returns: + A camera with adjusted image dimensions. The focal length is unchanged, + and the principal point is updated to preserve the original principal + axis. + """ + + crop_left_top = np.array([left, top]) + crop_right_bottom = np.array([right, bottom]) + new_resolution = self.image_size - crop_left_top - crop_right_bottom + new_principal_point = self.principal_point - crop_left_top + if np.any(new_resolution <= 0): + raise ValueError('Crop would result in non-positive image dimensions.') + + new_camera = self.copy() + new_camera.image_size = np.array([int(new_resolution[0]), + int(new_resolution[1])]) + new_camera.principal_point = np.array([new_principal_point[0], + new_principal_point[1]]) + return new_camera + + def copy(self): + return copy.deepcopy(self) + + +''' Misc +''' +mse2psnr = lambda x : -10. * torch.log10(x) +to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8) + + + +''' Checkpoint utils +''' \ No newline at end of file diff --git a/scene/video_dataset.py b/scene/video_dataset.py new file mode 100644 index 0000000..29e5203 --- /dev/null +++ b/scene/video_dataset.py @@ -0,0 +1,276 @@ +from torch.utils.data import Dataset +# from scene.cameras import Camera +import numpy as np +from utils.general_utils import PILtoTorch +from utils.graphics_utils import fov2focal, focal2fov +import torch +from utils.camera_utils import loadCam +from utils.graphics_utils import focal2fov + +from torchvision.transforms import ToTensor +from PIL import Image +import glob +from scene.cam_utils import orbit_camera +import math + +def getProjectionMatrix(znear, zfar, fovX, fovY): + tanHalfFovY = math.tan((fovY / 2)) + tanHalfFovX = math.tan((fovX / 2)) + + P = torch.zeros(4, 4) + + z_sign = 1.0 + + P[0, 0] = 1 / tanHalfFovX + P[1, 1] = 1 / tanHalfFovY + P[3, 2] = z_sign + P[2, 2] = z_sign * zfar / (zfar - znear) + P[2, 3] = -(zfar * znear) / (zfar - znear) + return P + + +class MiniCam: + def __init__(self, c2w, width, height, fovy, fovx, znear, zfar): + # c2w (pose) should be in NeRF convention. + + self.image_width = width + self.image_height = height + self.FoVy = fovy + self.FoVx = fovx + self.znear = znear + self.zfar = zfar + + w2c = np.linalg.inv(c2w) + + # rectify... + w2c[1:3, :3] *= -1 + w2c[:3, 3] *= -1 + + self.world_view_transform = torch.tensor(w2c).transpose(0, 1)#.cuda() + self.projection_matrix = ( + getProjectionMatrix( + znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy + ) + .transpose(0, 1) + # .cuda() + ) + self.full_proj_transform = self.world_view_transform @ self.projection_matrix + self.camera_center = -torch.tensor(c2w[:3, 3])#.cuda() + + +class FourDGSdataset(Dataset): + def __init__( + self, + split, + frame_num = 16, + name='panda', + rife=False, + static=False, + ): + self.split = split + # self.args = args + + # https://github.com/threestudio-project/threestudio/blob/main/configs/magic123-coarse-sd.yaml#L22 + self.radius = 2.5 + self.W = 512 + self.H = 512 + self.fovy = np.deg2rad(40) + self.fovx = np.deg2rad(40) + # self.fovy = np.deg2rad(49.1) + # self.fovx = np.deg2rad(49.1) + # align with zero123 rendering setting (ref: https://github.com/cvlab-columbia/zero123/blob/main/objaverse-rendering/scripts/blender_script.py#L61 + self.near = 0.01 + self.far = 100 + self.T = ToTensor() + self.len_pose0 = frame_num + self.name=name + self.rife=rife + self.static=static + + # load t=0 sequences + dir=f'data/{self.name}_static_rgba/' + #dir = 'data/panda_static_rgba/' # generated from new.png + t0_im_names = [dir + str(x) + '_rgba.png' for x in range(1, 101)] + # t0_im_names = glob.glob(dir + '/*.png') + self.t0_im_list = [] + # TODO: should images be RGBA when input?? + for fname in t0_im_names: + im = Image.open(fname).resize((self.W, self.H))#.convert('RGB') + # use RGBA + ww = self.T(im) + assert ww.shape[0] == 4 + ww[:3] = ww[:3] * ww[-1:] + (1 - ww[-1:]) + self.t0_im_list.append(ww) + self.t0_im_list = torch.stack(self.t0_im_list, dim=0)#.permute(0, 2, 3, 1) + + print(f"T0 images loaded {self.t0_im_list.shape}.") + + # load pose0 (canonical pose) frames + # dir = 'data/panda_im/' + # pose0_im_names = [dir + x for x in ['new.png', '1.png', '2.png', '3.png']] + #dir = 'data/panda_rgba_pose0/' + dir=f'data/{self.name}_rgba_pose0/' + if self.rife==False: + if frame_num==4: + if self.name=='panda': + frame_list=[0,12,14,15] + # elif self.name=='rose': + # frame_list=[0,6,13,22] + else: + frame_list=range(frame_num) + pose0_im_names = [dir + f'{x}.png' for x in frame_list] + # pose0_im_names = [dir + f'frame_{x}_rgba.png' for x in frame_list] + else: + if self.name=='astronaut': + frame_list= [0] + list(range(12, 27)) + elif self.name=='kitten': + frame_list= [0] + list(range(16, 23))+ list(range(24, 32)) + else: + frame_list=range(frame_num) + pose0_im_names = [dir + f'{x}.png' for x in frame_list] + #pose0_im_names = [dir + f'frame_{x}_rgba.png' for x in range(frame_num)] + else: + + dir=f'data/{self.name}_rife/' + frame_list=range(frame_num) + pose0_im_names = [dir + f'img{x}.png' for x in frame_list] + + if self.static: + dir=f'data/{self.name}_rgba_pose0/' + frame_list=range(frame_num) + pose0_im_names = [dir + f'{0}.png' for _ in frame_list] + + + # pose0_im_names = pose0_im_names[:2] + # pose0_im_names = glob.glob(dir + '/*.png') + self.pose0_im_list = [] + # TODO: should images be RGBA when input?? + for fname in pose0_im_names: + im = Image.open(fname).resize((self.W, self.H))#.convert('RGB') + ww = self.T(im) + ww[:3] = ww[:3] * ww[-1:] + (1 - ww[-1:]) + self.pose0_im_list.append(ww) + # self.pose0_im_list.append(self.T(im)) + while len(self.pose0_im_list) < self.len_pose0: + self.pose0_im_list.append(ww) + self.pose0_im_list = torch.stack(self.pose0_im_list, dim=0)#.permute(0, 2, 3, 1) + # self.pose0_im_list = self.pose0_im_list.expand(16, 3, 256, 256) + print(f"Pose0 images loaded {self.pose0_im_list.shape}") + self.max_frames = self.pose0_im_list.shape[0] + print(f"Loaded SDS Dataset. Max {self.max_frames} frames.") + + self.t0_num = self.t0_im_list.shape[0] + self.pose0_num = self.pose0_im_list.shape[0] + self.len_ = (self.t0_num) * (self.pose0_num) + + pose0_pose = orbit_camera(0, 0, self.radius) + self.pose0_cam = MiniCam( + pose0_pose, + self.H, # NOTE: order might be wrong + self.W, + self.fovy, + self.fovx, + self.near, + self.far, + ) + # we sample (pose, t) + def __getitem__(self, index): + if self.split == 'train': + t0_idx = index // self.pose0_num + pose0_idx = index % self.pose0_num + time = torch.tensor([pose0_idx]).unsqueeze(0)#.expand(1, self.W * self.H) + else: + t0_idx = index # self.t0_num // 2 + pose0_idx = 1 + time = torch.tensor([pose0_idx]).unsqueeze(0) + + # return Camera(R=R,T=T,FoVx=FovX,FoVy=FovY,image=image,gt_alpha_mask=None, + # image_name=f"{index}",uid=index,data_device=torch.device("cuda"),time=time) + out = { + # timestamp is per pixel + "time": time / self.pose0_num, + 't0': self.t0_im_list[t0_idx], + 'pose0': self.pose0_im_list[pose0_idx], + # 't0': self.t0_im_list[t0_idx].view(-1, 3), + # 'pose0': self.pose0_im_list[pose0_idx].view(-1, 3), + # 'bg_color': torch.ones((1, 3), dtype=torch.float32), + "pose0_cam": self.pose0_cam, + } + + t0_pose = orbit_camera(0, (t0_idx / self.t0_num) * 360, self.radius) + ver = 0 + hor = (t0_idx / self.t0_num) * 360 + # ver = np.random.randint(-45, 45) + # hor = np.random.randint(-180, 180) + pose = orbit_camera(0 + ver, hor, self.radius) + out['hor'] = hor + out['ver'] = ver + + cur_cam = MiniCam( + pose, + self.H, # NOTE: order might be wrong + self.W, + self.fovy, + self.fovx, + self.near, + self.far, + ) + t0_cam = MiniCam( + t0_pose, + self.H, # NOTE: order might be wrong + self.W, + self.fovy, + self.fovx, + self.near, + self.far, + ) + out['cur_cam'] = cur_cam + out['t0_cam'] = t0_cam + + # rand_seq = [t0_cam] + # start from cur_cam, generate 6 sets of offsets + # rand_seq = [cur_cam] + # ver_offset = [np.random.randint(-10, 10) for i in range(self.pose0_num - 1)] + # hor_offset = [np.random.randint(-10, 10) for i in range(self.pose0_num - 1)] + # ver_offset = np.cumsum(ver_offset) + ver + # hor_offset = np.cumsum(hor_offset) + hor + # ver_offset = np.clip(ver_offset, -15, 45) + + rand_seq = [] + ver_list = [] + hor_list = [] + # for i in range(self.pose0_num - 1): + for i in range(self.pose0_num): + ver = np.random.randint(-30, 30) + hor = np.random.randint(-180, 180) + cur_pose = orbit_camera(ver, hor, self.radius) + ver_list.append(ver) + hor_list.append(hor) + # cur_pose = orbit_camera(ver_offset[i], hor_offset[i], self.radius) + rand_seq.append(MiniCam( + cur_pose if self.split == 'train' else pose, + # cur_pose, + self.H, # NOTE: order might be wrong + self.W, + self.fovy, + self.fovx, + self.near, + self.far, + )) + out['rand_poses'] = rand_seq + out['rand_ver'] = np.array(ver_list) + out['rand_hor'] = np.array(hor_list) + # out['rand_ver'] = ver_offset + # out['rand_hor'] = hor_offset + + return out + + def __len__(self): + # we sample (pose, t) + if self.split == 'train': + return self.len_ + if self.split == 'test': + return self.pose0_num + # return self.t0_num + if self.split == 'video': + return 100 diff --git a/simple-knn/ext.cpp b/simple-knn/ext.cpp new file mode 100644 index 0000000..ae6cefe --- /dev/null +++ b/simple-knn/ext.cpp @@ -0,0 +1,17 @@ +/* + * Copyright (C) 2023, Inria + * GRAPHDECO research group, https://team.inria.fr/graphdeco + * All rights reserved. + * + * This software is free for non-commercial, research and evaluation use + * under the terms of the LICENSE.md file. + * + * For inquiries contact george.drettakis@inria.fr + */ + +#include +#include "spatial.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("distCUDA2", &distCUDA2); +} diff --git a/simple-knn/setup.py b/simple-knn/setup.py new file mode 100644 index 0000000..580d2bd --- /dev/null +++ b/simple-knn/setup.py @@ -0,0 +1,35 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +from setuptools import setup +from torch.utils.cpp_extension import CUDAExtension, BuildExtension +import os + +cxx_compiler_flags = [] + +if os.name == 'nt': + cxx_compiler_flags.append("/wd4624") + +setup( + name="simple_knn", + ext_modules=[ + CUDAExtension( + name="simple_knn._C", + sources=[ + "spatial.cu", + "simple_knn.cu", + "ext.cpp"], + extra_compile_args={"nvcc": [], "cxx": cxx_compiler_flags}) + ], + cmdclass={ + 'build_ext': BuildExtension + } +) diff --git a/simple-knn/simple_knn.cu b/simple-knn/simple_knn.cu new file mode 100644 index 0000000..e72e4c9 --- /dev/null +++ b/simple-knn/simple_knn.cu @@ -0,0 +1,221 @@ +/* + * Copyright (C) 2023, Inria + * GRAPHDECO research group, https://team.inria.fr/graphdeco + * All rights reserved. + * + * This software is free for non-commercial, research and evaluation use + * under the terms of the LICENSE.md file. + * + * For inquiries contact george.drettakis@inria.fr + */ + +#define BOX_SIZE 1024 + +#include "cuda_runtime.h" +#include "device_launch_parameters.h" +#include "simple_knn.h" +#include +#include +#include +#include +#include +#include +#define __CUDACC__ +#include +#include + +namespace cg = cooperative_groups; + +struct CustomMin +{ + __device__ __forceinline__ + float3 operator()(const float3& a, const float3& b) const { + return { min(a.x, b.x), min(a.y, b.y), min(a.z, b.z) }; + } +}; + +struct CustomMax +{ + __device__ __forceinline__ + float3 operator()(const float3& a, const float3& b) const { + return { max(a.x, b.x), max(a.y, b.y), max(a.z, b.z) }; + } +}; + +__host__ __device__ uint32_t prepMorton(uint32_t x) +{ + x = (x | (x << 16)) & 0x030000FF; + x = (x | (x << 8)) & 0x0300F00F; + x = (x | (x << 4)) & 0x030C30C3; + x = (x | (x << 2)) & 0x09249249; + return x; +} + +__host__ __device__ uint32_t coord2Morton(float3 coord, float3 minn, float3 maxx) +{ + uint32_t x = prepMorton(((coord.x - minn.x) / (maxx.x - minn.x)) * ((1 << 10) - 1)); + uint32_t y = prepMorton(((coord.y - minn.y) / (maxx.y - minn.y)) * ((1 << 10) - 1)); + uint32_t z = prepMorton(((coord.z - minn.z) / (maxx.z - minn.z)) * ((1 << 10) - 1)); + + return x | (y << 1) | (z << 2); +} + +__global__ void coord2Morton(int P, const float3* points, float3 minn, float3 maxx, uint32_t* codes) +{ + auto idx = cg::this_grid().thread_rank(); + if (idx >= P) + return; + + codes[idx] = coord2Morton(points[idx], minn, maxx); +} + +struct MinMax +{ + float3 minn; + float3 maxx; +}; + +__global__ void boxMinMax(uint32_t P, float3* points, uint32_t* indices, MinMax* boxes) +{ + auto idx = cg::this_grid().thread_rank(); + + MinMax me; + if (idx < P) + { + me.minn = points[indices[idx]]; + me.maxx = points[indices[idx]]; + } + else + { + me.minn = { FLT_MAX, FLT_MAX, FLT_MAX }; + me.maxx = { -FLT_MAX,-FLT_MAX,-FLT_MAX }; + } + + __shared__ MinMax redResult[BOX_SIZE]; + + for (int off = BOX_SIZE / 2; off >= 1; off /= 2) + { + if (threadIdx.x < 2 * off) + redResult[threadIdx.x] = me; + __syncthreads(); + + if (threadIdx.x < off) + { + MinMax other = redResult[threadIdx.x + off]; + me.minn.x = min(me.minn.x, other.minn.x); + me.minn.y = min(me.minn.y, other.minn.y); + me.minn.z = min(me.minn.z, other.minn.z); + me.maxx.x = max(me.maxx.x, other.maxx.x); + me.maxx.y = max(me.maxx.y, other.maxx.y); + me.maxx.z = max(me.maxx.z, other.maxx.z); + } + __syncthreads(); + } + + if (threadIdx.x == 0) + boxes[blockIdx.x] = me; +} + +__device__ __host__ float distBoxPoint(const MinMax& box, const float3& p) +{ + float3 diff = { 0, 0, 0 }; + if (p.x < box.minn.x || p.x > box.maxx.x) + diff.x = min(abs(p.x - box.minn.x), abs(p.x - box.maxx.x)); + if (p.y < box.minn.y || p.y > box.maxx.y) + diff.y = min(abs(p.y - box.minn.y), abs(p.y - box.maxx.y)); + if (p.z < box.minn.z || p.z > box.maxx.z) + diff.z = min(abs(p.z - box.minn.z), abs(p.z - box.maxx.z)); + return diff.x * diff.x + diff.y * diff.y + diff.z * diff.z; +} + +template +__device__ void updateKBest(const float3& ref, const float3& point, float* knn) +{ + float3 d = { point.x - ref.x, point.y - ref.y, point.z - ref.z }; + float dist = d.x * d.x + d.y * d.y + d.z * d.z; + for (int j = 0; j < K; j++) + { + if (knn[j] > dist) + { + float t = knn[j]; + knn[j] = dist; + dist = t; + } + } +} + +__global__ void boxMeanDist(uint32_t P, float3* points, uint32_t* indices, MinMax* boxes, float* dists) +{ + int idx = cg::this_grid().thread_rank(); + if (idx >= P) + return; + + float3 point = points[indices[idx]]; + float best[3] = { FLT_MAX, FLT_MAX, FLT_MAX }; + + for (int i = max(0, idx - 3); i <= min(P - 1, idx + 3); i++) + { + if (i == idx) + continue; + updateKBest<3>(point, points[indices[i]], best); + } + + float reject = best[2]; + best[0] = FLT_MAX; + best[1] = FLT_MAX; + best[2] = FLT_MAX; + + for (int b = 0; b < (P + BOX_SIZE - 1) / BOX_SIZE; b++) + { + MinMax box = boxes[b]; + float dist = distBoxPoint(box, point); + if (dist > reject || dist > best[2]) + continue; + + for (int i = b * BOX_SIZE; i < min(P, (b + 1) * BOX_SIZE); i++) + { + if (i == idx) + continue; + updateKBest<3>(point, points[indices[i]], best); + } + } + dists[indices[idx]] = (best[0] + best[1] + best[2]) / 3.0f; +} + +void SimpleKNN::knn(int P, float3* points, float* meanDists) +{ + float3* result; + cudaMalloc(&result, sizeof(float3)); + size_t temp_storage_bytes; + + float3 init = { 0, 0, 0 }, minn, maxx; + + cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, points, result, P, CustomMin(), init); + thrust::device_vector temp_storage(temp_storage_bytes); + + cub::DeviceReduce::Reduce(temp_storage.data().get(), temp_storage_bytes, points, result, P, CustomMin(), init); + cudaMemcpy(&minn, result, sizeof(float3), cudaMemcpyDeviceToHost); + + cub::DeviceReduce::Reduce(temp_storage.data().get(), temp_storage_bytes, points, result, P, CustomMax(), init); + cudaMemcpy(&maxx, result, sizeof(float3), cudaMemcpyDeviceToHost); + + thrust::device_vector morton(P); + thrust::device_vector morton_sorted(P); + coord2Morton << <(P + 255) / 256, 256 >> > (P, points, minn, maxx, morton.data().get()); + + thrust::device_vector indices(P); + thrust::sequence(indices.begin(), indices.end()); + thrust::device_vector indices_sorted(P); + + cub::DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, morton.data().get(), morton_sorted.data().get(), indices.data().get(), indices_sorted.data().get(), P); + temp_storage.resize(temp_storage_bytes); + + cub::DeviceRadixSort::SortPairs(temp_storage.data().get(), temp_storage_bytes, morton.data().get(), morton_sorted.data().get(), indices.data().get(), indices_sorted.data().get(), P); + + uint32_t num_boxes = (P + BOX_SIZE - 1) / BOX_SIZE; + thrust::device_vector boxes(num_boxes); + boxMinMax << > > (P, points, indices_sorted.data().get(), boxes.data().get()); + boxMeanDist << > > (P, points, indices_sorted.data().get(), boxes.data().get(), meanDists); + + cudaFree(result); +} \ No newline at end of file diff --git a/simple-knn/simple_knn.egg-info/PKG-INFO b/simple-knn/simple_knn.egg-info/PKG-INFO new file mode 100644 index 0000000..bb654da --- /dev/null +++ b/simple-knn/simple_knn.egg-info/PKG-INFO @@ -0,0 +1,3 @@ +Metadata-Version: 2.1 +Name: simple-knn +Version: 0.0.0 diff --git a/simple-knn/simple_knn.egg-info/SOURCES.txt b/simple-knn/simple_knn.egg-info/SOURCES.txt new file mode 100644 index 0000000..6d19758 --- /dev/null +++ b/simple-knn/simple_knn.egg-info/SOURCES.txt @@ -0,0 +1,8 @@ +ext.cpp +setup.py +simple_knn.cu +spatial.cu +simple_knn.egg-info/PKG-INFO +simple_knn.egg-info/SOURCES.txt +simple_knn.egg-info/dependency_links.txt +simple_knn.egg-info/top_level.txt \ No newline at end of file diff --git a/simple-knn/simple_knn.egg-info/dependency_links.txt b/simple-knn/simple_knn.egg-info/dependency_links.txt new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/simple-knn/simple_knn.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/simple-knn/simple_knn.egg-info/top_level.txt b/simple-knn/simple_knn.egg-info/top_level.txt new file mode 100644 index 0000000..bae7cb8 --- /dev/null +++ b/simple-knn/simple_knn.egg-info/top_level.txt @@ -0,0 +1 @@ +simple_knn diff --git a/simple-knn/simple_knn.h b/simple-knn/simple_knn.h new file mode 100644 index 0000000..3fcfdb8 --- /dev/null +++ b/simple-knn/simple_knn.h @@ -0,0 +1,21 @@ +/* + * Copyright (C) 2023, Inria + * GRAPHDECO research group, https://team.inria.fr/graphdeco + * All rights reserved. + * + * This software is free for non-commercial, research and evaluation use + * under the terms of the LICENSE.md file. + * + * For inquiries contact george.drettakis@inria.fr + */ + +#ifndef SIMPLEKNN_H_INCLUDED +#define SIMPLEKNN_H_INCLUDED + +class SimpleKNN +{ +public: + static void knn(int P, float3* points, float* meanDists); +}; + +#endif \ No newline at end of file diff --git a/simple-knn/simple_knn/.gitkeep b/simple-knn/simple_knn/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/simple-knn/spatial.cu b/simple-knn/spatial.cu new file mode 100644 index 0000000..1a6a654 --- /dev/null +++ b/simple-knn/spatial.cu @@ -0,0 +1,26 @@ +/* + * Copyright (C) 2023, Inria + * GRAPHDECO research group, https://team.inria.fr/graphdeco + * All rights reserved. + * + * This software is free for non-commercial, research and evaluation use + * under the terms of the LICENSE.md file. + * + * For inquiries contact george.drettakis@inria.fr + */ + +#include "spatial.h" +#include "simple_knn.h" + +torch::Tensor +distCUDA2(const torch::Tensor& points) +{ + const int P = points.size(0); + + auto float_opts = points.options().dtype(torch::kFloat32); + torch::Tensor means = torch::full({P}, 0.0, float_opts); + + SimpleKNN::knn(P, (float3*)points.contiguous().data(), means.contiguous().data()); + + return means; +} \ No newline at end of file diff --git a/simple-knn/spatial.h b/simple-knn/spatial.h new file mode 100644 index 0000000..280c953 --- /dev/null +++ b/simple-knn/spatial.h @@ -0,0 +1,14 @@ +/* + * Copyright (C) 2023, Inria + * GRAPHDECO research group, https://team.inria.fr/graphdeco + * All rights reserved. + * + * This software is free for non-commercial, research and evaluation use + * under the terms of the LICENSE.md file. + * + * For inquiries contact george.drettakis@inria.fr + */ + +#include + +torch::Tensor distCUDA2(const torch::Tensor& points); \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000..3ccb995 --- /dev/null +++ b/train.py @@ -0,0 +1,486 @@ + + +import numpy as np +import random +import os +import torch + +from random import randint +from utils.loss_utils import l1_loss, ssim, l2_loss, lpips_loss +from gaussian_renderer import render, network_gui +import sys +from scene import Scene, GaussianModel +from utils.general_utils import safe_state +import uuid +from tqdm import tqdm +from utils.image_utils import psnr +from argparse import ArgumentParser, Namespace +from arguments import ModelParams, PipelineParams, OptimizationParams, ModelHiddenParams +from torch.utils.data import DataLoader +from utils.timer import Timer + +import lpips +import gc +from torchvision import transforms as T +from utils.scene_utils import render_training_image +from time import time +to8b = lambda x : (255*np.clip(x.cpu().numpy(),0,1)).astype(np.uint8) + +try: + from torch.utils.tensorboard import SummaryWriter + TENSORBOARD_FOUND = True +except ImportError: + TENSORBOARD_FOUND = False + +from guidance.zero123_utils import Zero123 + +from PIL import Image +from torchvision.transforms import ToTensor +from kaolin.metrics.pointcloud import chamfer_distance +from plyfile import PlyData + +def scene_reconstruction(dataset, opt, hyper, pipe, testing_iterations, saving_iterations, + checkpoint_iterations, checkpoint, debug_from, + gaussians, scene, stage, tb_writer, train_iter,timer, args): + first_iter = 0 + + + torch.cuda.empty_cache() + gc.collect() + print(f'Start training of stage {stage}: ') + zero123 = Zero123('cuda') + dir=f'data/{args.name}_pose0/' + # dir=f'data/{args.name}_rgba_pose0/' + # if args.i2v: + # frame_list=range(1, 1 + args.frame_num) + # e1lse: + frame_list = range(args.frame_num) + pose0_im_names = [dir + f'{x}.png' for x in frame_list] + if not os.path.exists(pose0_im_names[0]): # check 0 index + pose0_im_names = pose0_im_names[1:] + [dir + f'{args.frame_num}.png'] # use 1 index + print('pose0_im_names:',pose0_im_names) + T = ToTensor() + im_list = [] + for fname in pose0_im_names: + im = Image.open(fname).resize((512, 512)) + ww = T(im) + assert ww.shape[0] == 4 + ww[:3] = ww[:3] * ww[-1:] + (1 - ww[-1:]) + im_list.append((ww)) + pose0_im = torch.stack(im_list).cuda().detach() + print('pose0_im shape:',pose0_im.shape) + pose0_embed1, pose0_embed2 = zero123.get_img_embeds_pil(pose0_im[:,:3, :, :] , pose0_im[:,:3, :, :] ) + print('pose0_embed1 shape:',pose0_embed1.shape) + print('pose0_embed2 shape:',pose0_embed2.shape) + stage_ = ['static', 'coarse', 'fine'] + train_iter_ = [opt.static_iterations, opt.coarse_iterations, opt.iterations] + bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] + background = torch.tensor(bg_color, dtype=torch.float32, device="cuda", requires_grad=False) + lpips_model = lpips.LPIPS(net="alex").cuda() + for cur_stage, train_iter in zip(stage_, train_iter_): + gaussians.training_setup(opt) + if checkpoint: + (model_params, first_iter) = torch.load(checkpoint) + gaussians.restore(model_params, opt) + iter_start = torch.cuda.Event(enable_timing = True) + iter_end = torch.cuda.Event(enable_timing = True) + viewpoint_stack = None + ema_loss_for_log = 0.0 + ema_psnr_for_log = 0.0 + + final_iter = train_iter + + progress_bar = tqdm(range(first_iter, final_iter), desc=f"[{args.expname}] Training progress") + video_cams = scene.getVideoCameras() + for iteration in range(first_iter, final_iter+1): + stage = cur_stage + loss_weight = 1 + + iter_start.record() + gaussians.update_learning_rate(iteration) + if not viewpoint_stack: + viewpoint_stack = scene.getTrainCameras() + viewpoint_stack_loader = DataLoader(viewpoint_stack, batch_size=1,shuffle=True,num_workers=4,collate_fn=list) + frame_num = viewpoint_stack.pose0_num + + loader = iter(viewpoint_stack_loader) + if True: + try: + data = next(loader) + except StopIteration: + print("reset dataloader") + batch_size = 1 + loader = iter(viewpoint_stack_loader) + if (iteration - 1) == debug_from: + pipe.debug = True + images = [] + gt_images = [] + radii_list = [] + visibility_filter_list = [] + viewspace_point_tensor_list = [] + dx = [] + ds = [] + dr = [] + do = [] + dc=[] + out_pts = [] + if stage in ['static']: + viewpoint_cam = data[0]['t0_cam'] + if stage == 'static': + render_pkg = render(viewpoint_cam, gaussians, pipe, background, stage=stage, time=0) + else: + raise NotImplementedError + image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"] + rgba = torch.cat([image, render_pkg['alpha']], dim=0) + images.append(rgba.unsqueeze(0)) + # gt_image = data[0]['gtim'].to(image.device) + gt_image = data[0]['t0'].to(image.device) + if data[0]['t0_idx'] == 0: + loss_weight = 10 + # gt_image = data[0]['gtim'].to(image.device) + gt_images.append(gt_image.unsqueeze(0)) + radii_list.append(radii.unsqueeze(0)) + visibility_filter_list.append(visibility_filter.unsqueeze(0)) + viewspace_point_tensor_list.append(viewspace_point_tensor) + elif stage == 'coarse': + for i in range(1): + time = data[0]['time'] + viewpoint_cam = data[0]['t0_cam'] + render_pkg = render(viewpoint_cam, gaussians, pipe, background, stage=stage, time=time, return_pts=True) + image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"] + means3D = render_pkg['means3D'] + rgba = torch.cat([image, render_pkg['alpha']], dim=0) + images.append(rgba.unsqueeze(0)) + gt_image = data[0]['gtim'].to(image.device) + gt_images.append(gt_image.unsqueeze(0)) + if data[0]['t0_idx'] == 0: + loss_weight = 10 + out_pts.append(means3D) + if 'dx' in render_pkg: + dx.append(render_pkg['dx']) + ds.append(render_pkg['ds']) + dr.append(render_pkg['dr']) + do.append(render_pkg['do']) + dc.append(render_pkg['dc']) + radii_list.append(radii.unsqueeze(0)) + visibility_filter_list.append(visibility_filter.unsqueeze(0)) + viewspace_point_tensor_list.append(viewspace_point_tensor) + else: + rand_seed=np.random.random() + if rand_seed< args.fine_rand_rate: + viewpoint_cam = data[0]['rand_poses'] + fps = 1 / frame_num + set_t0_frame0 = True + t0 = 0 + if frame_num > 16: + sds_idx_list = np.random.choice(range(frame_num), 16) + else: + sds_idx_list = range(frame_num) + # for i in range(frame_num): + for i in sds_idx_list: + time = torch.tensor([t0 + i * fps]).unsqueeze(0).float() + render_pkg = render(viewpoint_cam[i], gaussians, pipe, background, stage=stage, time=time) + image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"] + fg_mask = render_pkg['alpha'] + rgba = torch.cat([image, fg_mask], dim=0) + images.append(rgba.unsqueeze(0)) + if 'dx' in render_pkg: + dx.append(render_pkg['dx']) + ds.append(render_pkg['ds']) + dr.append(render_pkg['dr']) + do.append(render_pkg['do']) + dc.append(render_pkg['dc']) + radii_list.append(radii.unsqueeze(0)) + visibility_filter_list.append(visibility_filter.unsqueeze(0)) + viewspace_point_tensor_list.append(viewspace_point_tensor) + else: + for i in range(1): + time = data[0]['time'] + viewpoint_cam = data[0]['t0_cam'] + render_pkg = render(viewpoint_cam, gaussians, pipe, background, stage=stage, time=time, return_pts=True) + image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"] + means3D = render_pkg['means3D'] + rgba = torch.cat([image, render_pkg['alpha']], dim=0) + images.append(rgba.unsqueeze(0)) + gt_image = data[0]['gtim'].to(image.device) + gt_images.append(gt_image.unsqueeze(0)) + if data[0]['t0_idx'] == 0: + loss_weight = 10 + out_pts.append(means3D) + if 'dx' in render_pkg: + dx.append(render_pkg['dx']) + ds.append(render_pkg['ds']) + dr.append(render_pkg['dr']) + do.append(render_pkg['do']) + dc.append(render_pkg['dc']) + radii_list.append(radii.unsqueeze(0)) + visibility_filter_list.append(visibility_filter.unsqueeze(0)) + viewspace_point_tensor_list.append(viewspace_point_tensor) + radii = torch.cat(radii_list,0).max(dim=0).values + visibility_filter = torch.cat(visibility_filter_list).any(dim=0) + image_tensor = torch.cat(images,0) + if len(out_pts): + out_pts = torch.stack(out_pts,0) + use_zero123 = True + use_animate = False + + if len(gt_images): + gt_image_tensor = torch.cat(gt_images,0) + if stage in ['static']: + Ll1 = l1_loss(image_tensor, gt_image_tensor) + tb_writer.add_scalar(f'{stage}/loss_recon', Ll1.item(), iteration) + lpipsloss = lpips_loss(image_tensor, gt_image_tensor,lpips_model) + tb_writer.add_scalar(f'{stage}/loss_lpips', lpipsloss.item(), iteration) + loss = Ll1 * 10 + lpipsloss * 20 + + + elif stage == 'coarse': + Ll1 = l1_loss(image_tensor, gt_image_tensor) + tb_writer.add_scalar(f'{stage}/loss_recon', Ll1.item(), iteration) + lpipsloss = lpips_loss(image_tensor, gt_image_tensor,lpips_model) + tb_writer.add_scalar(f'{stage}/loss_lpips', lpipsloss.item(), iteration) + loss = Ll1 * 10 + lpipsloss * 20 + + + time_now=int(time.item()*args.frame_num) + else: + if rand_seed < args.fine_rand_rate: + if use_zero123: + loss = 0 + loss_zero123_total=0 + + for idx in sds_idx_list: + # for idx in range(0, frame_num): + cur_emb = (pose0_embed1[idx].unsqueeze(0), pose0_embed2[idx]) + loss_zero123, im = zero123.train_step(image_tensor[idx:idx+1,:3, :, :], data[0]['rand_ver'][idx], data[0]['rand_hor'][idx], 0, cur_emb) + loss_zero123_total += loss_zero123 + tb_writer.add_scalar(f'{stage}/loss_zero123', loss_zero123_total.item(), iteration) + loss += loss_zero123_total / len(sds_idx_list) * args.lambda_zero123 + else: + Ll1 = l1_loss(image_tensor, gt_image_tensor) + tb_writer.add_scalar(f'{stage}/loss_recon', Ll1.item(), iteration) + lpipsloss = lpips_loss(image_tensor,gt_image_tensor,lpips_model) + tb_writer.add_scalar(f'{stage}/loss_lpips', lpipsloss.item(), iteration) + loss = Ll1 * 10 + lpipsloss * 20 + + time_now=int(time.item()*args.frame_num) + + loss = loss * loss_weight + + if stage == "fine" and hyper.time_smoothness_weight != 0: + tv_loss = gaussians.compute_regulation(hyper.time_smoothness_weight, hyper.plane_tv_weight, hyper.l1_time_planes) + loss += tv_loss + tb_writer.add_scalar(f'{stage}/loss_tv', tv_loss.item(), iteration) + if opt.lambda_dssim != 0 and len(gt_images) != 0: + ssim_loss = 1 - ssim(image_tensor,gt_image_tensor) + loss += opt.lambda_dssim * (ssim_loss) + tb_writer.add_scalar(f'{stage}/loss_ssim', ssim_loss.item(), iteration) + if opt.lambda_lpips != 0 and len(gt_images) != 0: + lpipsloss = lpips_loss(image_tensor,gt_image_tensor,lpips_model) + loss += opt.lambda_lpips * lpipsloss + tb_writer.add_scalar(f'{stage}/loss_lpips', lpipsloss.item(), iteration) + # if len(dx)!=1 and len(dx)!=0: + # loss_dx_tv = torch.stack([x if i else x.detach() for i, x in enumerate(dx[:-1])]) - torch.stack(dx[1:]) + loss.backward() + viewspace_point_tensor_grad = torch.zeros_like(viewspace_point_tensor) + for idx in range(0, len(viewspace_point_tensor_list)): + if viewspace_point_tensor_list[idx].grad is not None: + viewspace_point_tensor_grad = viewspace_point_tensor_grad + viewspace_point_tensor_list[idx].grad + iter_end.record() + with torch.no_grad(): + ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log + + total_point = gaussians._xyz.shape[0] + if iteration % 10 == 0: + progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}", + "point":f"{total_point}"}) + progress_bar.update(10) + if iteration == opt.iterations: + progress_bar.close() + timer.pause() + training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, pipe, background, stage) + if (iteration in saving_iterations): + print("\n[ITER {}] Saving Gaussians".format(iteration)) + scene.save(iteration, stage) + if dataset.render_process: + if (iteration < 1000 and iteration % 10 == 1) \ + or (iteration < 3000 and iteration % 50 == 1) \ + or (iteration < 10000 and iteration % 100 == 1) \ + or (iteration < 60000 and iteration % 100 ==1): + render_training_image(scene, gaussians, video_cams, render, pipe, background, stage, iteration-1,timer.get_elapsed_time()) + timer.start() + if stage == 'static' and iteration < opt.densify_until_iter: + gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter]) + gaussians.add_densification_stats(viewspace_point_tensor_grad, visibility_filter) + if stage in ['static', "coarse"]: + opacity_threshold = opt.opacity_threshold_coarse + densify_threshold = opt.densify_grad_threshold_coarse + else: + opacity_threshold = opt.opacity_threshold_fine_init - iteration*(opt.opacity_threshold_fine_init - opt.opacity_threshold_fine_after)/(opt.densify_until_iter) + densify_threshold = opt.densify_grad_threshold_fine_init - iteration*(opt.densify_grad_threshold_fine_init - opt.densify_grad_threshold_after)/(opt.densify_until_iter) + if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0 : + size_threshold = 20 if iteration > opt.opacity_reset_interval else None + print('>>>>>>> Now densify') + gaussians.densify(densify_threshold, opacity_threshold, scene.cameras_extent, size_threshold) + pruning_interval = opt.pruning_interval if stage != 'fine' else opt.pruning_interval_fine + if iteration > opt.pruning_from_iter and iteration % pruning_interval == 0: + print('>>>>>>> Now pruning', opt.opacity_reset_interval) + size_threshold = 20 if iteration > opt.opacity_reset_interval else None + + gaussians.prune(densify_threshold, opacity_threshold, scene.cameras_extent, size_threshold) + if iteration < opt.iterations: + gaussians.optimizer.step() + gaussians.optimizer.zero_grad(set_to_none = True) + if (iteration in checkpoint_iterations): + print("\n[ITER {}] Saving Checkpoint".format(iteration)) + torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth") + +def training(dataset, hyper, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from, expname, args): + tb_writer = prepare_output_and_logger(expname) + gaussians = GaussianModel(dataset.sh_degree, hyper) + dataset.model_path = args.model_path + timer = Timer() + scene = Scene(dataset, gaussians,load_coarse=None) + timer.start() + scene_reconstruction(dataset, opt, hyper, pipe, testing_iterations, saving_iterations, + checkpoint_iterations, checkpoint, debug_from, + gaussians, scene, "coarse", tb_writer, opt.coarse_iterations,timer, args) + +from datetime import datetime + +def prepare_output_and_logger(expname): + if not args.model_path: + unique_str = str(datetime.today().strftime('%Y-%m-%d')) + '/' + expname + '_' + datetime.today().strftime('%H:%M:%S') + args.model_path = os.path.join("./output/", unique_str) + print("Output folder: {}".format(args.model_path)) + os.makedirs(args.model_path, exist_ok = True) + with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f: + cfg_log_f.write(str(Namespace(**vars(args)))) + tb_writer = None + if TENSORBOARD_FOUND: + tb_writer = SummaryWriter(args.model_path) + else: + print("Tensorboard not available: not logging progress") + return tb_writer + +def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc, pipe, bg, stage): + if tb_writer: + tb_writer.add_scalar(f'{stage}/train_loss_patches/l1_loss', Ll1.item(), iteration) + tb_writer.add_scalar(f'{stage}/train_loss_patchestotal_loss', loss.item(), iteration) + tb_writer.add_scalar(f'{stage}/iter_time', elapsed, iteration) + ww = iteration if stage == 'static' else iteration + if iteration % 100 == 0 and ww in testing_iterations: + torch.cuda.empty_cache() + train_set = scene.getTrainCameras() + validation_configs = [{'name': 'train', 'cameras' : [train_set[idx % len(train_set)] for idx in range(10, 5000, 299)]}] + for config in validation_configs: + if config['cameras'] and len(config['cameras']) > 0: + l1_test = 0.0 + psnr_test = 0.0 + ti = (torch.tensor([0]).unsqueeze(0)) + cam_li = config['cameras'][0]['rand_poses'] + im_li = [] + num = len(cam_li) + for tii in range(num): + if stage == 'static': + ti = (torch.tensor([tii * 0]).unsqueeze(0).cuda()) + else: + ti = (torch.tensor([tii / num]).unsqueeze(0).cuda()) + viewpoint = cam_li[tii] + image = torch.clamp(renderFunc(viewpoint, scene.gaussians,stage=stage, pipe=pipe, bg_color=bg, time=ti)["render"], 0.0, 1.0) + im_li.append(image) + ww = len(im_li) // 2 + r1 = torch.cat(im_li[:ww], dim=-1) + r2 = torch.cat(im_li[ww:], dim=-1) + im_li = torch.cat([r1, r2], dim=-2) + if tb_writer: + tb_writer.add_image(f"rand_seq/{stage}", im_li, global_step=iteration) + for idx, data in enumerate(config['cameras']): + if stage == 'static': + ti = (torch.tensor([0]).unsqueeze(0).cuda()) + viewpoint = data['t0_cam'] + else: + ti = data['time'] + viewpoint = data['t0_cam'] + image = torch.clamp(renderFunc(viewpoint, scene.gaussians,stage=stage, pipe=pipe, bg_color=bg, time=ti)["render"], 0.0, 1.0) + if stage == 'static': + gt_image = data['gtim'][:3].to(image.device) + elif stage == 'coarse': + gt_image = data['gtim'][:3].to(image.device) + else: + gt_image = data['gtim'][:3].to(image.device) + if tb_writer and (idx < 5): + tb_writer.add_images(stage + "/{}/render".format(idx), image[None], global_step=iteration) + tb_writer.add_images(stage + "/{}/gt".format(idx), gt_image[None], global_step=iteration) + l1_test += l1_loss(image, gt_image).mean().double() + psnr_test += psnr(image, gt_image).mean().double() + psnr_test /= len(config['cameras']) + l1_test /= len(config['cameras']) + print("\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(iteration, config['name'], l1_test, psnr_test)) + if tb_writer: + tb_writer.add_scalar(stage + "/"+config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration) + tb_writer.add_scalar(stage+"/"+config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration) + if tb_writer: + tb_writer.add_histogram(f"{stage}/scene/opacity_histogram", scene.gaussians.get_opacity, iteration) + tb_writer.add_scalar(f'{stage}/total_points', scene.gaussians.get_xyz.shape[0], iteration) + tb_writer.add_scalar(f'{stage}/deformation_rate', scene.gaussians._deformation_table.sum()/scene.gaussians.get_xyz.shape[0], iteration) + tb_writer.add_histogram(f"{stage}/scene/motion_histogram", scene.gaussians._deformation_accum.mean(dim=-1)/100, iteration,max_bins=500) + torch.cuda.empty_cache() +def setup_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.deterministic = True +if __name__ == "__main__": + torch.cuda.empty_cache() + parser = ArgumentParser(description="Training script parameters") + setup_seed(6666) + lp = ModelParams(parser) + op = OptimizationParams(parser) + pp = PipelineParams(parser) + hp = ModelHiddenParams(parser) + parser.add_argument('--ip', type=str, default="127.0.0.1") + parser.add_argument('--port', type=int, default=6009) + parser.add_argument('--debug_from', type=int, default=-1) + parser.add_argument('--detect_anomaly', action='store_true', default=False) + parser.add_argument("--test_iterations", nargs="+", type=int, default=[i*50 for i in range(0,300)]) + parser.add_argument("--save_iterations", nargs="+", type=int, default=[2000,2500, 3000, 5000, 7_000, 8000, 9000, 14000, 20000, 30_000,45000,60000]) + parser.add_argument("--quiet", action="store_true") + parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) + parser.add_argument("--start_checkpoint", type=str, default = None) + parser.add_argument('-e', "--expname", type=str, default = "") + parser.add_argument("--configs", type=str, default = "arguments/ours/i2v_xdj.py") + parser.add_argument("--yyypath", type=str, default = "") + parser.add_argument("--t0_frame0_rate", type=float, default = 1) + parser.add_argument("--name_override", type=str, default="") + parser.add_argument("--sds_ratio_override", type=float, default=-1) + parser.add_argument("--sds_weight_override", type=float, default=-1) + parser.add_argument("--iteration", default=-1, type=int) + args = parser.parse_args(sys.argv[1:]) + args.save_iterations.append(args.iterations - 1) + if args.configs: + import mmcv + from utils.params_utils import merge_hparams + config = mmcv.Config.fromfile(args.configs) + args = merge_hparams(args, config) + if args.name_override != '': + args.name = args.name_override + if args.sds_ratio_override != -1: + args.fine_rand_rate = args.sds_ratio_override + if args.sds_weight_override != -1: + args.lambda_zero123 = args.sds_weight_override + # print(args.name) + print("Optimizing " + args.model_path) + safe_state(args.quiet) + torch.autograd.set_detect_anomaly(args.detect_anomaly) + timer1 = Timer() + timer1.start() + print('Configs: ', args) + training(lp.extract(args), hp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from, args.expname, args) + print("\nTraining complete.") + print('training time:',timer1.get_elapsed_time()) + from render import render_sets + + render_sets(lp.extract(args), hp.extract(args), op.extract(args), args.iterations, pp.extract(args), skip_train=True, skip_test=True, skip_video=False, multiview_video=True) + print("\Rendering complete.") diff --git a/utils/camera_utils.py b/utils/camera_utils.py new file mode 100644 index 0000000..4a23c1e --- /dev/null +++ b/utils/camera_utils.py @@ -0,0 +1,65 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +from scene.cameras import Camera +import numpy as np +from utils.general_utils import PILtoTorch +from utils.graphics_utils import fov2focal + +WARNED = False + +def loadCam(args, id, cam_info, resolution_scale): + + + # resized_image_rgb = PILtoTorch(cam_info.image, resolution) + + # gt_image = resized_image_rgb[:3, ...] + # loaded_mask = None + + # if resized_image_rgb.shape[1] == 4: + # loaded_mask = resized_image_rgb[3:4, ...] + + return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, + FoVx=cam_info.FovX, FoVy=cam_info.FovY, + image=cam_info.image, gt_alpha_mask=None, + image_name=cam_info.image_name, uid=id, data_device=args.data_device, + time = cam_info.time, +) + +def cameraList_from_camInfos(cam_infos, resolution_scale, args): + camera_list = [] + + for id, c in enumerate(cam_infos): + camera_list.append(loadCam(args, id, c, resolution_scale)) + + return camera_list + +def camera_to_JSON(id, camera : Camera): + Rt = np.zeros((4, 4)) + Rt[:3, :3] = camera.R.transpose() + Rt[:3, 3] = camera.T + Rt[3, 3] = 1.0 + + W2C = np.linalg.inv(Rt) + pos = W2C[:3, 3] + rot = W2C[:3, :3] + serializable_array_2d = [x.tolist() for x in rot] + camera_entry = { + 'id' : id, + 'img_name' : camera.image_name, + 'width' : camera.width, + 'height' : camera.height, + 'position': pos.tolist(), + 'rotation': serializable_array_2d, + 'fy' : fov2focal(camera.FovY, camera.height), + 'fx' : fov2focal(camera.FovX, camera.width) + } + return camera_entry diff --git a/utils/general_utils.py b/utils/general_utils.py new file mode 100644 index 0000000..e6a8a81 --- /dev/null +++ b/utils/general_utils.py @@ -0,0 +1,136 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +import sys +from datetime import datetime +import numpy as np +import random + +def inverse_sigmoid(x): + return torch.log(x/(1-x)) + +def PILtoTorch(pil_image, resolution): + if resolution is not None: + resized_image_PIL = pil_image.resize(resolution) + else: + resized_image_PIL = pil_image + resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 + if len(resized_image.shape) == 3: + return resized_image.permute(2, 0, 1) + else: + return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) + +def get_expon_lr_func( + lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 +): + """ + Copied from Plenoxels + + Continuous learning rate decay function. Adapted from JaxNeRF + The returned rate is lr_init when step=0 and lr_final when step=max_steps, and + is log-linearly interpolated elsewhere (equivalent to exponential decay). + If lr_delay_steps>0 then the learning rate will be scaled by some smooth + function of lr_delay_mult, such that the initial learning rate is + lr_init*lr_delay_mult at the beginning of optimization but will be eased back + to the normal learning rate when steps>lr_delay_steps. + :param conf: config subtree 'lr' or similar + :param max_steps: int, the number of steps during optimization. + :return HoF which takes step as input + """ + + def helper(step): + if step < 0 or (lr_init == 0.0 and lr_final == 0.0): + # Disable this parameter + return 0.0 + if lr_delay_steps > 0: + # A kind of reverse cosine decay. + delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( + 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) + ) + else: + delay_rate = 1.0 + t = np.clip(step / max_steps, 0, 1) + log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) + return delay_rate * log_lerp + + return helper + +def strip_lowerdiag(L): + uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") + + uncertainty[:, 0] = L[:, 0, 0] + uncertainty[:, 1] = L[:, 0, 1] + uncertainty[:, 2] = L[:, 0, 2] + uncertainty[:, 3] = L[:, 1, 1] + uncertainty[:, 4] = L[:, 1, 2] + uncertainty[:, 5] = L[:, 2, 2] + return uncertainty + +def strip_symmetric(sym): + return strip_lowerdiag(sym) + +def build_rotation(r): + norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3]) + + q = r / norm[:, None] + + R = torch.zeros((q.size(0), 3, 3), device='cuda') + + r = q[:, 0] + x = q[:, 1] + y = q[:, 2] + z = q[:, 3] + + R[:, 0, 0] = 1 - 2 * (y*y + z*z) + R[:, 0, 1] = 2 * (x*y - r*z) + R[:, 0, 2] = 2 * (x*z + r*y) + R[:, 1, 0] = 2 * (x*y + r*z) + R[:, 1, 1] = 1 - 2 * (x*x + z*z) + R[:, 1, 2] = 2 * (y*z - r*x) + R[:, 2, 0] = 2 * (x*z - r*y) + R[:, 2, 1] = 2 * (y*z + r*x) + R[:, 2, 2] = 1 - 2 * (x*x + y*y) + return R + +def build_scaling_rotation(s, r): + L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") + R = build_rotation(r) + + L[:,0,0] = s[:,0] + L[:,1,1] = s[:,1] + L[:,2,2] = s[:,2] + + L = R @ L + return L + +def safe_state(silent): + old_f = sys.stdout + class F: + def __init__(self, silent): + self.silent = silent + + def write(self, x): + if not self.silent: + if x.endswith("\n"): + old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S"))))) + else: + old_f.write(x) + + def flush(self): + old_f.flush() + + sys.stdout = F(silent) + + random.seed(0) + np.random.seed(0) + torch.manual_seed(0) + torch.cuda.set_device(torch.device("cuda:0")) diff --git a/utils/graphics_utils.py b/utils/graphics_utils.py new file mode 100644 index 0000000..b4627d8 --- /dev/null +++ b/utils/graphics_utils.py @@ -0,0 +1,77 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +import math +import numpy as np +from typing import NamedTuple + +class BasicPointCloud(NamedTuple): + points : np.array + colors : np.array + normals : np.array + +def geom_transform_points(points, transf_matrix): + P, _ = points.shape + ones = torch.ones(P, 1, dtype=points.dtype, device=points.device) + points_hom = torch.cat([points, ones], dim=1) + points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0)) + + denom = points_out[..., 3:] + 0.0000001 + return (points_out[..., :3] / denom).squeeze(dim=0) + +def getWorld2View(R, t): + Rt = np.zeros((4, 4)) + Rt[:3, :3] = R.transpose() + Rt[:3, 3] = t + Rt[3, 3] = 1.0 + return np.float32(Rt) + +def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): + Rt = np.zeros((4, 4)) + Rt[:3, :3] = R.transpose() + Rt[:3, 3] = t + Rt[3, 3] = 1.0 + + C2W = np.linalg.inv(Rt) + cam_center = C2W[:3, 3] + cam_center = (cam_center + translate) * scale + C2W[:3, 3] = cam_center + Rt = np.linalg.inv(C2W) + return np.float32(Rt) + +def getProjectionMatrix(znear, zfar, fovX, fovY): + tanHalfFovY = math.tan((fovY / 2)) + tanHalfFovX = math.tan((fovX / 2)) + + top = tanHalfFovY * znear + bottom = -top + right = tanHalfFovX * znear + left = -right + + P = torch.zeros(4, 4) + + z_sign = 1.0 + + P[0, 0] = 2.0 * znear / (right - left) + P[1, 1] = 2.0 * znear / (top - bottom) + P[0, 2] = (right + left) / (right - left) + P[1, 2] = (top + bottom) / (top - bottom) + P[3, 2] = z_sign + P[2, 2] = z_sign * zfar / (zfar - znear) + P[2, 3] = -(zfar * znear) / (zfar - znear) + return P + +def fov2focal(fov, pixels): + return pixels / (2 * math.tan(fov / 2)) + +def focal2fov(focal, pixels): + return 2*math.atan(pixels/(2*focal)) \ No newline at end of file diff --git a/utils/image_utils.py b/utils/image_utils.py new file mode 100644 index 0000000..b150699 --- /dev/null +++ b/utils/image_utils.py @@ -0,0 +1,19 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch + +def mse(img1, img2): + return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) +@torch.no_grad() +def psnr(img1, img2): + mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) + return 20 * torch.log10(1.0 / torch.sqrt(mse)) diff --git a/utils/loss_utils.py b/utils/loss_utils.py new file mode 100644 index 0000000..6c1b773 --- /dev/null +++ b/utils/loss_utils.py @@ -0,0 +1,69 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +import torch.nn.functional as F +from torch.autograd import Variable +from math import exp +import lpips +def lpips_loss(img1, img2, lpips_model): + a, b, _, __ = img2.shape + ww = img1[:a, :3] + loss = lpips_model(ww * 2 - 1,img2[:, :3] * 2 - 1) + return loss.mean() +def l1_loss(network_output, gt): + return torch.abs((network_output - gt)).mean() + +def l2_loss(network_output, gt): + return ((network_output - gt) ** 2).mean() + +def gaussian(window_size, sigma): + gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) + return gauss / gauss.sum() + +def create_window(window_size, channel): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) + return window + +def ssim(img1, img2, window_size=11, size_average=True): + channel = img1.size(-3) + window = create_window(window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + return _ssim(img1, img2, window, window_size, channel, size_average) + +def _ssim(img1, img2, window, window_size, channel, size_average=True): + mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) + mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq + sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq + sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 + + C1 = 0.01 ** 2 + C2 = 0.03 ** 2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + + if size_average: + return ssim_map.mean() + else: + return ssim_map.mean(1).mean(1).mean(1) + diff --git a/utils/params_utils.py b/utils/params_utils.py new file mode 100644 index 0000000..6f2ea64 --- /dev/null +++ b/utils/params_utils.py @@ -0,0 +1,9 @@ +def merge_hparams(args, config): + params = ["OptimizationParams", "ModelHiddenParams", "ModelParams", "PipelineParams"] + for param in params: + if param in config.keys(): + for key, value in config[param].items(): + if hasattr(args, key): + setattr(args, key, value) + + return args \ No newline at end of file diff --git a/utils/scene_utils.py b/utils/scene_utils.py new file mode 100644 index 0000000..402e3c3 --- /dev/null +++ b/utils/scene_utils.py @@ -0,0 +1,97 @@ +import torch +import os +from PIL import Image, ImageDraw, ImageFont +from matplotlib import pyplot as plt +plt.rcParams['font.sans-serif'] = ['Times New Roman'] + +import numpy as np + +import copy +@torch.no_grad() +def render_training_image(scene, gaussians, viewpoints, render_func, pipe, background, stage, iteration, time_now): + def render(gaussians, viewpoint, path, scaling): + # scaling_copy = gaussians._scaling + render_pkg = render_func(viewpoint, gaussians, pipe, background, stage=stage) + label1 = f"stage:{stage},iter:{iteration}" + times = time_now/60 + if times < 1: + end = "min" + else: + end = "mins" + label2 = "time:%.2f" % times + end + image = render_pkg["render"] + depth = render_pkg["depth"] + image_np = image.permute(1, 2, 0).cpu().numpy() # 转换通道顺序为 (H, W, 3) + depth_np = depth.permute(1, 2, 0).cpu().numpy() + depth_np /= depth_np.max() + depth_np = np.repeat(depth_np, 3, axis=2) + image_np = np.concatenate((image_np, depth_np), axis=1) + image_with_labels = Image.fromarray((np.clip(image_np,0,1) * 255).astype('uint8')) # 转换为8位图像 + # 创建PIL图像对象的副本以绘制标签 + draw1 = ImageDraw.Draw(image_with_labels) + + # 选择字体和字体大小 + font = ImageFont.truetype('./utils/TIMES.TTF', size=40) # 请将路径替换为您选择的字体文件路径 + + # 选择文本颜色 + text_color = (255, 0, 0) # 白色 + + # 选择标签的位置(左上角坐标) + label1_position = (10, 10) + label2_position = (image_with_labels.width - 100 - len(label2) * 10, 10) # 右上角坐标 + + # 在图像上添加标签 + draw1.text(label1_position, label1, fill=text_color, font=font) + draw1.text(label2_position, label2, fill=text_color, font=font) + + image_with_labels.save(path) + render_base_path = os.path.join(scene.model_path, f"{stage}_render") + point_cloud_path = os.path.join(render_base_path,"pointclouds") + image_path = os.path.join(render_base_path,"images") + if not os.path.exists(os.path.join(scene.model_path, f"{stage}_render")): + os.makedirs(render_base_path) + if not os.path.exists(point_cloud_path): + os.makedirs(point_cloud_path) + if not os.path.exists(image_path): + os.makedirs(image_path) + # image:3,800,800 + + # point_save_path = os.path.join(point_cloud_path,f"{iteration}.jpg") + for idx in range(len(viewpoints)): + image_save_path = os.path.join(image_path,f"{iteration}_{idx}.jpg") + # time = torch.tensor([idx]).unsqueeze(0) + # render(gaussians,viewpoints[idx]['pose0_cam'],image_save_path,scaling=1,time=time) + render(gaussians,viewpoints[idx]['t0_cam'],image_save_path,scaling = 1) + # render(gaussians,point_save_path,scaling = 0.1) + # 保存带有标签的图像 + + + + pc_mask = gaussians.get_opacity + pc_mask = pc_mask > 0.1 + xyz = gaussians.get_xyz.detach()[pc_mask.squeeze()].cpu().permute(1,0).numpy() + # visualize_and_save_point_cloud(xyz, viewpoint.R, viewpoint.T, point_save_path) + # 如果需要,您可以将PIL图像转换回PyTorch张量 + # return image + # image_with_labels_tensor = torch.tensor(image_with_labels, dtype=torch.float32).permute(2, 0, 1) / 255.0 +def visualize_and_save_point_cloud(point_cloud, R, T, filename): + # 创建3D散点图 + fig = plt.figure() + ax = fig.add_subplot(111, projection='3d') + R = R.T + # 应用旋转和平移变换 + T = -R.dot(T) + transformed_point_cloud = np.dot(R, point_cloud) + T.reshape(-1, 1) + # pcd = o3d.geometry.PointCloud() + # pcd.points = o3d.utility.Vector3dVector(transformed_point_cloud.T) # 转置点云数据以匹配Open3D的格式 + # transformed_point_cloud[2,:] = -transformed_point_cloud[2,:] + # 可视化点云 + ax.scatter(transformed_point_cloud[0], transformed_point_cloud[1], transformed_point_cloud[2], c='g', marker='o') + ax.axis("off") + # ax.set_xlabel('X Label') + # ax.set_ylabel('Y Label') + # ax.set_zlabel('Z Label') + + # 保存渲染结果为图片 + plt.savefig(filename) + diff --git a/utils/sh_utils.py b/utils/sh_utils.py new file mode 100644 index 0000000..bbca7d1 --- /dev/null +++ b/utils/sh_utils.py @@ -0,0 +1,118 @@ +# Copyright 2021 The PlenOctree Authors. +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. + +import torch + +C0 = 0.28209479177387814 +C1 = 0.4886025119029199 +C2 = [ + 1.0925484305920792, + -1.0925484305920792, + 0.31539156525252005, + -1.0925484305920792, + 0.5462742152960396 +] +C3 = [ + -0.5900435899266435, + 2.890611442640554, + -0.4570457994644658, + 0.3731763325901154, + -0.4570457994644658, + 1.445305721320277, + -0.5900435899266435 +] +C4 = [ + 2.5033429417967046, + -1.7701307697799304, + 0.9461746957575601, + -0.6690465435572892, + 0.10578554691520431, + -0.6690465435572892, + 0.47308734787878004, + -1.7701307697799304, + 0.6258357354491761, +] + + +def eval_sh(deg, sh, dirs): + """ + Evaluate spherical harmonics at unit directions + using hardcoded SH polynomials. + Works with torch/np/jnp. + ... Can be 0 or more batch dimensions. + Args: + deg: int SH deg. Currently, 0-3 supported + sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] + dirs: jnp.ndarray unit directions [..., 3] + Returns: + [..., C] + """ + assert deg <= 4 and deg >= 0 + coeff = (deg + 1) ** 2 + assert sh.shape[-1] >= coeff + + result = C0 * sh[..., 0] + if deg > 0: + x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] + result = (result - + C1 * y * sh[..., 1] + + C1 * z * sh[..., 2] - + C1 * x * sh[..., 3]) + + if deg > 1: + xx, yy, zz = x * x, y * y, z * z + xy, yz, xz = x * y, y * z, x * z + result = (result + + C2[0] * xy * sh[..., 4] + + C2[1] * yz * sh[..., 5] + + C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + + C2[3] * xz * sh[..., 7] + + C2[4] * (xx - yy) * sh[..., 8]) + + if deg > 2: + result = (result + + C3[0] * y * (3 * xx - yy) * sh[..., 9] + + C3[1] * xy * z * sh[..., 10] + + C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + + C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + + C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + + C3[5] * z * (xx - yy) * sh[..., 14] + + C3[6] * x * (xx - 3 * yy) * sh[..., 15]) + + if deg > 3: + result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + + C4[1] * yz * (3 * xx - yy) * sh[..., 17] + + C4[2] * xy * (7 * zz - 1) * sh[..., 18] + + C4[3] * yz * (7 * zz - 3) * sh[..., 19] + + C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + + C4[5] * xz * (7 * zz - 3) * sh[..., 21] + + C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + + C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + + C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) + return result + +def RGB2SH(rgb): + return (rgb - 0.5) / C0 + +def SH2RGB(sh): + return sh * C0 + 0.5 \ No newline at end of file diff --git a/utils/system_utils.py b/utils/system_utils.py new file mode 100644 index 0000000..90ca6d7 --- /dev/null +++ b/utils/system_utils.py @@ -0,0 +1,28 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +from errno import EEXIST +from os import makedirs, path +import os + +def mkdir_p(folder_path): + # Creates a directory. equivalent to using mkdir -p on the command line + try: + makedirs(folder_path) + except OSError as exc: # Python >2.5 + if exc.errno == EEXIST and path.isdir(folder_path): + pass + else: + raise + +def searchForMaxIteration(folder): + saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)] + return max(saved_iters) diff --git a/utils/timer.py b/utils/timer.py new file mode 100644 index 0000000..c01ff93 --- /dev/null +++ b/utils/timer.py @@ -0,0 +1,24 @@ +import time +class Timer: + def __init__(self): + self.start_time = None + self.elapsed = 0 + self.paused = False + + def start(self): + if self.start_time is None: + self.start_time = time.time() + elif self.paused: + self.start_time = time.time() - self.elapsed + self.paused = False + + def pause(self): + if not self.paused: + self.elapsed = time.time() - self.start_time + self.paused = True + + def get_elapsed_time(self): + if self.paused: + return self.elapsed + else: + return time.time() - self.start_time \ No newline at end of file diff --git a/zero123.py b/zero123.py new file mode 100644 index 0000000..158e31e --- /dev/null +++ b/zero123.py @@ -0,0 +1,666 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +import warnings +from typing import Any, Callable, Dict, List, Optional, Union + +import PIL +import torch +import torchvision.transforms.functional as TF +from diffusers.configuration_utils import ConfigMixin, FrozenDict, register_to_config +from diffusers.image_processor import VaeImageProcessor +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.models.modeling_utils import ModelMixin +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import ( + StableDiffusionSafetyChecker, +) +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import deprecate, is_accelerate_available, logging +from diffusers.utils.torch_utils import randn_tensor +from packaging import version +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class CLIPCameraProjection(ModelMixin, ConfigMixin): + """ + A Projection layer for CLIP embedding and camera embedding. + + Parameters: + embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `clip_embed` + additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the + projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings + + additional_embeddings`. + """ + + @register_to_config + def __init__(self, embedding_dim: int = 768, additional_embeddings: int = 4): + super().__init__() + self.embedding_dim = embedding_dim + self.additional_embeddings = additional_embeddings + + self.input_dim = self.embedding_dim + self.additional_embeddings + self.output_dim = self.embedding_dim + + self.proj = torch.nn.Linear(self.input_dim, self.output_dim) + + def forward( + self, + embedding: torch.FloatTensor, + ): + """ + The [`PriorTransformer`] forward method. + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, input_dim)`): + The currently input embeddings. + + Returns: + The output embedding projection (`torch.FloatTensor` of shape `(batch_size, output_dim)`). + """ + proj_embedding = self.proj(embedding) + return proj_embedding + + +class Zero123Pipeline(DiffusionPipeline): + r""" + Pipeline to generate variations from an input image using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + image_encoder ([`CLIPVisionModelWithProjection`]): + Frozen CLIP image-encoder. Stable Diffusion Image Variation uses the vision portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + # TODO: feature_extractor is required to encode images (if they are in PIL format), + # we should give a descriptive message if the pipeline doesn't have one. + _optional_components = ["safety_checker"] + + def __init__( + self, + vae: AutoencoderKL, + image_encoder: CLIPVisionModelWithProjection, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + clip_camera_projection: CLIPCameraProjection, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warn( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr( + unet.config, "_diffusers_version" + ) and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse( + "0.9.0.dev0" + ) + is_unet_sample_size_less_64 = ( + hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate( + "sample_size<64", "1.0.0", deprecation_message, standard_warn=False + ) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + image_encoder=image_encoder, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + clip_camera_projection=clip_camera_projection, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [ + self.unet, + self.image_encoder, + self.vae, + self.safety_checker, + ]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_image( + self, + image, + elevation, + azimuth, + distance, + device, + num_images_per_prompt, + do_classifier_free_guidance, + clip_image_embeddings=None, + image_camera_embeddings=None, + ): + dtype = next(self.image_encoder.parameters()).dtype + + if image_camera_embeddings is None: + if image is None: + assert clip_image_embeddings is not None + image_embeddings = clip_image_embeddings.to(device=device, dtype=dtype) + else: + if not isinstance(image, torch.Tensor): + image = self.feature_extractor( + images=image, return_tensors="pt" + ).pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeddings = self.image_encoder(image).image_embeds + image_embeddings = image_embeddings.unsqueeze(1) + + bs_embed, seq_len, _ = image_embeddings.shape + + if isinstance(elevation, float): + elevation = torch.as_tensor( + [elevation] * bs_embed, dtype=dtype, device=device + ) + if isinstance(azimuth, float): + azimuth = torch.as_tensor( + [azimuth] * bs_embed, dtype=dtype, device=device + ) + if isinstance(distance, float): + distance = torch.as_tensor( + [distance] * bs_embed, dtype=dtype, device=device + ) + + camera_embeddings = torch.stack( + [ + torch.deg2rad(elevation), + torch.sin(torch.deg2rad(azimuth)), + torch.cos(torch.deg2rad(azimuth)), + distance, + ], + dim=-1, + )[:, None, :] + + image_embeddings = torch.cat([image_embeddings, camera_embeddings], dim=-1) + + # project (image, camera) embeddings to the same dimension as clip embeddings + image_embeddings = self.clip_camera_projection(image_embeddings) + else: + image_embeddings = image_camera_embeddings.to(device=device, dtype=dtype) + bs_embed, seq_len, _ = image_embeddings.shape + + # duplicate image embeddings for each generation per prompt, using mps friendly method + image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) + image_embeddings = image_embeddings.view( + bs_embed * num_images_per_prompt, seq_len, -1 + ) + + if do_classifier_free_guidance: + negative_prompt_embeds = torch.zeros_like(image_embeddings) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) + + return image_embeddings + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess( + image, output_type="pil" + ) + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor( + feature_extractor_input, return_tensors="pt" + ).to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs(self, image, height, width, callback_steps): + # TODO: check image size or adjust image size to (height, width) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError( + f"`height` and `width` have to be divisible by 8 but are {height} and {width}." + ) + + if (callback_steps is None) or ( + callback_steps is not None + and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + shape = ( + batch_size, + num_channels_latents, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor( + shape, generator=generator, device=device, dtype=dtype + ) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _get_latent_model_input( + self, + latents: torch.FloatTensor, + image: Optional[ + Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor] + ], + num_images_per_prompt: int, + do_classifier_free_guidance: bool, + image_latents: Optional[torch.FloatTensor] = None, + ): + if isinstance(image, PIL.Image.Image): + image_pt = TF.to_tensor(image).unsqueeze(0).to(latents) + elif isinstance(image, list): + image_pt = torch.stack([TF.to_tensor(img) for img in image], dim=0).to( + latents + ) + elif isinstance(image, torch.Tensor): + image_pt = image + else: + image_pt = None + + if image_pt is None: + assert image_latents is not None + image_pt = image_latents.repeat_interleave(num_images_per_prompt, dim=0) + else: + image_pt = image_pt * 2.0 - 1.0 # scale to [-1, 1] + # FIXME: encoded latents should be multiplied with self.vae.config.scaling_factor + # but zero123 was not trained this way + image_pt = self.vae.encode(image_pt).latent_dist.mode() + image_pt = image_pt.repeat_interleave(num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + latent_model_input = torch.cat( + [ + torch.cat([latents, latents], dim=0), + torch.cat([torch.zeros_like(image_pt), image_pt], dim=0), + ], + dim=1, + ) + else: + latent_model_input = torch.cat([latents, image_pt], dim=1) + + return latent_model_input + + @torch.no_grad() + def __call__( + self, + image: Optional[ + Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor] + ] = None, + elevation: Optional[Union[float, torch.FloatTensor]] = None, + azimuth: Optional[Union[float, torch.FloatTensor]] = None, + distance: Optional[Union[float, torch.FloatTensor]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 3.0, + num_images_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + clip_image_embeddings: Optional[torch.FloatTensor] = None, + image_camera_embeddings: Optional[torch.FloatTensor] = None, + image_latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): + The image or images to guide the image generation. If you provide a tensor, it needs to comply with the + configuration of + [this](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json) + `CLIPImageProcessor` + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + # TODO: check input elevation, azimuth, and distance + # TODO: check image, clip_image_embeddings, image_latents + self.check_inputs(image, height, width, callback_steps) + + # 2. Define call parameters + if isinstance(image, PIL.Image.Image): + batch_size = 1 + elif isinstance(image, list): + batch_size = len(image) + elif isinstance(image, torch.Tensor): + batch_size = image.shape[0] + else: + assert image_latents is not None + assert ( + clip_image_embeddings is not None or image_camera_embeddings is not None + ) + batch_size = image_latents.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input image + if isinstance(image, PIL.Image.Image) or isinstance(image, list): + pil_image = image + elif isinstance(image, torch.Tensor): + pil_image = [TF.to_pil_image(image[i]) for i in range(image.shape[0])] + else: + pil_image = None + image_embeddings = self._encode_image( + pil_image, + elevation, + azimuth, + distance, + device, + num_images_per_prompt, + do_classifier_free_guidance, + clip_image_embeddings, + image_camera_embeddings, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + # num_channels_latents = self.unet.config.in_channels + num_channels_latents = 4 # FIXME: hard-coded + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + image_embeddings.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = self._get_latent_model_input( + latents, + image, + num_images_per_prompt, + do_classifier_free_guidance, + image_latents, + ) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=image_embeddings, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs + ).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + image = self.vae.decode( + latents / self.vae.config.scaling_factor, return_dict=False + )[0] + image, has_nsfw_concept = self.run_safety_checker( + image, device, image_embeddings.dtype + ) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess( + image, output_type=output_type, do_denormalize=do_denormalize + ) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput( + images=image, nsfw_content_detected=has_nsfw_concept + ) \ No newline at end of file