Skip to content

Commit

Permalink
[data] Add FlowFields.curl() method
Browse files Browse the repository at this point in the history
  • Loading branch information
aschuh-hf committed Aug 31, 2023
1 parent 55352ea commit 9cbcb02
Showing 1 changed file with 27 additions and 8 deletions.
35 changes: 27 additions & 8 deletions src/deepali/data/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from deepali.core.enum import PaddingMode, Sampling
from deepali.core.grid import ALIGN_CORNERS, Axes, Grid, grid_transform_vectors
from deepali.core.tensor import move_dim
from deepali.core.typing import Array, Device, DType, EllipsisType, PathStr, Scalar
from deepali.core.typing import Array, Device, DType, EllipsisType, PathStr, Scalar, ScalarOrTuple

from .image import Image, ImageBatch

Expand Down Expand Up @@ -203,12 +203,25 @@ def axes(self: TFlowFields, axes: Optional[Axes] = None) -> Union[Axes, TFlowFie
data = move_dim(data, -1, 1)
return self._make_instance(data, self._grid, axes)

def curl(self: TFlowFields, mode: str = "central") -> ImageBatch:
def curl(
self: TFlowFields,
mode: Optional[str] = None,
sigma: Optional[float] = None,
spacing: Optional[Union[Scalar, Array]] = None,
stride: Optional[ScalarOrTuple[int]] = None,
) -> ImageBatch:
if self.ndim not in (2, 3):
raise RuntimeError("Cannot compute curl of {self.ndim}-dimensional flow field")
spacing = self.spacing()
data = self.tensor()
data = U.curl(data, spacing=spacing, mode=mode)
raise RuntimeError(f"Cannot compute curl of {self.ndim}-dimensional flow field")
if spacing is None:
if self.axes() is Axes.GRID:
spacing = 1
elif self.axes() is Axes.WORLD:
spacing = self.spacing()
elif self.axes() is Axes.CUBE:
spacing = tuple(2 / n for n in self.grid().size())
else:
spacing = tuple(2 / (n - 1) for n in self.grid().size())
data = U.curl(self.tensor(), mode=mode, sigma=sigma, spacing=spacing, stride=stride)
return ImageBatch(data, self._grid)

def exp(
Expand Down Expand Up @@ -529,10 +542,16 @@ def write(self, path: PathStr, axes: Optional[Axes] = None, compress: bool = Tru
disp = disp.axes(axes or Axes.WORLD)
Image.write(disp, path, compress=compress)

def curl(self: TFlowField, mode: str = "central") -> Image:
def curl(
self: TFlowField,
mode: Optional[str] = None,
sigma: Optional[float] = None,
spacing: Optional[Union[Scalar, Array]] = None,
stride: Optional[ScalarOrTuple[int]] = None,
) -> Image:
r"""Compute curl of vector field."""
batch = self.batch()
rotvec = batch.curl(mode=mode)
rotvec = batch.curl(mode=mode, sigma=sigma, spacing=spacing, stride=stride)
return rotvec[0]

def exp(
Expand Down

0 comments on commit 9cbcb02

Please sign in to comment.