diff --git a/src/deepali/data/flow.py b/src/deepali/data/flow.py index 75ea3b2..16f6d63 100644 --- a/src/deepali/data/flow.py +++ b/src/deepali/data/flow.py @@ -12,7 +12,7 @@ from deepali.core.enum import PaddingMode, Sampling from deepali.core.grid import ALIGN_CORNERS, Axes, Grid, grid_transform_vectors from deepali.core.tensor import move_dim -from deepali.core.typing import Array, Device, DType, EllipsisType, PathStr, Scalar +from deepali.core.typing import Array, Device, DType, EllipsisType, PathStr, Scalar, ScalarOrTuple from .image import Image, ImageBatch @@ -203,12 +203,25 @@ def axes(self: TFlowFields, axes: Optional[Axes] = None) -> Union[Axes, TFlowFie data = move_dim(data, -1, 1) return self._make_instance(data, self._grid, axes) - def curl(self: TFlowFields, mode: str = "central") -> ImageBatch: + def curl( + self: TFlowFields, + mode: Optional[str] = None, + sigma: Optional[float] = None, + spacing: Optional[Union[Scalar, Array]] = None, + stride: Optional[ScalarOrTuple[int]] = None, + ) -> ImageBatch: if self.ndim not in (2, 3): - raise RuntimeError("Cannot compute curl of {self.ndim}-dimensional flow field") - spacing = self.spacing() - data = self.tensor() - data = U.curl(data, spacing=spacing, mode=mode) + raise RuntimeError(f"Cannot compute curl of {self.ndim}-dimensional flow field") + if spacing is None: + if self.axes() is Axes.GRID: + spacing = 1 + elif self.axes() is Axes.WORLD: + spacing = self.spacing() + elif self.axes() is Axes.CUBE: + spacing = tuple(2 / n for n in self.grid().size()) + else: + spacing = tuple(2 / (n - 1) for n in self.grid().size()) + data = U.curl(self.tensor(), mode=mode, sigma=sigma, spacing=spacing, stride=stride) return ImageBatch(data, self._grid) def exp( @@ -529,10 +542,16 @@ def write(self, path: PathStr, axes: Optional[Axes] = None, compress: bool = Tru disp = disp.axes(axes or Axes.WORLD) Image.write(disp, path, compress=compress) - def curl(self: TFlowField, mode: str = "central") -> Image: + def curl( + self: TFlowField, + mode: Optional[str] = None, + sigma: Optional[float] = None, + spacing: Optional[Union[Scalar, Array]] = None, + stride: Optional[ScalarOrTuple[int]] = None, + ) -> Image: r"""Compute curl of vector field.""" batch = self.batch() - rotvec = batch.curl(mode=mode) + rotvec = batch.curl(mode=mode, sigma=sigma, spacing=spacing, stride=stride) return rotvec[0] def exp(