Skip to content
This repository has been archived by the owner on Apr 29, 2024. It is now read-only.

Commit

Permalink
* minor fixes for reproducibility
Browse files Browse the repository at this point in the history
  • Loading branch information
dependabot[bot] authored and wdika committed Jan 23, 2023
1 parent 19b5023 commit 6d15dd5
Show file tree
Hide file tree
Showing 48 changed files with 2,121 additions and 948 deletions.
2 changes: 1 addition & 1 deletion codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ coverage:
status:
patch: true

range: 70..100 # First number represents red, and second represents green
range: 60..100 # First number represents red, and second represents green
# (default is 70..100)
round: nearest # up, down, or nearest
precision: 2 # Number of decimal places, between 0 and 5
Expand Down
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Welcome to mridc's documentation!
:parser: myst_parser.sphinx_

.. toctree::
:maxdepth: 4
:maxdepth: 6
:caption: API Documentation:

modules.rst
212 changes: 161 additions & 51 deletions mridc/collections/common/parts/fft.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,67 @@
# coding=utf-8
__author__ = "Dimitrios Karkalousos"

# Parts of the code have been taken from https://github.com/facebookresearch/fastMRI

from typing import List, Sequence, Union

import numpy as np
import torch
from omegaconf import ListConfig

__all__ = ["fft2", "ifft2"]
__all__ = ["fft2", "ifft2", "fftshift", "ifftshift"]


