Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement cone-beam max-intensity projection #331

Merged
merged 1 commit into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion diffdrr/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@
'diffdrr.renderers._get_alpha_minmax': ('api/renderers.html#_get_alpha_minmax', 'diffdrr/renderers.py'),
'diffdrr.renderers._get_alphas': ('api/renderers.html#_get_alphas', 'diffdrr/renderers.py'),
'diffdrr.renderers._get_voxel': ('api/renderers.html#_get_voxel', 'diffdrr/renderers.py'),
'diffdrr.renderers._get_xyzs': ('api/renderers.html#_get_xyzs', 'diffdrr/renderers.py')},
'diffdrr.renderers._get_xyzs': ('api/renderers.html#_get_xyzs', 'diffdrr/renderers.py'),
'diffdrr.renderers.reduce': ('api/renderers.html#reduce', 'diffdrr/renderers.py')},
'diffdrr.utils': { 'diffdrr.utils.get_focal_length': ('api/utils.html#get_focal_length', 'diffdrr/utils.py'),
'diffdrr.utils.get_principal_point': ('api/utils.html#get_principal_point', 'diffdrr/utils.py'),
'diffdrr.utils.make_intrinsic_matrix': ('api/utils.html#make_intrinsic_matrix', 'diffdrr/utils.py'),
Expand Down
20 changes: 17 additions & 3 deletions diffdrr/renderers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@ def __init__(
mode: str = "nearest", # Interpolation mode for grid_sample
stop_gradients_through_grid_sample: bool = False, # Apply torch.no_grad when calling grid_sample
filter_intersections_outside_volume: bool = True, # Use alphamin/max to filter the intersections
reducefn: str = "sum", # Function for combining samples along each ray
eps: float = 1e-8, # Small constant to avoid div by zero errors
):
super().__init__()
self.mode = mode
self.stop_gradients_through_grid_sample = stop_gradients_through_grid_sample
self.filter_intersections_outside_volume = filter_intersections_outside_volume
self.reducefn = reducefn
self.eps = eps

def dims(self, volume):
Expand Down Expand Up @@ -69,7 +71,7 @@ def forward(

# Handle optional masking
if mask is None:
img = img.sum(dim=-1)
img = reduce(img, self.reducefn)
img = img.unsqueeze(1)
else:
# Thanks to @Ivan for the clutch assist w/ pytorch tensor ops
Expand Down Expand Up @@ -162,17 +164,28 @@ def _get_voxel(volume, xyzs, img, mode, align_corners):
img = voxels
return img

# %% ../notebooks/api/01_renderers.ipynb 10
# %% ../notebooks/api/01_renderers.ipynb 9
def reduce(img, reducefn):
if reducefn == "sum":
return img.sum(dim=-1)
elif reducefn == "max":
return img.max(dim=-1).values
else:
raise ValueError(f"Only supports reducefn 'sum' or 'max', not {reducefn}")

# %% ../notebooks/api/01_renderers.ipynb 11
class Trilinear(torch.nn.Module):
"""Differentiable X-ray renderer implemented with trilinear interpolation."""

def __init__(
self,
mode: str = "bilinear", # Interpolation mode for grid_sample
reducefn: str = "sum", # Function for combining samples along each ray
eps: float = 1e-8, # Small constant to avoid div by zero errors
):
super().__init__()
self.mode = mode
self.reducefn = reducefn
self.eps = eps

def dims(self, volume):
Expand Down Expand Up @@ -213,7 +226,8 @@ def forward(

# Handle optional masking
if mask is None:
img = img.sum(dim=-1).unsqueeze(1)
img = reduce(img, self.reducefn)
img = img.unsqueeze(1)
else:
B, D, _ = img.shape
C = int(mask.max().item() + 1)
Expand Down
25 changes: 23 additions & 2 deletions notebooks/api/01_renderers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,14 @@
" mode: str = \"nearest\", # Interpolation mode for grid_sample\n",
" stop_gradients_through_grid_sample: bool = False, # Apply torch.no_grad when calling grid_sample\n",
" filter_intersections_outside_volume: bool = True, # Use alphamin/max to filter the intersections\n",
" reducefn: str = \"sum\", # Function for combining samples along each ray\n",
" eps: float = 1e-8, # Small constant to avoid div by zero errors\n",
" ):\n",
" super().__init__()\n",
" self.mode = mode\n",
" self.stop_gradients_through_grid_sample = stop_gradients_through_grid_sample\n",
" self.filter_intersections_outside_volume = filter_intersections_outside_volume\n",
" self.reducefn = reducefn\n",
" self.eps = eps\n",
"\n",
" def dims(self, volume):\n",
Expand Down Expand Up @@ -169,7 +171,7 @@
"\n",
" # Handle optional masking\n",
" if mask is None:\n",
" img = img.sum(dim=-1)\n",
" img = reduce(img, self.reducefn)\n",
" img = img.unsqueeze(1)\n",
" else:\n",
" # Thanks to @Ivan for the clutch assist w/ pytorch tensor ops\n",
Expand Down Expand Up @@ -270,6 +272,22 @@
" return img"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| exporti\n",
"def reduce(img, reducefn):\n",
" if reducefn == \"sum\":\n",
" return img.sum(dim=-1)\n",
" elif reducefn == \"max\":\n",
" return img.max(dim=-1).values\n",
" else:\n",
" raise ValueError(f\"Only supports reducefn 'sum' or 'max', not {reducefn}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -298,10 +316,12 @@
" def __init__(\n",
" self,\n",
" mode: str = \"bilinear\", # Interpolation mode for grid_sample\n",
" reducefn: str = \"sum\", # Function for combining samples along each ray\n",
" eps: float = 1e-8, # Small constant to avoid div by zero errors\n",
" ):\n",
" super().__init__()\n",
" self.mode = mode\n",
" self.reducefn = reducefn\n",
" self.eps = eps\n",
"\n",
" def dims(self, volume):\n",
Expand Down Expand Up @@ -342,7 +362,8 @@
"\n",
" # Handle optional masking\n",
" if mask is None:\n",
" img = img.sum(dim=-1).unsqueeze(1)\n",
" img = reduce(img, self.reducefn)\n",
" img = img.unsqueeze(1)\n",
" else:\n",
" B, D, _ = img.shape\n",
" C = int(mask.max().item() + 1)\n",
Expand Down
4 changes: 2 additions & 2 deletions notebooks/index.ipynb

Large diffs are not rendered by default.

44 changes: 43 additions & 1 deletion notebooks/tutorials/introduction.ipynb

Large diffs are not rendered by default.

Loading