From d51bbf24d5e52651e2612e80925beab4574f9061 Mon Sep 17 00:00:00 2001 From: Andreas Schuh <77496589+aschuh-hf@users.noreply.github.com> Date: Fri, 28 Jul 2023 06:26:20 -0700 Subject: [PATCH] [core] Create unbatched image tensors when num=0 (#103) --- src/deepali/core/flow.py | 5 ++--- src/deepali/core/image.py | 28 ++++++++++++++++++++-------- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/src/deepali/core/flow.py b/src/deepali/core/flow.py index 9e3fe99..f1f0efd 100644 --- a/src/deepali/core/flow.py +++ b/src/deepali/core/flow.py @@ -429,11 +429,10 @@ def warp_image( def zeros_flow( size: Optional[Union[int, Size, Grid]] = None, shape: Optional[Shape] = None, - num: int = 1, - named: bool = False, + num: Optional[int] = None, dtype: Optional[DType] = None, device: Optional[Device] = None, ) -> Tensor: r"""Create batch of flow fields filled with zeros for given image batch size or grid.""" size = _image_size("zeros_flow", size, shape) - return zeros_image(size, num=num, channels=len(size), named=named, dtype=dtype, device=device) + return zeros_image(size, num=num, channels=len(size), dtype=dtype, device=device) diff --git a/src/deepali/core/image.py b/src/deepali/core/image.py index 671ecd5..7c8648b 100644 --- a/src/deepali/core/image.py +++ b/src/deepali/core/image.py @@ -1700,6 +1700,7 @@ def circle_image( size: Spatial size in the order ``(X, Y)``. shape: Spatial size in the order ``(Y, X)``. num: Number ``N`` of images in batch. + If zero, return a single unbatched image tensor. center: Coordinates of center pixel in the order ``(x, y)``. radius: Radius of circle in pixel units. sigma: Standard deviation of isotropic Gaussian blurring kernel in pixel units. @@ -1712,7 +1713,7 @@ def circle_image( device: Device on which to create image tensor. Returns: - Image tensor of shape ``(N, 1, Y, X)``. + Image tensor of shape ``(N, 1, Y, X)`` or ``(1, Y, X)`` (``num=0``). """ size = _image_size("circle_image", size, shape, ndim=2) @@ -1769,6 +1770,7 @@ def cshape_image( size: Spatial size in the order ``(X, Y)``. shape: Spatial size in the order ``(Y, X)``. num: Number ``N`` of images in batch. + If zero, return a single unbatched image tensor. center: Coordinates of center pixel in the order ``(y, x)``. radius: Radius of circle in pixel units. width: Difference between outer and inner circle radius. @@ -1784,7 +1786,7 @@ def cshape_image( device: Device on which to create image tensor. Returns: - Image tensor of shape ``(N, 1, Y, X)``. + Image tensor of shape ``(N, 1, Y, X)`` or ``(1, Y, X)`` (``num=0``). """ size = _image_size("cshape_image", size, shape, ndim=2) @@ -1832,17 +1834,21 @@ def empty_image( size: Spatial size in the order ``(X, ...)``. shape: Spatial size in the order ``(..., X)``. num: Number of images in batch. + If zero, return a single unbatched image tensor. channels: Number of channels per image. dtype: Data type of image tensor. device: Device on which to store image data. Returns: - Uninitialized image batch tensor. + Image tensor of shape ``(N, 1, ..., X)`` or ``(1, ..., X)`` (``num=0``). """ size = _image_size("empty_image", size, shape) shape = (num or 1, channels or 1) + tuple(reversed(size)) - return torch.empty(shape, dtype=dtype, device=device) + data = torch.empty(shape, dtype=dtype, device=device) + if num == 0: + data = data.squeeze_(0) + return data def grid_image( @@ -1861,6 +1867,7 @@ def grid_image( shape: Spatial size in the order ``(..., X)``. num: Number of images in batch. When ``shape`` is not a ``Grid``, must match the size of the first dimension in ``shape`` if not ``None``. + If zero, return a single unbatched image tensor. stride: Spacing between grid lines. To draw in-plane grid lines on a D-dimensional image where ``D>2``, specify a sequence of two stride values, where the first stride applies to the last tensor dimension, @@ -1870,7 +1877,7 @@ def grid_image( device: Device on which to store image data. Returns: - Image tensor of shape ``(N, 1, ..., X)``. The default number of channels is 1. + Image tensor of shape ``(N, 1, ..., X)`` or ``(1, ..., X)`` (``num=0``). """ size = _image_size("grid_image", size, shape) @@ -1889,7 +1896,12 @@ def grid_image( n = data.shape[dim] index = torch.arange((n % step) // 2, n, step, dtype=torch.int64, device=data.device) data.index_fill_(dim, index, 0 if inverted else 1) - return data.expand(num or 1, *data.shape[1:]) + if num is not None: + if num == 0: + data = data.squeeze_(0) + elif num > 1: + data = data.expand((1,) + data.shape[1:]) + return data def ones_image( @@ -1911,7 +1923,7 @@ def ones_image( device: Device on which to store image data. Returns: - Image batch tensor filled with ones. + Image tensor of shape ``(N, 1, ..., X)`` or ``(1, ..., X)`` (``num=0``) filled with ones. """ size = _image_size("ones_image", size, shape) @@ -1938,7 +1950,7 @@ def zeros_image( device: Device on which to store image data. Returns: - Image batch tensor filled with zeros. + Image tensor of shape ``(N, 1, ..., X)`` or ``(1, ..., X)`` (``num=0``) filled with zeros. """ size = _image_size("zeros_image", size, shape)