def fft2(
data: torch.Tensor,
centered: bool = True,
normalization: str = "ortho",
centered: bool = False,
normalization: str = "backward",
spatial_dims: Sequence[int] = None,
) -> torch.Tensor:
"""
Apply 2 dimensional Fast Fourier Transform.
Parameters
----------
data: Complex valued input data containing at least 3 dimensions: dimensions -2 & -1 are spatial dimensions. All
other dimensions are assumed to be batch dimensions.
centered: Whether to center the fft.
normalization: "ortho" is the default normalization used by PyTorch. Can be changed to "ortho" or None.
spatial_dims: dimensions to apply the FFT
data : Complex valued input data containing at least 2 dimensions: dimensions -2 & -1 are spatial dimensions.
centered : Whether to center the fft. If True, the fft will be shifted so that the zero frequency component is
in the center of the spectrum. Default is False.
normalization : Normalization mode. For the forward transform (fft2()), these correspond to:
* "forward" - normalize by 1/n
* "backward" - no normalization
* "ortho" - normalize by 1/sqrt(n) (making the FFT orthonormal)
Where n = prod(s) is the logical FFT size.
Calling the backward transform (ifft2()) with the same normalization mode will apply an overall normalization
of 1/n between the two transforms. This is required to make ifft2() the exact inverse.
Default is "backward" (no normalization).
spatial_dims : Dimensions to apply the FFT. Default is the last two dimensions.
If tensor is viewed as real, the last dimension is assumed to be the complex dimension.
Returns
-------
The FFT of the input.
The 2D FFT of the input.
Examples
--------
>>> import torch
>>> from mridc.collections.common.parts.fft import fft2
>>> x = torch.randn(2, 3, 4, 5, 2)
>>> fft2(x).shape
torch.Size([2, 3, 4, 5, 2])
>>> fft2(x, centered=True).shape
torch.Size([2, 3, 4, 5, 2])
>>> fft2(x, centered=True, normalization="ortho").shape
torch.Size([2, 3, 4, 5, 2])
>>> fft2(x, centered=True, normalization="ortho", spatial_dims=[-3, -2]).shape
torch.Size([2, 3, 4, 5, 2])
Notes
-----
The PyTorch fft2 function does not support complex tensors. Therefore, the input is converted to a complex tensor
and then converted back to a real tensor. This is done by using the torch.view_as_complex and torch.view_as_real
functions. The input is assumed to be a real tensor with the last dimension being the complex dimension.
The PyTorch fft2 function performs a separate fft, so fft2 is the same as fft(fft(data, dim=-2), dim=-1).
Source: https://pytorch.org/docs/stable/fft.html#torch.fft.fft2
"""
if data.shape[-1] == 2:
data = torch.view_as_complex(data)
Expand Down Expand Up @@ -60,24 +90,56 @@ def fft2(

def ifft2(
data: torch.Tensor,
centered: bool = True,
normalization: str = "ortho",
centered: bool = False,
normalization: str = "backward",
spatial_dims: Sequence[int] = None,
) -> torch.Tensor:
"""
Apply 2 dimensional Inverse Fast Fourier Transform.
Parameters
----------
data: Complex valued input data containing at least 3 dimensions: dimensions -2 & -1 are spatial dimensions. All
other dimensions are assumed to be batch dimensions.
centered: Whether to center the fft.
normalization: "ortho" is the default normalization used by PyTorch. Can be changed to "ortho" or None.
spatial_dims: dimensions to apply the FFT
data : Complex valued input data containing at least 2 dimensions: dimensions -2 & -1 are spatial dimensions.
centered : Whether to center the ifft. If True, the ifft will be shifted so that the zero frequency component is
in the center of the spectrum. Default is False.
normalization : Normalization mode. For the backward transform (ifft2()), these correspond to:
* "forward" - normalize by 1/n
* "backward" - no normalization
* "ortho" - normalize by 1/sqrt(n) (making the IFFT orthonormal)
Where n = prod(s) is the logical IFFT size.
Calling the forward transform (fft2()) with the same normalization mode will apply an overall normalization
of 1/n between the two transforms. This is required to make ifft2() the exact inverse.
Default is "backward" (no normalization).
spatial_dims : Dimensions to apply the IFFT. Default is the last two dimensions.
If tensor is viewed as real, the last dimension is assumed to be the complex dimension.
Returns
-------
The FFT of the input.
The 2D IFFT of the input.
Examples
--------
>>> import torch
>>> from mridc.collections.common.parts.fft import ifft2
>>> x = torch.randn(2, 3, 4, 5, 2)
>>> ifft2(x).shape
torch.Size([2, 3, 4, 5, 2])
>>> ifft2(x, centered=True).shape
torch.Size([2, 3, 4, 5, 2])
>>> ifft2(x, centered=True, normalization="ortho").shape
torch.Size([2, 3, 4, 5, 2])
>>> ifft2(x, centered=True, normalization="ortho", spatial_dims=[-3, -2]).shape
torch.Size([2, 3, 4, 5, 2])
Notes
-----
The PyTorch ifft2 function does not support complex tensors. Therefore, the input is converted to a complex tensor
and then converted back to a real tensor. This is done by using the torch.view_as_complex and torch.view_as_real
functions. The input is assumed to be a real tensor with the last dimension being the complex dimension.
The PyTorch ifft2 function performs a separate ifft, so ifft2 is the same as ifft(ifft(data, dim=-2), dim=-1).
Source: https://pytorch.org/docs/stable/fft.html#torch.fft.ifft2
"""
if data.shape[-1] == 2:
data = torch.view_as_complex(data)
Expand All @@ -104,43 +166,67 @@ def ifft2(
return data


def roll_one_dim(x: torch.Tensor, shift: int, dim: int) -> torch.Tensor:
def roll_one_dim(data: torch.Tensor, shift: int, dim: int) -> torch.Tensor:
"""
Similar to roll but for only one dim.
Parameters
----------
x: A PyTorch tensor.
shift: Amount to roll.
dim: Which dimension to roll.
data : A PyTorch tensor.
shift : Amount to roll.
dim : Which dimension to roll.
Returns
-------
Rolled version of x.
Rolled version of data.
Examples
--------
>>> import torch
>>> from mridc.collections.common.parts.fft import roll_one_dim
>>> x = torch.randn(2, 3, 4, 5)
>>> roll_one_dim(x, 1, 0).shape
torch.Size([2, 3, 4, 5])
Notes
-----
Source: https://github.com/facebookresearch/fastMRI/blob/main/fastmri/fftc.py
"""
shift %= x.size(dim)
shift %= data.size(dim)
if shift == 0:
return x
return data

left = x.narrow(dim, 0, x.size(dim) - shift)
right = x.narrow(dim, x.size(dim) - shift, shift)
left = data.narrow(dim, 0, data.size(dim) - shift)
right = data.narrow(dim, data.size(dim) - shift, shift)

return torch.cat((right, left), dim=dim)


def roll(x: torch.Tensor, shift: List[int], dim: Union[List[int], Sequence[int]]) -> torch.Tensor:
def roll(data: torch.Tensor, shift: List[int], dim: Union[List[int], Sequence[int]]) -> torch.Tensor:
"""
Similar to np.roll but applies to PyTorch Tensors.
Parameters
----------
x: A PyTorch tensor.
shift: Amount to roll.
dim: Which dimension to roll.
data : A PyTorch tensor.
shift : Amount to roll.
dim : Which dimension to roll.
Returns
-------
Rolled version of x.
Rolled version of data.
Examples
--------
>>> import torch
>>> from mridc.collections.common.parts.fft import roll
>>> x = torch.randn(2, 3, 4, 5)
>>> roll(x, [1, 2], [0, 1]).shape
torch.Size([2, 3, 4, 5])
Notes
-----
Source: https://github.com/facebookresearch/fastMRI/blob/main/fastmri/fftc.py
"""
if len(shift) != len(dim):
raise ValueError("len(shift) must match len(dim)")
Expand All @@ -149,64 +235,88 @@ def roll(x: torch.Tensor, shift: List[int], dim: Union[List[int], Sequence[int]]
dim = list(dim)

for (s, d) in zip(shift, dim):
x = roll_one_dim(x, s, d)
data = roll_one_dim(data, s, d)

return x
return data


def fftshift(x: torch.Tensor, dim: Union[List[int], Sequence[int]] = None) -> torch.Tensor:
def fftshift(data: torch.Tensor, dim: Union[List[int], Sequence[int]] = None) -> torch.Tensor:
"""
Similar to np.fft.fftshift but applies to PyTorch Tensors
Parameters
----------
x: A PyTorch tensor.
dim: Which dimension to fftshift.
data : A PyTorch tensor.
dim : Which dimension to fftshift.
Returns
-------
fftshifted version of x.
fftshifted version of data.
Examples
--------
>>> import torch
>>> from mridc.collections.common.parts.fft import fftshift
>>> x = torch.randn(2, 3, 4, 5)
>>> fftshift(x).shape
torch.Size([2, 3, 4, 5])
Notes
-----
Source: https://github.com/facebookresearch/fastMRI/blob/main/fastmri/fftc.py
"""
if dim is None:
# this weird code is necessary for torch.jit.script typing
dim = [0] * (x.dim())
for i in range(1, x.dim()):
dim = [0] * (data.dim())
for i in range(1, data.dim()):
dim[i] = i
elif isinstance(dim, ListConfig):
dim = list(dim)

# Also necessary for torch.jit.script
shift = [0] * len(dim)
for i, dim_num in enumerate(dim):
shift[i] = np.floor_divide(x.shape[dim_num], 2)
shift[i] = np.floor_divide(data.shape[dim_num], 2)

return roll(x, shift, dim)
return roll(data, shift, dim)


def ifftshift(x: torch.Tensor, dim: Union[List[int], Sequence[int]] = None) -> torch.Tensor:
def ifftshift(data: torch.Tensor, dim: Union[List[int], Sequence[int]] = None) -> torch.Tensor:
"""
Similar to np.fft.ifftshift but applies to PyTorch Tensors
Parameters
----------
x: A PyTorch tensor.
dim: Which dimension to ifftshift.
data : A PyTorch tensor.
dim : Which dimension to ifftshift.
Returns
-------
ifftshifted version of x.
ifftshifted version of data.
Examples
--------
>>> import torch
>>> from mridc.collections.common.parts.fft import ifftshift
>>> x = torch.randn(2, 3, 4, 5)
>>> ifftshift(x).shape
torch.Size([2, 3, 4, 5])
Notes
-----
Source: https://github.com/facebookresearch/fastMRI/blob/main/fastmri/fftc.py
"""
if dim is None:
# this weird code is necessary for torch.jit.script typing
dim = [0] * (x.dim())
for i in range(1, x.dim()):
dim = [0] * (data.dim())
for i in range(1, data.dim()):
dim[i] = i
elif isinstance(dim, ListConfig):
dim = list(dim)

# Also necessary for torch.jit.script
shift = [0] * len(dim)
for i, dim_num in enumerate(dim):
shift[i] = np.floor_divide(x.shape[dim_num] + 1, 2)
shift[i] = np.floor_divide(data.shape[dim_num] + 1, 2)

return roll(x, shift, dim)
return roll(data, shift, dim)
4 changes: 2 additions & 2 deletions mridc/collections/reconstruction/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,9 +307,9 @@ def validation_step(self, batch: Dict[float, torch.Tensor], batch_idx: int) -> D
preds = next(preds)
except StopIteration:
pass
val_loss = sum(self.process_loss(target, preds, _loss_fn=self.eval_loss_fn, mask=None))
val_loss = sum(self.process_loss(target, preds, _loss_fn=self.val_loss_fn, mask=None))
else:
val_loss = self.process_loss(target, preds, _loss_fn=self.eval_loss_fn, mask=None)
val_loss = self.process_loss(target, preds, _loss_fn=self.val_loss_fn, mask=None)

# Cascades
if isinstance(preds, list):
Expand Down
2 changes: 1 addition & 1 deletion mridc/collections/reconstruction/models/crnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def process_intermediate_pred(self, pred, sensitivity_maps, target):
_, pred = utils.center_crop_to_smallest(target, pred)
return pred

def process_loss(self, target, pred, _loss_fn):
def process_loss(self, target, pred, _loss_fn=None, mask=None):
"""
Process the loss.
Expand Down
1 change: 1 addition & 0 deletions mridc/collections/reconstruction/models/dunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
elif data_consistency_term == "PROX":
dc_layer = dc_layers.DataProxCGLayer(
lambda_init=cfg_dict.get("data_consistency_lambda_init"),
iter=cfg_dict.get("data_consistency_iterations"),
fft_centered=self.fft_centered,
fft_normalization=self.fft_normalization,
spatial_dims=self.spatial_dims,
Expand Down
3 changes: 2 additions & 1 deletion mridc/collections/reconstruction/models/rim/rim_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def log_likelihood_gradient(
-------
Gradient of the log-likelihood function.
"""
coil_dim = 1
if coil_dim == 0:
coil_dim += 1

eta_real, eta_imag = map(lambda x: torch.unsqueeze(x, coil_dim), eta.chunk(2, -1))
sense_real, sense_imag = sense.chunk(2, -1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def AT(x):
x * mask,
centered=ctx.fft_centered,
normalization=ctx.fft_normalization,
spatial_dims=ctx.spatial_dimso,
spatial_dims=ctx.spatial_dims,
),
utils.complex_conj(smaps),
),
Expand Down
Loading

0 comments on commit 6d15dd5

Please sign in to comment.