Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gradient with respect to trajectory #30

Merged
merged 25 commits into from
Mar 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions run_tests.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#!/usr/bin/env bash
pip install torchkbnufft scikit-image pytest
python -m pytest tfkbnufft
pip install torch==1.7 torchkbnufft==0.3.4 scikit-image pytest
python -m pytest tfkbnufft --ignore=tfkbnufft/tests/ndft_test.py
python -m pytest tfkbnufft/tests/ndft_test.py
Comment on lines +3 to +4
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
python -m pytest tfkbnufft --ignore=tfkbnufft/tests/ndft_test.py
python -m pytest tfkbnufft/tests/ndft_test.py
python -m pytest tfkbnufft

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh no no, that wont work :P We cant merge them as then the codes just hang as we discussed..

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ok understood that wrong

2 changes: 1 addition & 1 deletion tfkbnufft/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Package info"""

__version__ = '0.1.4'
__version__ = '0.2.0'
__author__ = 'Zaccharie Ramzi'
__author_email__ = 'zaccharie.ramzi@inria.fr'
__license__ = 'MIT'
Expand Down
56 changes: 38 additions & 18 deletions tfkbnufft/kbnufft.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,12 @@ class KbNufftModule(KbModule):

def __init__(self, im_size, grid_size=None, numpoints=6, n_shift=None,
table_oversamp=2**10, kbwidth=2.34, order=0, norm='None',
coil_broadcast=False, matadj=False):
coil_broadcast=False, matadj=False, grad_traj=False):
super(KbNufftModule, self).__init__()

self.im_size = im_size
self.im_rank = len(im_size)
self.grad_traj = grad_traj
if grid_size is None:
self.grid_size = tuple(np.array(self.im_size) * 2)
else:
Expand Down Expand Up @@ -135,6 +136,7 @@ def _extract_nufft_interpob(self):
interpob['norm'] = self.norm
interpob['coil_broadcast'] = self.coil_broadcast
interpob['matadj'] = self.matadj
interpob['grad_traj'] = self.grad_traj
Jgen = []
for i in range(self.im_rank):
# number of points to use for interpolation is numpoints
Expand Down Expand Up @@ -168,18 +170,29 @@ def kbnufft_forward_for_interpob(x, om):
grid_size = interpob['grid_size']
im_size = interpob['im_size']
norm = interpob['norm']
grad_traj = interpob['grad_traj']
im_rank = interpob.get('im_rank', 2)

x = scale_and_fft_on_image_volume(
fft_x = scale_and_fft_on_image_volume(
x, scaling_coef, grid_size, im_size, norm, im_rank=im_rank, multiprocessing=multiprocessing)

y = kbinterp(x, om, interpob)
y = kbinterp(fft_x, om, interpob)

def grad(dy):
x = adjkbinterp(dy, om, interpob)
x = ifft_and_scale_on_gridded_data(
x, scaling_coef, grid_size, im_size, norm, im_rank=im_rank)
return x, None
# Gradients with respect to image
grid_dy = adjkbinterp(dy, om, interpob)
ifft_dy = ifft_and_scale_on_gridded_data(
grid_dy, scaling_coef, grid_size, im_size, norm, im_rank=im_rank)
if grad_traj:
# Gradients with respect to trajectory locations
r = [tf.linspace(-im_size[i]/2, im_size[i]/2-1, im_size[i]) for i in range(im_rank)]
grid_r = tf.cast(tf.meshgrid(*r, indexing='ij'), x.dtype)[None, ...]
fft_dx_dom = scale_and_fft_on_image_volume(
x * grid_r, scaling_coef, grid_size, im_size, norm, im_rank=im_rank)
dy_dom = tf.cast(-1j * dy * kbinterp(fft_dx_dom, om, interpob), tf.float32)
else:
dy_dom = None
return ifft_dy, dy_dom

return y, grad
return kbnufft_forward_for_interpob
Expand All @@ -200,25 +213,32 @@ def kbnufft_adjoint_for_interpob(y, om):
Returns:
tensor: The image after adjoint NUFFT.
"""
x = adjkbinterp(y, om, interpob)

grid_y = adjkbinterp(y, om, interpob)
scaling_coef = interpob['scaling_coef']
grid_size = interpob['grid_size']
im_size = interpob['im_size']
norm = interpob['norm']
grad_traj = interpob['grad_traj']
im_rank = interpob.get('im_rank', 2)

x = ifft_and_scale_on_gridded_data(
x, scaling_coef, grid_size, im_size, norm, im_rank=im_rank, multiprocessing=multiprocessing)
ifft_y = ifft_and_scale_on_gridded_data(
grid_y, scaling_coef, grid_size, im_size, norm, im_rank=im_rank, multiprocessing=multiprocessing)

