-
Notifications
You must be signed in to change notification settings - Fork 2
/
volume_rendering_utils.py
51 lines (44 loc) · 1.6 KB
/
volume_rendering_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
from nerf_helpers import cumprod_exclusive
def volume_render_radiance_field(
radiance_field,
depth_values,
ray_directions,
radiance_field_noise_std=0.0,
white_background=False,
mip_nerf=False,
):
# TESTED
one_e_10 = torch.tensor(
[1e10], dtype=ray_directions.dtype, device=ray_directions.device
)
dists = depth_values[..., 1:] - depth_values[..., :-1]
if not mip_nerf:
dists = torch.cat(
(
dists,
one_e_10.expand(depth_values[..., :1].shape),
),
dim=-1,
)
dists = dists * ray_directions[..., None, :].norm(p=2, dim=-1)
rgb = torch.sigmoid(radiance_field[..., :3])
noise = 0.0
if radiance_field_noise_std > 0.0:
noise = torch.randn(radiance_field[..., 3].shape) * radiance_field_noise_std
noise = noise.to(radiance_field)
sigma_a = torch.nn.functional.relu(radiance_field[..., 3] + noise)
alpha = 1.0 - torch.exp(-sigma_a * dists)
weights = alpha * cumprod_exclusive(1.0 - alpha + 1e-10)
rgb_map = weights[..., None] * rgb
rgb_map = rgb_map.sum(dim=-2)
if mip_nerf:
depth_values = 0.5*(depth_values[:,:-1]+depth_values[:,1:])
depth_map = weights * depth_values
depth_map = depth_map.sum(dim=-1)
# depth_map = (weights * depth_values).sum(dim=-1)
acc_map = weights.sum(dim=-1)
disp_map = 1.0 / torch.max(1e-10 * torch.ones_like(depth_map), depth_map / acc_map)
if white_background:
rgb_map = rgb_map + (1.0 - acc_map[..., None])
return rgb_map, disp_map, acc_map, weights, depth_map