Skip to content

Commit

Permalink
[core] Create unbatched image tensors when num=0 (#103)
Browse files Browse the repository at this point in the history
  • Loading branch information
aschuh-hf committed Jul 28, 2023
1 parent 125b3a9 commit d51bbf2
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 11 deletions.
5 changes: 2 additions & 3 deletions src/deepali/core/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,11 +429,10 @@ def warp_image(
def zeros_flow(
size: Optional[Union[int, Size, Grid]] = None,
shape: Optional[Shape] = None,
num: int = 1,
named: bool = False,
num: Optional[int] = None,
dtype: Optional[DType] = None,
device: Optional[Device] = None,
) -> Tensor:
r"""Create batch of flow fields filled with zeros for given image batch size or grid."""
size = _image_size("zeros_flow", size, shape)
return zeros_image(size, num=num, channels=len(size), named=named, dtype=dtype, device=device)
return zeros_image(size, num=num, channels=len(size), dtype=dtype, device=device)
28 changes: 20 additions & 8 deletions src/deepali/core/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -1700,6 +1700,7 @@ def circle_image(
size: Spatial size in the order ``(X, Y)``.
shape: Spatial size in the order ``(Y, X)``.
num: Number ``N`` of images in batch.
If zero, return a single unbatched image tensor.
center: Coordinates of center pixel in the order ``(x, y)``.
radius: Radius of circle in pixel units.
sigma: Standard deviation of isotropic Gaussian blurring kernel in pixel units.
Expand All @@ -1712,7 +1713,7 @@ def circle_image(
device: Device on which to create image tensor.
Returns:
Image tensor of shape ``(N, 1, Y, X)``.
Image tensor of shape ``(N, 1, Y, X)`` or ``(1, Y, X)`` (``num=0``).
"""
size = _image_size("circle_image", size, shape, ndim=2)
Expand Down Expand Up @@ -1769,6 +1770,7 @@ def cshape_image(
size: Spatial size in the order ``(X, Y)``.
shape: Spatial size in the order ``(Y, X)``.
num: Number ``N`` of images in batch.
If zero, return a single unbatched image tensor.
center: Coordinates of center pixel in the order ``(y, x)``.
radius: Radius of circle in pixel units.
width: Difference between outer and inner circle radius.
Expand All @@ -1784,7 +1786,7 @@ def cshape_image(
device: Device on which to create image tensor.
Returns:
Image tensor of shape ``(N, 1, Y, X)``.
Image tensor of shape ``(N, 1, Y, X)`` or ``(1, Y, X)`` (``num=0``).
"""
size = _image_size("cshape_image", size, shape, ndim=2)
Expand Down Expand Up @@ -1832,17 +1834,21 @@ def empty_image(
size: Spatial size in the order ``(X, ...)``.
shape: Spatial size in the order ``(..., X)``.
num: Number of images in batch.
If zero, return a single unbatched image tensor.
channels: Number of channels per image.
dtype: Data type of image tensor.
device: Device on which to store image data.
Returns:
Uninitialized image batch tensor.
Image tensor of shape ``(N, 1, ..., X)`` or ``(1, ..., X)`` (``num=0``).
"""
size = _image_size("empty_image", size, shape)
shape = (num or 1, channels or 1) + tuple(reversed(size))
return torch.empty(shape, dtype=dtype, device=device)
data = torch.empty(shape, dtype=dtype, device=device)
if num == 0:
data = data.squeeze_(0)
return data


def grid_image(
Expand All @@ -1861,6 +1867,7 @@ def grid_image(
shape: Spatial size in the order ``(..., X)``.
num: Number of images in batch. When ``shape`` is not a ``Grid``, must
match the size of the first dimension in ``shape`` if not ``None``.
If zero, return a single unbatched image tensor.
stride: Spacing between grid lines. To draw in-plane grid lines on a
D-dimensional image where ``D>2``, specify a sequence of two stride
values, where the first stride applies to the last tensor dimension,
Expand All @@ -1870,7 +1877,7 @@ def grid_image(
device: Device on which to store image data.
Returns:
Image tensor of shape ``(N, 1, ..., X)``. The default number of channels is 1.
Image tensor of shape ``(N, 1, ..., X)`` or ``(1, ..., X)`` (``num=0``).
"""
size = _image_size("grid_image", size, shape)
Expand All @@ -1889,7 +1896,12 @@ def grid_image(
n = data.shape[dim]
index = torch.arange((n % step) // 2, n, step, dtype=torch.int64, device=data.device)
data.index_fill_(dim, index, 0 if inverted else 1)
return data.expand(num or 1, *data.shape[1:])
if num is not None:
if num == 0:
data = data.squeeze_(0)
elif num > 1:
data = data.expand((1,) + data.shape[1:])
return data


def ones_image(
Expand All @@ -1911,7 +1923,7 @@ def ones_image(
device: Device on which to store image data.
Returns:
Image batch tensor filled with ones.
Image tensor of shape ``(N, 1, ..., X)`` or ``(1, ..., X)`` (``num=0``) filled with ones.
"""
size = _image_size("ones_image", size, shape)
Expand All @@ -1938,7 +1950,7 @@ def zeros_image(
device: Device on which to store image data.
Returns:
Image batch tensor filled with zeros.
Image tensor of shape ``(N, 1, ..., X)`` or ``(1, ..., X)`` (``num=0``) filled with zeros.
"""
size = _image_size("zeros_image", size, shape)
Expand Down

0 comments on commit d51bbf2

Please sign in to comment.