def grad(dx):
x = scale_and_fft_on_image_volume(
# Gradients with respect to off grid signal
fft_dx = scale_and_fft_on_image_volume(
dx, scaling_coef, grid_size, im_size, norm, im_rank=im_rank)

y = kbinterp(x, om, interpob)

return y, None
return x, grad
dx_dy = kbinterp(fft_dx, om, interpob)
if grad_traj:
# Gradients with respect to trajectory locations
r = [tf.linspace(-im_size[i]/2, im_size[i]/2-1, im_size[i]) for i in range(im_rank)]
grid_r = tf.cast(tf.meshgrid(*r, indexing='ij'), dx.dtype)[None, ...]
ifft_dxr = scale_and_fft_on_image_volume(
dx * grid_r, scaling_coef, grid_size, im_size, norm, im_rank=im_rank, do_ifft=True)
dx_dom = tf.cast(1j * y * kbinterp(ifft_dxr, om, interpob, conj=True), om.dtype)
else:
dx_dom = None
return dx_dy, dx_dom
return ifft_y, grad
return kbnufft_adjoint_for_interpob


Expand Down
104 changes: 75 additions & 29 deletions tfkbnufft/nufft/fft_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,27 @@
import tensorflow as tf
from tensorflow.python.ops.signal.fft_ops import ifft2d, fft2d, fft, ifft

def tf_mp_ifft(kspace):
k_shape_x = tf.shape(kspace)[-1]
batched_kspace = tf.reshape(kspace, (-1, k_shape_x))
batched_image = tf.map_fn(
ifft,
batched_kspace,
parallel_iterations=multiprocessing.cpu_count(),
)
image = tf.reshape(batched_image, tf.shape(kspace))
return image

def tf_mp_fft(kspace):
k_shape_x = tf.shape(kspace)[-1]
batched_kspace = tf.reshape(kspace, (-1, k_shape_x))
batched_image = tf.map_fn(
fft,
batched_kspace,
parallel_iterations=multiprocessing.cpu_count(),
)
image = tf.reshape(batched_image, tf.shape(kspace))
return image

def tf_mp_ifft2d(kspace):
k_shape_x = tf.shape(kspace)[-2]
Expand Down Expand Up @@ -61,7 +82,39 @@ def tf_mp_fourier3d(x, trans_type='inv'):
y = tf.transpose(y_reshaped, [0, 1, 4, 2, 3])
return y

def scale_and_fft_on_image_volume(x, scaling_coef, grid_size, im_size, norm, im_rank=2, multiprocessing=False):
# Generate a fourier dictionary to simplify its use below.
# In the end we have the following list:
# fourier_dict[do_ifft][multiprocessing][rank of image - 1]
fourier_list = [
[
[
tf.signal.fft,
tf.signal.fft2d,
tf.signal.fft3d,
],
[
tf_mp_fft,
tf_mp_fft2d,
tf_mp_fft3d,
]
],
[
[
tf.signal.ifft,
tf.signal.ifft2d,
tf.signal.ifft3d,
],
[
tf_mp_ifft,
tf_mp_ifft2d,
tf_mp_ifft3d,
]
]
]


def scale_and_fft_on_image_volume(x, scaling_coef, grid_size, im_size, norm, im_rank=2, multiprocessing=False,
do_ifft=False):
zaccharieramzi marked this conversation as resolved.
Show resolved Hide resolved
"""Applies the FFT and any relevant scaling factors to x.

Args:
Expand All @@ -72,6 +125,8 @@ def scale_and_fft_on_image_volume(x, scaling_coef, grid_size, im_size, norm, im_
im_size (tensor): The image dimensions for x.
norm (str): Type of normalization factor to use. If 'ortho', uses
orthogonal FFT, otherwise, no normalization is applied.
do_ifft (bool, optional, default False): When true, the IFFT is
carried out on signal rather than FFT. This is needed for gradient.

Returns:
tensor: The oversampled FFT of x.
Expand All @@ -84,34 +139,32 @@ def scale_and_fft_on_image_volume(x, scaling_coef, grid_size, im_size, norm, im_
(0, 0), # coil dimension
] + [
(0, grid_size[0] - im_size[0]), # nx
(0, grid_size[1] - im_size[1]), # ny
]
if im_rank >= 2:
pad_sizes += [(0, grid_size[1] - im_size[1])]
if im_rank == 3:
pad_sizes += [(0, grid_size[2] - im_size[2])] # nz
scaling_coef = tf.cast(scaling_coef, x.dtype)
scaling_coef = scaling_coef[None, None, ...]
# multiply by scaling coefs
x = x * scaling_coef
if do_ifft:
x = x * tf.math.conj(scaling_coef)
else:
x = x * scaling_coef

# zero pad and fft
x = tf.pad(x, pad_sizes)
# this might have to be a tf py function, or I could use tf cond
if im_rank == 2:
if multiprocessing:
x = tf_mp_fft2d(x)
else:
x = tf.signal.fft2d(x)
else:
if multiprocessing:
x = tf_mp_fft3d(x)
else:
x = tf.signal.fft3d(x)
x = fourier_list[do_ifft][multiprocessing][im_rank - 1](x)
if norm == 'ortho':
scaling_factor = tf.cast(tf.reduce_prod(grid_size), x.dtype)
x = x / tf.sqrt(scaling_factor)
if do_ifft:
x = x * tf.sqrt(scaling_factor)
else:
x = x / tf.sqrt(scaling_factor)

return x


def ifft_and_scale_on_gridded_data(x, scaling_coef, grid_size, im_size, norm, im_rank=2, multiprocessing=False):
"""Applies the iFFT and any relevant scaling factors to x.

Expand All @@ -130,22 +183,15 @@ def ifft_and_scale_on_gridded_data(x, scaling_coef, grid_size, im_size, norm, im
# we don't need permutations since the fft in fourier is done on the
# innermost dimensions and we are handling complex tensors
# do the inverse fft
if im_rank == 2:
if multiprocessing:
x = tf_mp_ifft2d(x)
else:
x = tf.signal.ifft2d(x)
else:
if multiprocessing:
x = tf_mp_ifft3d(x)
else:
x = tf.signal.ifft3d(x)

x = fourier_list[True][multiprocessing][im_rank - 1](x)
im_size = tf.cast(im_size, tf.int32)
# crop to output size
x = x[:, :, :im_size[0], :im_size[1]]
if im_rank == 3:
x = x[..., :im_size[2]]
x = x[:, :, :im_size[0]]
if im_rank >=2:
if im_rank == 3:
x = x[..., :im_size[1], :im_size[2]]
else:
x = x[..., :im_size[1]]

# scaling
scaling_factor = tf.cast(tf.reduce_prod(grid_size), x.dtype)
Expand Down
14 changes: 11 additions & 3 deletions tfkbnufft/nufft/interp_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def run_interp(griddat, tm, params):
# loop over offsets and take advantage of broadcasting
for J in Jlist:
coef, arr_ind = calc_coef_and_indices(
tm, kofflist, J, table, centers, L, dims)
tm, kofflist, J, table, centers, L, dims, conjcoef=params['conjcoef'])
coef = tf.cast(coef, griddat.dtype)
# I don't need to expand on coil dimension since I use tf gather and not
# gather_nd
Expand Down Expand Up @@ -164,7 +164,7 @@ def run_interp_back(kdat, tm, params):
return griddat

@tf.function(experimental_relax_shapes=True)
def kbinterp(x, om, interpob):
def kbinterp(x, om, interpob, conj=False):
chaithyagr marked this conversation as resolved.
Show resolved Hide resolved
"""Apply table interpolation.

Inputs are assumed to be batch/chans x coil x image dims.
Expand All @@ -176,6 +176,9 @@ def kbinterp(x, om, interpob):
interpolate to in radians/voxel.
interpob (dict): An interpolation object with 'table', 'n_shift',
'grid_size', 'numpoints', and 'table_oversamp' keys.
conj (bool, optional, default False): Boolean value to check if
conjugate value of interpolator coefficient must be used.
This is need for gradients calculation

Returns:
tensor: The signal interpolated to off-grid locations.
Expand Down Expand Up @@ -205,10 +208,15 @@ def kbinterp(x, om, interpob):
# run the table interpolator for each batch element
# TODO: look into how to use tf.while_loop
params['dims'] = tf.cast(tf.shape(x[0])[1:], 'int64')
params['conjcoef'] = conj
def _map_body(inputs):
_x, _tm, _om = inputs
y_not_shifted = run_interp(tf.reshape(_x, (tf.shape(_x)[0], -1)), _tm, params)
y = y_not_shifted * tf.exp(1j * tf.cast(tf.linalg.matvec(tf.transpose(_om), n_shift), y_not_shifted.dtype))[None, ...]
shift = tf.exp(1j * tf.cast(tf.linalg.matvec(tf.transpose(_om), n_shift), y_not_shifted.dtype))[None, ...]
if conj:
y = y_not_shifted * tf.math.conj(shift)
else:
y = y_not_shifted * shift
return y

y = tf.map_fn(_map_body, [x, tm, om], dtype=x.dtype)
Expand Down
Loading