Skip to content

Latest commit

 

History

History
28 lines (22 loc) · 771 Bytes

06.pytorch-nerf模型训练1之概述.md

File metadata and controls

28 lines (22 loc) · 771 Bytes

颠覆传统三维重建方法之nerf(十一)---pytorch-nerf模型训练1之概述

暂时 use_batching = false

训练的核心代码

rays_o, rays_d = get_rays(H, W, K, torch.Tensor(pose))  # (H, W, 3), (H, W, 3)

#####  Core optimization loop  #####
rgb, disp, acc, extras = render(H, W, K, chunk=args.chunk, rays=batch_rays,
                                        verbose=i < 10, retraw=True,
                                        **render_kwargs_train)

optimizer.zero_grad()
img_loss = img2mse(rgb, target_s)
trans = extras['raw'][...,-1]
loss = img_loss
psnr = mse2psnr(img_loss)

if 'rgb0' in extras:
    img_loss0 = img2mse(extras['rgb0'], target_s)
    loss = loss + img_loss0
    psnr0 = mse2psnr(img_loss0)

loss.backward()
optimizer.step()