Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor rendering computation #329

Merged
merged 1 commit into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions diffdrr/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
'diffdrr.drr.DRR.forward': ('api/drr.html#drr.forward', 'diffdrr/drr.py'),
'diffdrr.drr.DRR.inverse_projection': ('api/drr.html#drr.inverse_projection', 'diffdrr/drr.py'),
'diffdrr.drr.DRR.perspective_projection': ('api/drr.html#drr.perspective_projection', 'diffdrr/drr.py'),
'diffdrr.drr.DRR.render': ('api/drr.html#drr.render', 'diffdrr/drr.py'),
'diffdrr.drr.DRR.rescale_detector_': ('api/drr.html#drr.rescale_detector_', 'diffdrr/drr.py'),
'diffdrr.drr.DRR.reshape_transform': ('api/drr.html#drr.reshape_transform', 'diffdrr/drr.py'),
'diffdrr.drr.DRR.set_intrinsics_': ('api/drr.html#drr.set_intrinsics_', 'diffdrr/drr.py'),
Expand Down
46 changes: 33 additions & 13 deletions diffdrr/drr.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,10 @@ def reshape_transform(self, img, batch_size):
if self.reshape:
if self.detector.n_subsample is None:
img = img.view(
batch_size, -1, self.detector.height, self.detector.width
batch_size,
-1,
self.detector.height,
self.detector.width,
)
else:
img = reshape_subsampled_drr(img, self.detector, batch_size)
Expand Down Expand Up @@ -147,37 +150,54 @@ def forward(
pose = args[0]
else:
pose = convert(*args, parameterization=parameterization, convention=convention)

# Create the source / target points and render the image
source, target = self.detector(pose, calibration)
img = self.render(self.density, source, target, mask_to_channels, **kwargs)
return self.reshape_transform(img, batch_size=len(pose))


@patch
def render(
self: DRR,
density: torch.tensor,
source: torch.tensor,
target: torch.tensor,
mask_to_channels: bool,
**kwargs,
):
# Initialize the image with the length of each cast ray
img = (target - source).norm(dim=-1).unsqueeze(1)

# Convert rays to voxelspace
source = self.affine_inverse(source)
target = self.affine_inverse(target)

# Render the DRR
# Render the image
kwargs["mask"] = self.mask if mask_to_channels else None
if self.patch_size is None:
img = self.renderer(
self.density,
density,
source,
target,
img,
**kwargs,
)
else:
n_points = target.shape[1] // self.n_patches
img = []
partials = []
for idx in range(self.n_patches):
t = target[:, idx * n_points : (idx + 1) * n_points]
partial = self.renderer(
self.density,
density,
source,
t,
target[:, idx * n_points : (idx + 1) * n_points],
img[:, idx * n_points : (idx + 1) * n_points],
**kwargs,
)
img.append(partial)
img = torch.cat(img, dim=-1)
partials.append(partial)
img = torch.cat(partials, dim=-1)

# Multiply by the raylength (in world coordinate units)
img *= self.affine(target - source).norm(dim=-1).unsqueeze(1)

return self.reshape_transform(img, batch_size=len(pose))
return img

# %% ../notebooks/api/00_drr.ipynb 11
@patch
Expand Down
22 changes: 15 additions & 7 deletions diffdrr/renderers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def forward(
volume,
source,
target,
img,
align_corners=False,
mask=None,
):
Expand All @@ -56,9 +57,11 @@ def forward(
# Use torch.nn.functional.grid_sample to lookup the values of each intersected voxel
if self.stop_gradients_through_grid_sample:
with torch.no_grad():
img = _get_voxel(volume, xyzs, self.mode, align_corners=align_corners)
img = _get_voxel(
volume, xyzs, img, self.mode, align_corners=align_corners
)
else:
img = _get_voxel(volume, xyzs, self.mode, align_corners=align_corners)
img = _get_voxel(volume, xyzs, img, self.mode, align_corners=align_corners)

# Weight each intersected voxel by the length of the ray's intersection with the voxel
intersection_length = torch.diff(alphas, dim=-1)
Expand All @@ -74,7 +77,7 @@ def forward(
B, D, _ = img.shape
C = int(mask.max().item() + 1)
channels = _get_voxel(
mask, xyzs, self.mode, align_corners=align_corners
mask, xyzs, img=None, mode=self.mode, align_corners=align_corners
).long()
img = (
torch.zeros(B, C, D)
Expand Down Expand Up @@ -144,7 +147,7 @@ def _get_xyzs(alpha, source, target, dims, eps):
return xyzs


def _get_voxel(volume, xyzs, mode, align_corners):
def _get_voxel(volume, xyzs, img, mode, align_corners):
"""Wraps torch.nn.functional.grid_sample to sample a volume at XYZ coordinates."""
batch_size = len(xyzs)
voxels = grid_sample(
Expand All @@ -153,7 +156,11 @@ def _get_voxel(volume, xyzs, mode, align_corners):
mode=mode,
align_corners=align_corners,
)[:, 0, 0]
return voxels
if img is not None:
img = torch.einsum("bcn, bnj -> bnj", img, voxels)
else:
img = voxels
return img

# %% ../notebooks/api/01_renderers.ipynb 10
class Trilinear(torch.nn.Module):
Expand All @@ -176,6 +183,7 @@ def forward(
volume,
source,
target,
img,
n_points=500,
align_corners=False,
mask=None,
Expand All @@ -197,7 +205,7 @@ def forward(
xyzs = _get_xyzs(alphas, source, target, dims, self.eps)

# Sample the volume with trilinear interpolation
img = _get_voxel(volume, xyzs, self.mode, align_corners=align_corners)
img = _get_voxel(volume, xyzs, img, self.mode, align_corners=align_corners)

# Multiply by the step size to compute the rectangular rule for integration
step_size = (alphamax - alphamin) / (n_points - 1)
Expand All @@ -210,7 +218,7 @@ def forward(
B, D, _ = img.shape
C = int(mask.max().item() + 1)
channels = _get_voxel(
mask, xyzs, self.mode, align_corners=align_corners
mask, xyzs, img=None, mode=self.mode, align_corners=align_corners
).long()
img = (
torch.zeros(B, C, D)
Expand Down
48 changes: 34 additions & 14 deletions notebooks/api/00_drr.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -195,12 +195,15 @@
" self.patch_size = patch_size\n",
" if self.patch_size is not None:\n",
" self.n_patches = (height * width) // (self.patch_size**2)\n",
" \n",
"\n",
" def reshape_transform(self, img, batch_size):\n",
" if self.reshape:\n",
" if self.detector.n_subsample is None:\n",
" img = img.view(\n",
" batch_size, -1, self.detector.height, self.detector.width\n",
" batch_size,\n",
" -1,\n",
" self.detector.height,\n",
" self.detector.width,\n",
" )\n",
" else:\n",
" img = reshape_subsampled_drr(img, self.detector, batch_size)\n",
Expand Down Expand Up @@ -266,37 +269,54 @@
" pose = args[0]\n",
" else:\n",
" pose = convert(*args, parameterization=parameterization, convention=convention)\n",
"\n",
" # Create the source / target points and render the image\n",
" source, target = self.detector(pose, calibration)\n",
" img = self.render(self.density, source, target, mask_to_channels, **kwargs)\n",
" return self.reshape_transform(img, batch_size=len(pose))\n",
"\n",
"\n",
"@patch\n",
"def render(\n",
" self: DRR,\n",
" density: torch.tensor,\n",
" source: torch.tensor,\n",
" target: torch.tensor,\n",
" mask_to_channels: bool,\n",
" **kwargs,\n",
"):\n",
" # Initialize the image with the length of each cast ray\n",
" img = (target - source).norm(dim=-1).unsqueeze(1)\n",
"\n",
" # Convert rays to voxelspace\n",
" source = self.affine_inverse(source)\n",
" target = self.affine_inverse(target)\n",
"\n",
" # Render the DRR\n",
" # Render the image\n",
" kwargs[\"mask\"] = self.mask if mask_to_channels else None\n",
" if self.patch_size is None:\n",
" img = self.renderer(\n",
" self.density,\n",
" density,\n",
" source,\n",
" target,\n",
" img,\n",
" **kwargs,\n",
" )\n",
" else:\n",
" n_points = target.shape[1] // self.n_patches\n",
" img = []\n",
" partials = []\n",
" for idx in range(self.n_patches):\n",
" t = target[:, idx * n_points : (idx + 1) * n_points]\n",
" partial = self.renderer(\n",
" self.density,\n",
" density,\n",
" source,\n",
" t,\n",
" target[:, idx * n_points : (idx + 1) * n_points],\n",
" img[:, idx * n_points : (idx + 1) * n_points],\n",
" **kwargs,\n",
" )\n",
" img.append(partial)\n",
" img = torch.cat(img, dim=-1)\n",
" \n",
" # Multiply by the raylength (in world coordinate units)\n",
" img *= self.affine(target - source).norm(dim=-1).unsqueeze(1)\n",
" partials.append(partial)\n",
" img = torch.cat(partials, dim=-1)\n",
"\n",
" return self.reshape_transform(img, batch_size=len(pose))"
" return img"
]
},
{
Expand Down
20 changes: 13 additions & 7 deletions notebooks/api/01_renderers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@
" volume,\n",
" source,\n",
" target,\n",
" img,\n",
" align_corners=False,\n",
" mask=None,\n",
" ):\n",
Expand All @@ -158,9 +159,9 @@
" # Use torch.nn.functional.grid_sample to lookup the values of each intersected voxel\n",
" if self.stop_gradients_through_grid_sample:\n",
" with torch.no_grad():\n",
" img = _get_voxel(volume, xyzs, self.mode, align_corners=align_corners)\n",
" img = _get_voxel(volume, xyzs, img, self.mode, align_corners=align_corners)\n",
" else:\n",
" img = _get_voxel(volume, xyzs, self.mode, align_corners=align_corners)\n",
" img = _get_voxel(volume, xyzs, img, self.mode, align_corners=align_corners)\n",
"\n",
" # Weight each intersected voxel by the length of the ray's intersection with the voxel\n",
" intersection_length = torch.diff(alphas, dim=-1)\n",
Expand All @@ -176,7 +177,7 @@
" B, D, _ = img.shape\n",
" C = int(mask.max().item() + 1)\n",
" channels = _get_voxel(\n",
" mask, xyzs, self.mode, align_corners=align_corners\n",
" mask, xyzs, img=None, mode=self.mode, align_corners=align_corners\n",
" ).long()\n",
" img = (\n",
" torch.zeros(B, C, D)\n",
Expand Down Expand Up @@ -253,7 +254,7 @@
" return xyzs\n",
"\n",
"\n",
"def _get_voxel(volume, xyzs, mode, align_corners):\n",
"def _get_voxel(volume, xyzs, img, mode, align_corners):\n",
" \"\"\"Wraps torch.nn.functional.grid_sample to sample a volume at XYZ coordinates.\"\"\"\n",
" batch_size = len(xyzs)\n",
" voxels = grid_sample(\n",
Expand All @@ -262,7 +263,11 @@
" mode=mode,\n",
" align_corners=align_corners,\n",
" )[:, 0, 0]\n",
" return voxels"
" if img is not None:\n",
" img = torch.einsum(\"bcn, bnj -> bnj\", img, voxels)\n",
" else:\n",
" img = voxels\n",
" return img"
]
},
{
Expand Down Expand Up @@ -307,6 +312,7 @@
" volume,\n",
" source,\n",
" target,\n",
" img,\n",
" n_points=500,\n",
" align_corners=False,\n",
" mask=None,\n",
Expand All @@ -328,7 +334,7 @@
" xyzs = _get_xyzs(alphas, source, target, dims, self.eps)\n",
"\n",
" # Sample the volume with trilinear interpolation\n",
" img = _get_voxel(volume, xyzs, self.mode, align_corners=align_corners)\n",
" img = _get_voxel(volume, xyzs, img, self.mode, align_corners=align_corners)\n",
" \n",
" # Multiply by the step size to compute the rectangular rule for integration\n",
" step_size = (alphamax - alphamin) / (n_points - 1)\n",
Expand All @@ -341,7 +347,7 @@
" B, D, _ = img.shape\n",
" C = int(mask.max().item() + 1)\n",
" channels = _get_voxel(\n",
" mask, xyzs, self.mode, align_corners=align_corners\n",
" mask, xyzs, img=None, mode=self.mode, align_corners=align_corners\n",
" ).long()\n",
" img = (\n",
" torch.zeros(B, C, D)\n",
Expand Down
Loading