Skip to content

Commit

Permalink
Make it easier to change patch size
Browse files Browse the repository at this point in the history
  • Loading branch information
eigenvivek committed Oct 28, 2024
1 parent 66b7b5f commit 2eb6162
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 5 deletions.
1 change: 1 addition & 0 deletions diffdrr/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
'diffdrr.drr.DRR.affine_inverse': ('api/drr.html#drr.affine_inverse', 'diffdrr/drr.py'),
'diffdrr.drr.DRR.forward': ('api/drr.html#drr.forward', 'diffdrr/drr.py'),
'diffdrr.drr.DRR.inverse_projection': ('api/drr.html#drr.inverse_projection', 'diffdrr/drr.py'),
'diffdrr.drr.DRR.n_patches': ('api/drr.html#drr.n_patches', 'diffdrr/drr.py'),
'diffdrr.drr.DRR.perspective_projection': ('api/drr.html#drr.perspective_projection', 'diffdrr/drr.py'),
'diffdrr.drr.DRR.render': ('api/drr.html#drr.render', 'diffdrr/drr.py'),
'diffdrr.drr.DRR.rescale_detector_': ('api/drr.html#drr.rescale_detector_', 'diffdrr/drr.py'),
Expand Down
6 changes: 4 additions & 2 deletions diffdrr/drr.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,6 @@ def __init__(
)
self.reshape = reshape
self.patch_size = patch_size
if self.patch_size is not None:
self.n_patches = (height * width) // (self.patch_size**2)

def reshape_transform(self, img, batch_size):
if self.reshape:
Expand All @@ -122,6 +120,10 @@ def affine(self):
def affine_inverse(self):
return RigidTransform(self._affine_inverse)

@property
def n_patches(self):
return (self.detector.height * self.detector.width) // (self.patch_size**2)

# %% ../notebooks/api/00_drr.ipynb 8
def reshape_subsampled_drr(img: torch.Tensor, detector: Detector, batch_size: int):
n_points = detector.height * detector.width
Expand Down
8 changes: 5 additions & 3 deletions notebooks/api/00_drr.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,6 @@
" )\n",
" self.reshape = reshape\n",
" self.patch_size = patch_size\n",
" if self.patch_size is not None:\n",
" self.n_patches = (height * width) // (self.patch_size**2)\n",
"\n",
" def reshape_transform(self, img, batch_size):\n",
" if self.reshape:\n",
Expand All @@ -215,7 +213,11 @@
"\n",
" @property\n",
" def affine_inverse(self):\n",
" return RigidTransform(self._affine_inverse)"
" return RigidTransform(self._affine_inverse)\n",
"\n",
" @property\n",
" def n_patches(self):\n",
" return (self.detector.height * self.detector.width) // (self.patch_size**2)"
]
},
{
Expand Down

0 comments on commit 2eb6162

Please sign in to comment.