From 97b990df794d102da49c9f233f42f3ce1aa0823a Mon Sep 17 00:00:00 2001 From: chaithyagr Date: Tue, 16 Mar 2021 14:32:49 +0100 Subject: [PATCH 01/24] Add gradient with respect to trajectory --- tfkbnufft/kbnufft.py | 59 ++++++++++++++++++++--------- tfkbnufft/nufft/fft_functions.py | 42 +++++++++++++++----- tfkbnufft/nufft/interp_functions.py | 13 +++++-- 3 files changed, 85 insertions(+), 29 deletions(-) diff --git a/tfkbnufft/kbnufft.py b/tfkbnufft/kbnufft.py index 497fd78..1af5364 100644 --- a/tfkbnufft/kbnufft.py +++ b/tfkbnufft/kbnufft.py @@ -170,16 +170,30 @@ def kbnufft_forward_for_interpob(x, om): norm = interpob['norm'] im_rank = interpob.get('im_rank', 2) - x = scale_and_fft_on_image_volume( + fft_x = scale_and_fft_on_image_volume( x, scaling_coef, grid_size, im_size, norm, im_rank=im_rank, multiprocessing=multiprocessing) - y = kbinterp(x, om, interpob) + y = kbinterp(fft_x, om, interpob) def grad(dy): - x = adjkbinterp(dy, om, interpob) - x = ifft_and_scale_on_gridded_data( - x, scaling_coef, grid_size, im_size, norm, im_rank=im_rank) - return x, None + # Gradients with respect to image + grid_dy = adjkbinterp(dy, om, interpob) + ifft_dy = ifft_and_scale_on_gridded_data( + grid_dy, scaling_coef, grid_size, im_size, norm, im_rank=im_rank) + + # Gradients with respect to trajectory locations + r = ( + tf.linspace(-im_size[0]/2, im_size[0]/2-1, im_size[0]), + ) + if im_rank >=2: + r = r + (tf.linspace(-im_size[1]/2, im_size[1]/2-1, im_size[1]),) + if im_rank == 3: + r = r + (tf.linspace(-im_size[2]/2, im_size[2]/2-1, im_size[2]),) + grid_r = tf.cast(tf.meshgrid(*r, indexing='ij'), x.dtype)[None, ...] + fft_dx_dom = scale_and_fft_on_image_volume( + x * grid_r, scaling_coef, grid_size, im_size, norm, im_rank=im_rank, multiprocessing=multiprocessing) + dy_dom = tf.cast(-1j * kbinterp(fft_dx_dom, om, interpob), tf.float32) + return ifft_dy, dy_dom return y, grad return kbnufft_forward_for_interpob @@ -200,25 +214,36 @@ def kbnufft_adjoint_for_interpob(y, om): Returns: tensor: The image after adjoint NUFFT. """ - x = adjkbinterp(y, om, interpob) - + grid_y = adjkbinterp(y, om, interpob) scaling_coef = interpob['scaling_coef'] grid_size = interpob['grid_size'] im_size = interpob['im_size'] norm = interpob['norm'] im_rank = interpob.get('im_rank', 2) - - x = ifft_and_scale_on_gridded_data( - x, scaling_coef, grid_size, im_size, norm, im_rank=im_rank, multiprocessing=multiprocessing) + ifft_y = ifft_and_scale_on_gridded_data( + grid_y, scaling_coef, grid_size, im_size, norm, im_rank=im_rank, multiprocessing=multiprocessing) def grad(dx): - x = scale_and_fft_on_image_volume( + # Gradients with respect to off grid signal + fft_dx = scale_and_fft_on_image_volume( dx, scaling_coef, grid_size, im_size, norm, im_rank=im_rank) - - y = kbinterp(x, om, interpob) - - return y, None - return x, grad + dx_dy = kbinterp(fft_dx, om, interpob) + + # Gradients with respect to trajectory locations + r = ( + tf.linspace(-im_size[0]/2, im_size[0]/2-1, im_size[0]), + ) + if im_rank >=2: + r = r + (tf.linspace(-im_size[1]/2, im_size[1]/2-1, im_size[1]),) + if im_rank == 3: + r = r + (tf.linspace(-im_size[2]/2, im_size[2]/2-1, im_size[2]),) + grid_r = tf.cast(tf.meshgrid(*r, indexing='ij'), y.dtype)[None, ...] + ifft_dxr = scale_and_fft_on_image_volume( + dx * grid_r, scaling_coef, grid_size, im_size, norm, im_rank=im_rank, do_ifft=True) + dx_dom = tf.cast(1j * y * kbinterp(ifft_dxr, om, interpob, conj=True), om.dtype) + + return dx_dy, dx_dom + return ifft_y, grad return kbnufft_adjoint_for_interpob diff --git a/tfkbnufft/nufft/fft_functions.py b/tfkbnufft/nufft/fft_functions.py index efb5364..da76642 100644 --- a/tfkbnufft/nufft/fft_functions.py +++ b/tfkbnufft/nufft/fft_functions.py @@ -61,7 +61,8 @@ def tf_mp_fourier3d(x, trans_type='inv'): y = tf.transpose(y_reshaped, [0, 1, 4, 2, 3]) return y -def scale_and_fft_on_image_volume(x, scaling_coef, grid_size, im_size, norm, im_rank=2, multiprocessing=False): +def scale_and_fft_on_image_volume(x, scaling_coef, grid_size, im_size, norm, im_rank=2, multiprocessing=False, + do_ifft=False): """Applies the FFT and any relevant scaling factors to x. Args: @@ -72,6 +73,8 @@ def scale_and_fft_on_image_volume(x, scaling_coef, grid_size, im_size, norm, im_ im_size (tensor): The image dimensions for x. norm (str): Type of normalization factor to use. If 'ortho', uses orthogonal FFT, otherwise, no normalization is applied. + do_ifft (bool, optional, default False): When true, the IFFT is + carried out on signal rather than FFT. This is needed for gradient. Returns: tensor: The oversampled FFT of x. @@ -84,31 +87,50 @@ def scale_and_fft_on_image_volume(x, scaling_coef, grid_size, im_size, norm, im_ (0, 0), # coil dimension ] + [ (0, grid_size[0] - im_size[0]), # nx - (0, grid_size[1] - im_size[1]), # ny ] + if im_rank >= 2: + pad_sizes += [(0, grid_size[1] - im_size[1])] if im_rank == 3: pad_sizes += [(0, grid_size[2] - im_size[2])] # nz scaling_coef = tf.cast(scaling_coef, x.dtype) scaling_coef = scaling_coef[None, None, ...] # multiply by scaling coefs - x = x * scaling_coef + if do_ifft: + x = x * tf.math.conj(scaling_coef) + else: + x = x * scaling_coef # zero pad and fft x = tf.pad(x, pad_sizes) # this might have to be a tf py function, or I could use tf cond if im_rank == 2: if multiprocessing: - x = tf_mp_fft2d(x) + if do_ifft: + x = tf_mp_ifft2d(x) + else: + x = tf_mp_fft2d(x) else: - x = tf.signal.fft2d(x) + if do_ifft: + x = tf.signal.ifft2d(x) + else: + x = tf.signal.fft2d(x) else: if multiprocessing: - x = tf_mp_fft3d(x) + if do_ifft: + x = tf_mp_ifft3d(x) + else: + x = tf_mp_fft3d(x) else: - x = tf.signal.fft3d(x) + if do_ifft: + x = tf.signal.ifft3d(x) + else: + x = tf.signal.fft3d(x) if norm == 'ortho': scaling_factor = tf.cast(tf.reduce_prod(grid_size), x.dtype) - x = x / tf.sqrt(scaling_factor) + if do_ifft: + x = x * tf.sqrt(scaling_factor) + else: + x = x / tf.sqrt(scaling_factor) return x @@ -143,7 +165,9 @@ def ifft_and_scale_on_gridded_data(x, scaling_coef, grid_size, im_size, norm, im im_size = tf.cast(im_size, tf.int32) # crop to output size - x = x[:, :, :im_size[0], :im_size[1]] + x = x[:, :, :im_size[0]] + if im_rank >=2: + x = x[..., :im_size[1]] if im_rank == 3: x = x[..., :im_size[2]] diff --git a/tfkbnufft/nufft/interp_functions.py b/tfkbnufft/nufft/interp_functions.py index 89275d9..deb6649 100644 --- a/tfkbnufft/nufft/interp_functions.py +++ b/tfkbnufft/nufft/interp_functions.py @@ -98,7 +98,7 @@ def run_interp(griddat, tm, params): # loop over offsets and take advantage of broadcasting for J in Jlist: coef, arr_ind = calc_coef_and_indices( - tm, kofflist, J, table, centers, L, dims) + tm, kofflist, J, table, centers, L, dims, conjcoef=params['conjcoef']) coef = tf.cast(coef, griddat.dtype) # I don't need to expand on coil dimension since I use tf gather and not # gather_nd @@ -164,7 +164,7 @@ def run_interp_back(kdat, tm, params): return griddat @tf.function(experimental_relax_shapes=True) -def kbinterp(x, om, interpob): +def kbinterp(x, om, interpob, conj=False): """Apply table interpolation. Inputs are assumed to be batch/chans x coil x image dims. @@ -176,6 +176,9 @@ def kbinterp(x, om, interpob): interpolate to in radians/voxel. interpob (dict): An interpolation object with 'table', 'n_shift', 'grid_size', 'numpoints', and 'table_oversamp' keys. + conj (bool, optional, default False): Boolean value to check if + conjugate value of interpolator coefficient must be used. + This is need for gradients calculation Returns: tensor: The signal interpolated to off-grid locations. @@ -205,10 +208,14 @@ def kbinterp(x, om, interpob): # run the table interpolator for each batch element # TODO: look into how to use tf.while_loop params['dims'] = tf.cast(tf.shape(x[0])[1:], 'int64') + params['conjcoef'] = conj def _map_body(inputs): _x, _tm, _om = inputs y_not_shifted = run_interp(tf.reshape(_x, (tf.shape(_x)[0], -1)), _tm, params) - y = y_not_shifted * tf.exp(1j * tf.cast(tf.linalg.matvec(tf.transpose(_om), n_shift), y_not_shifted.dtype))[None, ...] + if conj: + y = y_not_shifted * tf.math.conj(tf.exp(1j * tf.cast(tf.linalg.matvec(tf.transpose(_om), n_shift), y_not_shifted.dtype)))[None, ...] + else: + y = y_not_shifted * tf.exp(1j * tf.cast(tf.linalg.matvec(tf.transpose(_om), n_shift), y_not_shifted.dtype))[None, ...] return y y = tf.map_fn(_map_body, [x, tm, om], dtype=x.dtype) From bc40c14b1e25b2abc0d361af090f03c96721204c Mon Sep 17 00:00:00 2001 From: chaithyagr Date: Tue, 16 Mar 2021 15:37:28 +0100 Subject: [PATCH 02/24] Add back dy --- tfkbnufft/kbnufft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tfkbnufft/kbnufft.py b/tfkbnufft/kbnufft.py index 1af5364..b0ae15f 100644 --- a/tfkbnufft/kbnufft.py +++ b/tfkbnufft/kbnufft.py @@ -192,7 +192,7 @@ def grad(dy): grid_r = tf.cast(tf.meshgrid(*r, indexing='ij'), x.dtype)[None, ...] fft_dx_dom = scale_and_fft_on_image_volume( x * grid_r, scaling_coef, grid_size, im_size, norm, im_rank=im_rank, multiprocessing=multiprocessing) - dy_dom = tf.cast(-1j * kbinterp(fft_dx_dom, om, interpob), tf.float32) + dy_dom = tf.cast(-1j * dy * kbinterp(fft_dx_dom, om, interpob), tf.float32) return ifft_dy, dy_dom return y, grad From 9b66b9bff83dcfed5698eb700f8e45690e9ec86f Mon Sep 17 00:00:00 2001 From: Chaithya G R Date: Tue, 16 Mar 2021 15:47:30 +0100 Subject: [PATCH 03/24] Update tfkbnufft/kbnufft.py Co-authored-by: Zaccharie Ramzi --- tfkbnufft/kbnufft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tfkbnufft/kbnufft.py b/tfkbnufft/kbnufft.py index b0ae15f..a385b64 100644 --- a/tfkbnufft/kbnufft.py +++ b/tfkbnufft/kbnufft.py @@ -191,7 +191,7 @@ def grad(dy): r = r + (tf.linspace(-im_size[2]/2, im_size[2]/2-1, im_size[2]),) grid_r = tf.cast(tf.meshgrid(*r, indexing='ij'), x.dtype)[None, ...] fft_dx_dom = scale_and_fft_on_image_volume( - x * grid_r, scaling_coef, grid_size, im_size, norm, im_rank=im_rank, multiprocessing=multiprocessing) + x * grid_r, scaling_coef, grid_size, im_size, norm, im_rank=im_rank) dy_dom = tf.cast(-1j * dy * kbinterp(fft_dx_dom, om, interpob), tf.float32) return ifft_dy, dy_dom From bd19d46b08f059f308d469ad3d14c5767a623341 Mon Sep 17 00:00:00 2001 From: chaithyagr Date: Tue, 16 Mar 2021 15:53:25 +0100 Subject: [PATCH 04/24] Move r to new for loop --- notebooks/2d.py | 35 +++++++++++++++++++++++++ notebooks/2d_adj.py | 46 +++++++++++++++++++++++++++++++++ notebooks/nfft.py | 25 ++++++++++++++++++ tfkbnufft/kbnufft.py | 16 ++---------- tfkbnufft/tests/kbnufft_test.py | 9 ++++++- 5 files changed, 116 insertions(+), 15 deletions(-) create mode 100644 notebooks/2d.py create mode 100644 notebooks/2d_adj.py create mode 100644 notebooks/nfft.py diff --git a/notebooks/2d.py b/notebooks/2d.py new file mode 100644 index 0000000..657e22b --- /dev/null +++ b/notebooks/2d.py @@ -0,0 +1,35 @@ +import tensorflow as tf +import numpy as np +import matplotlib.pyplot as plt +from tfkbnufft import kbnufft_forward, kbnufft_adjoint +from tfkbnufft.kbnufft import KbNufftModule +from mri.operators import NonCartesianFFT +tf.config.run_functions_eagerly(True) + +N = 20 +M = 20*5 +nufft_ob = KbNufftModule(im_size=(N, N), grid_size=(2*N, 2*N), norm='ortho') +ktraj = tf.Variable(tf.random.uniform((1, 2, M), minval=-1/2, maxval=1/2)*2*np.pi) +ktraj_seq = tf.Variable(tf.cast( + np.reshape(np.meshgrid( + np.linspace(-1/2, 1/2, N, endpoint=False), + np.linspace(-1/2, 1/2, N, endpoint=False), + indexing='ij' + ), (2, N*N) + )[None, :], + tf.float32 +)) +signal = tf.Variable(tf.cast(tf.random.uniform((1, 1, N, N)), tf.complex64)) +with tf.GradientTape(persistent=True) as g: + kdata_nufft = kbnufft_forward(nufft_ob._extract_nufft_interpob())(signal, ktraj) + r = tf.cast(tf.reshape(tf.meshgrid( + np.linspace(-N/2, N/2, N, endpoint=False), + np.linspace(-N/2, N/2, N, endpoint=False), + indexing='ij' + ), (2, N*N)), tf.float32) + A = tf.exp(-2j * np.pi * tf.cast(tf.matmul(tf.transpose(ktraj[0])/2/np.pi, r), tf.complex64))/N/2 + kdata_ndft = tf.matmul(A, tf.reshape(signal[0][0], (N*N, 1))) + +grad2 = g.gradient(kdata_ndft, ktraj)[0] +grad3 = -2j * np.pi * tf.transpose(tf.matmul(A, tf.transpose(tf.cast(r, tf.complex64) * tf.reshape(signal[0][0], (N*N,))))) +grad4 = g.gradient(kdata_nufft, ktraj)[0] \ No newline at end of file diff --git a/notebooks/2d_adj.py b/notebooks/2d_adj.py new file mode 100644 index 0000000..9e52886 --- /dev/null +++ b/notebooks/2d_adj.py @@ -0,0 +1,46 @@ +import tensorflow as tf +import numpy as np +import matplotlib.pyplot as plt +from tfkbnufft import kbnufft_forward, kbnufft_adjoint +from tfkbnufft.kbnufft import KbNufftModule +from mri.operators import NonCartesianFFT +tf.config.run_functions_eagerly(True) +tf.random.set_seed(0) +N = 20 +M = 20*5 +nufft_ob = KbNufftModule(im_size=(N, N), grid_size=(2*N, 2*N), norm='ortho') +ktraj = tf.Variable(tf.random.uniform((1, 2, M), minval=-1/2, maxval=1/2)) +ktraj_seq = tf.Variable(tf.cast( + np.reshape(np.meshgrid( + np.linspace(-1/2, 1/2, N, endpoint=False), + np.linspace(-1/2, 1/2, N, endpoint=False), + indexing='ij' + ), (2, N*N) + )[None, :], + tf.float32 +)) +signal = tf.Variable(tf.cast(tf.random.uniform((1, 1, N, N)), tf.complex64)) +kdata = tf.Variable(kbnufft_forward(nufft_ob._extract_nufft_interpob())(signal, ktraj*np.pi*2)) + +with tf.GradientTape(persistent=True) as g: + I_nufft = kbnufft_adjoint(nufft_ob._extract_nufft_interpob())(kdata, ktraj*np.pi*2)[0][0] + r = tf.cast(tf.reshape(tf.meshgrid( + np.linspace(-N/2, N/2, N, endpoint=False), + np.linspace(-N/2, N/2, N, endpoint=False), + indexing='ij' + ), (2, N*N)), tf.float32) + A = tf.exp(2j * np.pi * tf.cast(tf.matmul(tf.transpose(ktraj[0]), r), tf.complex64))/N/2 + I_ndft = tf.reshape(tf.matmul(tf.transpose(A), kdata[0][0][..., None]), (N, N)) + +grad2 = g.gradient(I_ndft, ktraj)[0] +grad3 = 2*np.pi*1j*tf.matmul(tf.cast(r, tf.complex64), tf.transpose(A))*kdata[0][0] +grad4 = g.gradient(I_nufft, ktraj)[0] +C = kbnufft_adjoint(nufft_ob._extract_nufft_interpob())(kdata, ktraj*np.pi*2) * tf.reshape(tf.cast(r, tf.complex64), (1, 2, 20, 20)) +grad4 = tf.reshape( + 2j * np.pi * kbnufft_adjoint(nufft_ob._extract_nufft_interpob())( + kdata * tf.cast(r, tf.complex64), + ktraj*np.pi*2 + ), + (2, N*N) +) +grad4 \ No newline at end of file diff --git a/notebooks/nfft.py b/notebooks/nfft.py new file mode 100644 index 0000000..58079d7 --- /dev/null +++ b/notebooks/nfft.py @@ -0,0 +1,25 @@ +import tensorflow as tf +import numpy as np +import matplotlib.pyplot as plt +from tfkbnufft import kbnufft_forward, kbnufft_adjoint +from tfkbnufft.kbnufft import KbNufftModule + + +N = 1024 +M = 512 +nufft_ob = KbNufftModule(im_size=(N, ), grid_size=(2*N, ), norm='ortho') +ktraj = tf.Variable(tf.random.uniform((1, 1, M), minval=-1/2, maxval=1/2)) +ktraj_seq = tf.Variable(tf.cast(np.linspace(-1/2, 1/2, N, endpoint=False)[None, None, :], tf.float32)) +signal = tf.Variable(tf.cast(tf.random.uniform((1, 1, N)), tf.complex64)) +signal_spectrum = tf.signal.fftshift(tf.signal.fft(tf.signal.ifftshift(signal[0][0]))) / N + +with tf.GradientTape(persistent=True) as g: + kdata_nufft = kbnufft_forward(nufft_ob._extract_nufft_interpob())(signal, ktraj*np.pi*2) + r = tf.cast(tf.linspace(0, N-1, N), tf.float32) + A = tf.exp(-2j * np.pi * tf.cast(ktraj[0][0][..., None] * r[None, :], tf.complex64))/N + kdata_ndft = tf.matmul(A, tf.transpose(signal[0])) + +grad1 = g.gradient(kdata_nufft, ktraj) +grad2 = g.gradient(kdata_ndft, ktraj) +grad3 = -2j * np.pi * tf.matmul(A, tf.transpose(signal[0] * tf.cast(r, tf.complex64))) +grad_nufft_mine =-2j * np.pi * kbnufft_forward(nufft_ob._extract_nufft_interpob())(signal * tf.cast(r, tf.complex64), ktraj*np.pi*2) \ No newline at end of file diff --git a/tfkbnufft/kbnufft.py b/tfkbnufft/kbnufft.py index b0ae15f..324753c 100644 --- a/tfkbnufft/kbnufft.py +++ b/tfkbnufft/kbnufft.py @@ -182,13 +182,7 @@ def grad(dy): grid_dy, scaling_coef, grid_size, im_size, norm, im_rank=im_rank) # Gradients with respect to trajectory locations - r = ( - tf.linspace(-im_size[0]/2, im_size[0]/2-1, im_size[0]), - ) - if im_rank >=2: - r = r + (tf.linspace(-im_size[1]/2, im_size[1]/2-1, im_size[1]),) - if im_rank == 3: - r = r + (tf.linspace(-im_size[2]/2, im_size[2]/2-1, im_size[2]),) + r = [tf.linspace(-im_size[i]/2, im_size[i]/2-1, im_size[i]) for i in range(im_rank)] grid_r = tf.cast(tf.meshgrid(*r, indexing='ij'), x.dtype)[None, ...] fft_dx_dom = scale_and_fft_on_image_volume( x * grid_r, scaling_coef, grid_size, im_size, norm, im_rank=im_rank, multiprocessing=multiprocessing) @@ -230,13 +224,7 @@ def grad(dx): dx_dy = kbinterp(fft_dx, om, interpob) # Gradients with respect to trajectory locations - r = ( - tf.linspace(-im_size[0]/2, im_size[0]/2-1, im_size[0]), - ) - if im_rank >=2: - r = r + (tf.linspace(-im_size[1]/2, im_size[1]/2-1, im_size[1]),) - if im_rank == 3: - r = r + (tf.linspace(-im_size[2]/2, im_size[2]/2-1, im_size[2]),) + r = [tf.linspace(-im_size[i]/2, im_size[i]/2-1, im_size[i]) for i in range(im_rank)] grid_r = tf.cast(tf.meshgrid(*r, indexing='ij'), y.dtype)[None, ...] ifft_dxr = scale_and_fft_on_image_volume( dx * grid_r, scaling_coef, grid_size, im_size, norm, im_rank=im_rank, do_ifft=True) diff --git a/tfkbnufft/tests/kbnufft_test.py b/tfkbnufft/tests/kbnufft_test.py index 9fc604c..ea7909c 100644 --- a/tfkbnufft/tests/kbnufft_test.py +++ b/tfkbnufft/tests/kbnufft_test.py @@ -57,7 +57,14 @@ def test_forward_gradient(multiprocessing): forward_op = kbnufft_forward(nufft_ob._extract_nufft_interpob(), multiprocessing) with tf.GradientTape() as tape: tape.watch(image) - res = forward_op(image, traj) + r = tf.cast(tf.reshape(tf.meshgrid( + np.linspace(-640/2, 640/2, 640, endpoint=False), + np.linspace(-400/2, 400/2, 400, endpoint=False), + indexing='ij' + ), (2, -640*400)), tf.float32) + A = tf.exp(-2j * np.pi * tf.cast(tf.matmul(tf.transpose(traj[0])/2/np.pi, r), tf.complex64))/np.sqrt(640*400)/2 + print(A) + res = forward_op(image, traj) grad = tape.gradient(res, image) tf_test = tf.test.TestCase() tf_test.assertEqual(grad.shape, image.shape) From bcf5eb0d5d30df35b264ddf67ffcce29a0bd0013 Mon Sep 17 00:00:00 2001 From: chaithyagr Date: Tue, 16 Mar 2021 15:54:29 +0100 Subject: [PATCH 05/24] Remove unwanted test files --- notebooks/2d.py | 35 ---------------------------------- notebooks/2d_adj.py | 46 --------------------------------------------- notebooks/nfft.py | 25 ------------------------ 3 files changed, 106 deletions(-) delete mode 100644 notebooks/2d.py delete mode 100644 notebooks/2d_adj.py delete mode 100644 notebooks/nfft.py diff --git a/notebooks/2d.py b/notebooks/2d.py deleted file mode 100644 index 657e22b..0000000 --- a/notebooks/2d.py +++ /dev/null @@ -1,35 +0,0 @@ -import tensorflow as tf -import numpy as np -import matplotlib.pyplot as plt -from tfkbnufft import kbnufft_forward, kbnufft_adjoint -from tfkbnufft.kbnufft import KbNufftModule -from mri.operators import NonCartesianFFT -tf.config.run_functions_eagerly(True) - -N = 20 -M = 20*5 -nufft_ob = KbNufftModule(im_size=(N, N), grid_size=(2*N, 2*N), norm='ortho') -ktraj = tf.Variable(tf.random.uniform((1, 2, M), minval=-1/2, maxval=1/2)*2*np.pi) -ktraj_seq = tf.Variable(tf.cast( - np.reshape(np.meshgrid( - np.linspace(-1/2, 1/2, N, endpoint=False), - np.linspace(-1/2, 1/2, N, endpoint=False), - indexing='ij' - ), (2, N*N) - )[None, :], - tf.float32 -)) -signal = tf.Variable(tf.cast(tf.random.uniform((1, 1, N, N)), tf.complex64)) -with tf.GradientTape(persistent=True) as g: - kdata_nufft = kbnufft_forward(nufft_ob._extract_nufft_interpob())(signal, ktraj) - r = tf.cast(tf.reshape(tf.meshgrid( - np.linspace(-N/2, N/2, N, endpoint=False), - np.linspace(-N/2, N/2, N, endpoint=False), - indexing='ij' - ), (2, N*N)), tf.float32) - A = tf.exp(-2j * np.pi * tf.cast(tf.matmul(tf.transpose(ktraj[0])/2/np.pi, r), tf.complex64))/N/2 - kdata_ndft = tf.matmul(A, tf.reshape(signal[0][0], (N*N, 1))) - -grad2 = g.gradient(kdata_ndft, ktraj)[0] -grad3 = -2j * np.pi * tf.transpose(tf.matmul(A, tf.transpose(tf.cast(r, tf.complex64) * tf.reshape(signal[0][0], (N*N,))))) -grad4 = g.gradient(kdata_nufft, ktraj)[0] \ No newline at end of file diff --git a/notebooks/2d_adj.py b/notebooks/2d_adj.py deleted file mode 100644 index 9e52886..0000000 --- a/notebooks/2d_adj.py +++ /dev/null @@ -1,46 +0,0 @@ -import tensorflow as tf -import numpy as np -import matplotlib.pyplot as plt -from tfkbnufft import kbnufft_forward, kbnufft_adjoint -from tfkbnufft.kbnufft import KbNufftModule -from mri.operators import NonCartesianFFT -tf.config.run_functions_eagerly(True) -tf.random.set_seed(0) -N = 20 -M = 20*5 -nufft_ob = KbNufftModule(im_size=(N, N), grid_size=(2*N, 2*N), norm='ortho') -ktraj = tf.Variable(tf.random.uniform((1, 2, M), minval=-1/2, maxval=1/2)) -ktraj_seq = tf.Variable(tf.cast( - np.reshape(np.meshgrid( - np.linspace(-1/2, 1/2, N, endpoint=False), - np.linspace(-1/2, 1/2, N, endpoint=False), - indexing='ij' - ), (2, N*N) - )[None, :], - tf.float32 -)) -signal = tf.Variable(tf.cast(tf.random.uniform((1, 1, N, N)), tf.complex64)) -kdata = tf.Variable(kbnufft_forward(nufft_ob._extract_nufft_interpob())(signal, ktraj*np.pi*2)) - -with tf.GradientTape(persistent=True) as g: - I_nufft = kbnufft_adjoint(nufft_ob._extract_nufft_interpob())(kdata, ktraj*np.pi*2)[0][0] - r = tf.cast(tf.reshape(tf.meshgrid( - np.linspace(-N/2, N/2, N, endpoint=False), - np.linspace(-N/2, N/2, N, endpoint=False), - indexing='ij' - ), (2, N*N)), tf.float32) - A = tf.exp(2j * np.pi * tf.cast(tf.matmul(tf.transpose(ktraj[0]), r), tf.complex64))/N/2 - I_ndft = tf.reshape(tf.matmul(tf.transpose(A), kdata[0][0][..., None]), (N, N)) - -grad2 = g.gradient(I_ndft, ktraj)[0] -grad3 = 2*np.pi*1j*tf.matmul(tf.cast(r, tf.complex64), tf.transpose(A))*kdata[0][0] -grad4 = g.gradient(I_nufft, ktraj)[0] -C = kbnufft_adjoint(nufft_ob._extract_nufft_interpob())(kdata, ktraj*np.pi*2) * tf.reshape(tf.cast(r, tf.complex64), (1, 2, 20, 20)) -grad4 = tf.reshape( - 2j * np.pi * kbnufft_adjoint(nufft_ob._extract_nufft_interpob())( - kdata * tf.cast(r, tf.complex64), - ktraj*np.pi*2 - ), - (2, N*N) -) -grad4 \ No newline at end of file diff --git a/notebooks/nfft.py b/notebooks/nfft.py deleted file mode 100644 index 58079d7..0000000 --- a/notebooks/nfft.py +++ /dev/null @@ -1,25 +0,0 @@ -import tensorflow as tf -import numpy as np -import matplotlib.pyplot as plt -from tfkbnufft import kbnufft_forward, kbnufft_adjoint -from tfkbnufft.kbnufft import KbNufftModule - - -N = 1024 -M = 512 -nufft_ob = KbNufftModule(im_size=(N, ), grid_size=(2*N, ), norm='ortho') -ktraj = tf.Variable(tf.random.uniform((1, 1, M), minval=-1/2, maxval=1/2)) -ktraj_seq = tf.Variable(tf.cast(np.linspace(-1/2, 1/2, N, endpoint=False)[None, None, :], tf.float32)) -signal = tf.Variable(tf.cast(tf.random.uniform((1, 1, N)), tf.complex64)) -signal_spectrum = tf.signal.fftshift(tf.signal.fft(tf.signal.ifftshift(signal[0][0]))) / N - -with tf.GradientTape(persistent=True) as g: - kdata_nufft = kbnufft_forward(nufft_ob._extract_nufft_interpob())(signal, ktraj*np.pi*2) - r = tf.cast(tf.linspace(0, N-1, N), tf.float32) - A = tf.exp(-2j * np.pi * tf.cast(ktraj[0][0][..., None] * r[None, :], tf.complex64))/N - kdata_ndft = tf.matmul(A, tf.transpose(signal[0])) - -grad1 = g.gradient(kdata_nufft, ktraj) -grad2 = g.gradient(kdata_ndft, ktraj) -grad3 = -2j * np.pi * tf.matmul(A, tf.transpose(signal[0] * tf.cast(r, tf.complex64))) -grad_nufft_mine =-2j * np.pi * kbnufft_forward(nufft_ob._extract_nufft_interpob())(signal * tf.cast(r, tf.complex64), ktraj*np.pi*2) \ No newline at end of file From 0fd2beb5df972b30a129c7028af01c860629d3a2 Mon Sep 17 00:00:00 2001 From: chaithyagr Date: Tue, 16 Mar 2021 16:13:28 +0100 Subject: [PATCH 06/24] Remove unwated test lines --- tfkbnufft/tests/kbnufft_test.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/tfkbnufft/tests/kbnufft_test.py b/tfkbnufft/tests/kbnufft_test.py index ea7909c..9fc604c 100644 --- a/tfkbnufft/tests/kbnufft_test.py +++ b/tfkbnufft/tests/kbnufft_test.py @@ -57,14 +57,7 @@ def test_forward_gradient(multiprocessing): forward_op = kbnufft_forward(nufft_ob._extract_nufft_interpob(), multiprocessing) with tf.GradientTape() as tape: tape.watch(image) - r = tf.cast(tf.reshape(tf.meshgrid( - np.linspace(-640/2, 640/2, 640, endpoint=False), - np.linspace(-400/2, 400/2, 400, endpoint=False), - indexing='ij' - ), (2, -640*400)), tf.float32) - A = tf.exp(-2j * np.pi * tf.cast(tf.matmul(tf.transpose(traj[0])/2/np.pi, r), tf.complex64))/np.sqrt(640*400)/2 - print(A) - res = forward_op(image, traj) + res = forward_op(image, traj) grad = tape.gradient(res, image) tf_test = tf.test.TestCase() tf_test.assertEqual(grad.shape, image.shape) From 63c0d06baa32cc36896cc8341eb3ceb637be1992 Mon Sep 17 00:00:00 2001 From: Chaithya G R Date: Tue, 16 Mar 2021 16:25:06 +0100 Subject: [PATCH 07/24] Update tfkbnufft/kbnufft.py Co-authored-by: Zaccharie Ramzi --- tfkbnufft/kbnufft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tfkbnufft/kbnufft.py b/tfkbnufft/kbnufft.py index d8ecd0d..dcda93b 100644 --- a/tfkbnufft/kbnufft.py +++ b/tfkbnufft/kbnufft.py @@ -225,7 +225,7 @@ def grad(dx): # Gradients with respect to trajectory locations r = [tf.linspace(-im_size[i]/2, im_size[i]/2-1, im_size[i]) for i in range(im_rank)] - grid_r = tf.cast(tf.meshgrid(*r, indexing='ij'), y.dtype)[None, ...] + grid_r = tf.cast(tf.meshgrid(*r, indexing='ij'), dx.dtype)[None, ...] ifft_dxr = scale_and_fft_on_image_volume( dx * grid_r, scaling_coef, grid_size, im_size, norm, im_rank=im_rank, do_ifft=True) dx_dom = tf.cast(1j * y * kbinterp(ifft_dxr, om, interpob, conj=True), om.dtype) From 19452fc6d77e2356a5cd79c93e679994a4cdf1f3 Mon Sep 17 00:00:00 2001 From: chaithyagr Date: Wed, 17 Mar 2021 09:03:12 +0100 Subject: [PATCH 08/24] #29 Refactorinf codes --- tfkbnufft/nufft/fft_functions.py | 88 ++++++++++++++++++----------- tfkbnufft/nufft/interp_functions.py | 5 +- 2 files changed, 58 insertions(+), 35 deletions(-) diff --git a/tfkbnufft/nufft/fft_functions.py b/tfkbnufft/nufft/fft_functions.py index da76642..65834a9 100644 --- a/tfkbnufft/nufft/fft_functions.py +++ b/tfkbnufft/nufft/fft_functions.py @@ -3,6 +3,27 @@ import tensorflow as tf from tensorflow.python.ops.signal.fft_ops import ifft2d, fft2d, fft, ifft +def tf_mp_ifft(kspace): + k_shape_x = tf.shape(kspace)[-1] + batched_kspace = tf.reshape(kspace, (-1, k_shape_x)) + batched_image = tf.map_fn( + ifft, + batched_kspace, + parallel_iterations=multiprocessing.cpu_count(), + ) + image = tf.reshape(batched_image, tf.shape(kspace)) + return image + +def tf_mp_fft(kspace): + k_shape_x = tf.shape(kspace)[-1] + batched_kspace = tf.reshape(kspace, (-1, k_shape_x)) + batched_image = tf.map_fn( + fft, + batched_kspace, + parallel_iterations=multiprocessing.cpu_count(), + ) + image = tf.reshape(batched_image, tf.shape(kspace)) + return image def tf_mp_ifft2d(kspace): k_shape_x = tf.shape(kspace)[-2] @@ -61,6 +82,37 @@ def tf_mp_fourier3d(x, trans_type='inv'): y = tf.transpose(y_reshaped, [0, 1, 4, 2, 3]) return y +# Generate a fourier dictionary to simplify its use below. +# In the end we have the following list: +# fourier_dict[do_ifft][multiprocessing][rank of image - 1] +fourier_dict = [ + [ + [ + tf.signal.fft, + tf.signal.fft2d, + tf.signal.fft3d, + ], + [ + tf_mp_fft, + tf_mp_fft2d, + tf_mp_fft3d, + ] + ], + [ + [ + tf.signal.ifft, + tf.signal.ifft2d, + tf.signal.ifft3d, + ], + [ + tf_mp_ifft, + tf_mp_ifft2d, + tf_mp_ifft3d, + ] + ] +] + + def scale_and_fft_on_image_volume(x, scaling_coef, grid_size, im_size, norm, im_rank=2, multiprocessing=False, do_ifft=False): """Applies the FFT and any relevant scaling factors to x. @@ -103,28 +155,7 @@ def scale_and_fft_on_image_volume(x, scaling_coef, grid_size, im_size, norm, im_ # zero pad and fft x = tf.pad(x, pad_sizes) # this might have to be a tf py function, or I could use tf cond - if im_rank == 2: - if multiprocessing: - if do_ifft: - x = tf_mp_ifft2d(x) - else: - x = tf_mp_fft2d(x) - else: - if do_ifft: - x = tf.signal.ifft2d(x) - else: - x = tf.signal.fft2d(x) - else: - if multiprocessing: - if do_ifft: - x = tf_mp_ifft3d(x) - else: - x = tf_mp_fft3d(x) - else: - if do_ifft: - x = tf.signal.ifft3d(x) - else: - x = tf.signal.fft3d(x) + x = fourier_dict[do_ifft][multiprocessing][im_rank - 1] if norm == 'ortho': scaling_factor = tf.cast(tf.reduce_prod(grid_size), x.dtype) if do_ifft: @@ -134,6 +165,7 @@ def scale_and_fft_on_image_volume(x, scaling_coef, grid_size, im_size, norm, im_ return x + def ifft_and_scale_on_gridded_data(x, scaling_coef, grid_size, im_size, norm, im_rank=2, multiprocessing=False): """Applies the iFFT and any relevant scaling factors to x. @@ -152,17 +184,7 @@ def ifft_and_scale_on_gridded_data(x, scaling_coef, grid_size, im_size, norm, im # we don't need permutations since the fft in fourier is done on the # innermost dimensions and we are handling complex tensors # do the inverse fft - if im_rank == 2: - if multiprocessing: - x = tf_mp_ifft2d(x) - else: - x = tf.signal.ifft2d(x) - else: - if multiprocessing: - x = tf_mp_ifft3d(x) - else: - x = tf.signal.ifft3d(x) - + x = fourier_dict[True][multiprocessing][im_rank - 1] im_size = tf.cast(im_size, tf.int32) # crop to output size x = x[:, :, :im_size[0]] diff --git a/tfkbnufft/nufft/interp_functions.py b/tfkbnufft/nufft/interp_functions.py index deb6649..73e1fb7 100644 --- a/tfkbnufft/nufft/interp_functions.py +++ b/tfkbnufft/nufft/interp_functions.py @@ -212,10 +212,11 @@ def kbinterp(x, om, interpob, conj=False): def _map_body(inputs): _x, _tm, _om = inputs y_not_shifted = run_interp(tf.reshape(_x, (tf.shape(_x)[0], -1)), _tm, params) + shift = tf.exp(1j * tf.cast(tf.linalg.matvec(tf.transpose(_om), n_shift), y_not_shifted.dtype))[None, ...] if conj: - y = y_not_shifted * tf.math.conj(tf.exp(1j * tf.cast(tf.linalg.matvec(tf.transpose(_om), n_shift), y_not_shifted.dtype)))[None, ...] + y = y_not_shifted * tf.math.conj(shift) else: - y = y_not_shifted * tf.exp(1j * tf.cast(tf.linalg.matvec(tf.transpose(_om), n_shift), y_not_shifted.dtype))[None, ...] + y = y_not_shifted * shift return y y = tf.map_fn(_map_body, [x, tm, om], dtype=x.dtype) From 81558bd20107cd927694a64de21fb082601d0d1d Mon Sep 17 00:00:00 2001 From: chaithyagr Date: Wed, 17 Mar 2021 09:10:07 +0100 Subject: [PATCH 09/24] Add grad_traj --- tfkbnufft/kbnufft.py | 37 ++++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/tfkbnufft/kbnufft.py b/tfkbnufft/kbnufft.py index dcda93b..8e267ce 100644 --- a/tfkbnufft/kbnufft.py +++ b/tfkbnufft/kbnufft.py @@ -37,11 +37,12 @@ class KbNufftModule(KbModule): def __init__(self, im_size, grid_size=None, numpoints=6, n_shift=None, table_oversamp=2**10, kbwidth=2.34, order=0, norm='None', - coil_broadcast=False, matadj=False): + coil_broadcast=False, matadj=False, grad_traj=False): super(KbNufftModule, self).__init__() self.im_size = im_size self.im_rank = len(im_size) + self.grad_traj = grad_traj if grid_size is None: self.grid_size = tuple(np.array(self.im_size) * 2) else: @@ -135,6 +136,7 @@ def _extract_nufft_interpob(self): interpob['norm'] = self.norm interpob['coil_broadcast'] = self.coil_broadcast interpob['matadj'] = self.matadj + interpob['grad_traj'] = self.grad_traj Jgen = [] for i in range(self.im_rank): # number of points to use for interpolation is numpoints @@ -168,6 +170,7 @@ def kbnufft_forward_for_interpob(x, om): grid_size = interpob['grid_size'] im_size = interpob['im_size'] norm = interpob['norm'] + grad_traj = interpob['grad_traj'] im_rank = interpob.get('im_rank', 2) fft_x = scale_and_fft_on_image_volume( @@ -180,13 +183,15 @@ def grad(dy): grid_dy = adjkbinterp(dy, om, interpob) ifft_dy = ifft_and_scale_on_gridded_data( grid_dy, scaling_coef, grid_size, im_size, norm, im_rank=im_rank) - - # Gradients with respect to trajectory locations - r = [tf.linspace(-im_size[i]/2, im_size[i]/2-1, im_size[i]) for i in range(im_rank)] - grid_r = tf.cast(tf.meshgrid(*r, indexing='ij'), x.dtype)[None, ...] - fft_dx_dom = scale_and_fft_on_image_volume( + if grad_traj: + # Gradients with respect to trajectory locations + r = [tf.linspace(-im_size[i]/2, im_size[i]/2-1, im_size[i]) for i in range(im_rank)] + grid_r = tf.cast(tf.meshgrid(*r, indexing='ij'), x.dtype)[None, ...] + fft_dx_dom = scale_and_fft_on_image_volume( x * grid_r, scaling_coef, grid_size, im_size, norm, im_rank=im_rank) - dy_dom = tf.cast(-1j * dy * kbinterp(fft_dx_dom, om, interpob), tf.float32) + dy_dom = tf.cast(-1j * dy * kbinterp(fft_dx_dom, om, interpob), tf.float32) + else: + dy_dom = None return ifft_dy, dy_dom return y, grad @@ -213,6 +218,7 @@ def kbnufft_adjoint_for_interpob(y, om): grid_size = interpob['grid_size'] im_size = interpob['im_size'] norm = interpob['norm'] + grad_traj = interpob['grad_traj'] im_rank = interpob.get('im_rank', 2) ifft_y = ifft_and_scale_on_gridded_data( grid_y, scaling_coef, grid_size, im_size, norm, im_rank=im_rank, multiprocessing=multiprocessing) @@ -222,14 +228,15 @@ def grad(dx): fft_dx = scale_and_fft_on_image_volume( dx, scaling_coef, grid_size, im_size, norm, im_rank=im_rank) dx_dy = kbinterp(fft_dx, om, interpob) - - # Gradients with respect to trajectory locations - r = [tf.linspace(-im_size[i]/2, im_size[i]/2-1, im_size[i]) for i in range(im_rank)] - grid_r = tf.cast(tf.meshgrid(*r, indexing='ij'), dx.dtype)[None, ...] - ifft_dxr = scale_and_fft_on_image_volume( - dx * grid_r, scaling_coef, grid_size, im_size, norm, im_rank=im_rank, do_ifft=True) - dx_dom = tf.cast(1j * y * kbinterp(ifft_dxr, om, interpob, conj=True), om.dtype) - + if grad_traj: + # Gradients with respect to trajectory locations + r = [tf.linspace(-im_size[i]/2, im_size[i]/2-1, im_size[i]) for i in range(im_rank)] + grid_r = tf.cast(tf.meshgrid(*r, indexing='ij'), dx.dtype)[None, ...] + ifft_dxr = scale_and_fft_on_image_volume( + dx * grid_r, scaling_coef, grid_size, im_size, norm, im_rank=im_rank, do_ifft=True) + dx_dom = tf.cast(1j * y * kbinterp(ifft_dxr, om, interpob, conj=True), om.dtype) + else: + dx_dom = None return dx_dy, dx_dom return ifft_y, grad return kbnufft_adjoint_for_interpob From ac2e7f95f1c1251329b70c5700ee3f6c8d7333e4 Mon Sep 17 00:00:00 2001 From: chaithyagr Date: Wed, 17 Mar 2021 09:26:23 +0100 Subject: [PATCH 10/24] Update the tests --- run_tests.sh | 2 +- tfkbnufft/nufft/fft_functions.py | 4 ++-- tfkbnufft/tests/nufft/interp_functions_test.py | 4 +++- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/run_tests.sh b/run_tests.sh index 1261fd6..8b8c837 100644 --- a/run_tests.sh +++ b/run_tests.sh @@ -1,3 +1,3 @@ #!/usr/bin/env bash -pip install torchkbnufft scikit-image pytest +pip install torchkbnufft==0.3.4 scikit-image pytest python -m pytest tfkbnufft diff --git a/tfkbnufft/nufft/fft_functions.py b/tfkbnufft/nufft/fft_functions.py index 65834a9..b73474f 100644 --- a/tfkbnufft/nufft/fft_functions.py +++ b/tfkbnufft/nufft/fft_functions.py @@ -155,7 +155,7 @@ def scale_and_fft_on_image_volume(x, scaling_coef, grid_size, im_size, norm, im_ # zero pad and fft x = tf.pad(x, pad_sizes) # this might have to be a tf py function, or I could use tf cond - x = fourier_dict[do_ifft][multiprocessing][im_rank - 1] + x = fourier_dict[do_ifft][multiprocessing][im_rank - 1](x) if norm == 'ortho': scaling_factor = tf.cast(tf.reduce_prod(grid_size), x.dtype) if do_ifft: @@ -184,7 +184,7 @@ def ifft_and_scale_on_gridded_data(x, scaling_coef, grid_size, im_size, norm, im # we don't need permutations since the fft in fourier is done on the # innermost dimensions and we are handling complex tensors # do the inverse fft - x = fourier_dict[True][multiprocessing][im_rank - 1] + x = fourier_dict[True][multiprocessing][im_rank - 1](x) im_size = tf.cast(im_size, tf.int32) # crop to output size x = x[:, :, :im_size[0]] diff --git a/tfkbnufft/tests/nufft/interp_functions_test.py b/tfkbnufft/tests/nufft/interp_functions_test.py index 87d273d..f257152 100644 --- a/tfkbnufft/tests/nufft/interp_functions_test.py +++ b/tfkbnufft/tests/nufft/interp_functions_test.py @@ -47,7 +47,8 @@ def test_calc_coef_and_indices(conjcoef): np.testing.assert_allclose(res_torch_coefs, res_tf_coefs.numpy()) @pytest.mark.parametrize('n_coil', [1, 2, 5, 16]) -def test_run_interp(n_coil): +@pytest.mark.parametrize('conjcoef', [True, False]) +def test_run_interp(n_coil, conjcoef): tm, Jgen, table, numpoints, L, grid_size = setup() grid_size = grid_size.astype('int') griddat = np.stack([ @@ -62,6 +63,7 @@ def test_run_interp(n_coil): 'numpoints': numpoints, 'Jlist': Jgen, 'table_oversamp': L, + 'conjcoef': False, } args = [griddat, tm, params] torch_args = [to_torch_arg(arg) for arg in args] From d89f21dab976ed3a479a5723283111174b54dd8c Mon Sep 17 00:00:00 2001 From: chaithyagr Date: Wed, 17 Mar 2021 09:38:25 +0100 Subject: [PATCH 11/24] Update fft_function, cropping function was broken --- tfkbnufft/nufft/fft_functions.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tfkbnufft/nufft/fft_functions.py b/tfkbnufft/nufft/fft_functions.py index b73474f..9ace35b 100644 --- a/tfkbnufft/nufft/fft_functions.py +++ b/tfkbnufft/nufft/fft_functions.py @@ -189,9 +189,10 @@ def ifft_and_scale_on_gridded_data(x, scaling_coef, grid_size, im_size, norm, im # crop to output size x = x[:, :, :im_size[0]] if im_rank >=2: - x = x[..., :im_size[1]] - if im_rank == 3: - x = x[..., :im_size[2]] + if im_rank == 3: + x = x[..., :im_size[1], :im_size[2]] + else: + x = x[..., :im_size[1]] # scaling scaling_factor = tf.cast(tf.reduce_prod(grid_size), x.dtype) From b4e1a391d3ff5b75befa86cf97df9d084ab6c4c2 Mon Sep 17 00:00:00 2001 From: chaithyagr Date: Wed, 17 Mar 2021 10:40:25 +0100 Subject: [PATCH 12/24] Add tests for NDFT --- tfkbnufft/tests/ndft_tests.py | 78 +++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 tfkbnufft/tests/ndft_tests.py diff --git a/tfkbnufft/tests/ndft_tests.py b/tfkbnufft/tests/ndft_tests.py new file mode 100644 index 0000000..5788703 --- /dev/null +++ b/tfkbnufft/tests/ndft_tests.py @@ -0,0 +1,78 @@ +import tensorflow as tf +import numpy as np +from tfkbnufft import kbnufft_forward, kbnufft_adjoint +from tfkbnufft.kbnufft import KbNufftModule + + +def test_adjoint_and_gradients(): + N = 20 + M = 20*5 + nufft_ob = KbNufftModule(im_size=(N, N), grid_size=(2*N, 2*N), norm='ortho', grad_traj=True) + # Generate Trajectory + ktraj = tf.Variable(tf.random.uniform((1, 2, M), minval=-1/2, maxval=1/2)*2*np.pi) + # Have a random signal + signal = tf.Variable(tf.cast(tf.random.uniform((1, 1, N, N)), tf.complex64)) + kdata = tf.Variable(kbnufft_forward(nufft_ob._extract_nufft_interpob())(signal, ktraj)) + + with tf.GradientTape(persistent=True) as g: + I_nufft = kbnufft_adjoint(nufft_ob._extract_nufft_interpob())(kdata, ktraj)[0][0] + r = tf.cast(tf.reshape(tf.meshgrid( + tf.linspace(-N/2, N/2-1, N), + tf.linspace(-N/2, N/2-1, N), + indexing='ij' + ), (2, N * N)), tf.float32) + A = tf.exp(2j * np.pi * tf.cast(tf.matmul(tf.transpose(ktraj[0]/2/np.pi), r), tf.complex64))/N/2 + I_ndft = tf.reshape(tf.matmul(tf.transpose(A), kdata[0][0][..., None]), (N, N)) + + tf_test = tf.test.TestCase() + # Test if the NUFFT and NDFT operation is same + tf_test.assertAllClose(I_nufft, I_ndft, rtol=1e-1) + + # Test gradients with respect to kdata + gradient_ndft_kdata = g.gradient(I_ndft, kdata)[0] + gradient_nufft_kdata = g.gradient(I_nufft, kdata)[0] + tf_test.assertAllClose(gradient_ndft_kdata, gradient_nufft_kdata, atol=1e-1) + + # Test gradients with respect to trajectory location + gradient_ndft_traj = g.gradient(I_ndft, ktraj)[0] + gradient_nufft_traj = g.gradient(I_nufft, ktraj)[0] + tf_test.assertAllClose(gradient_ndft_traj, gradient_nufft_traj, atol=1e-1) + # This is gradient of NDFT from matrix, will help in debug + # gradient_from_matrix = 2*np.pi*1j*tf.matmul(tf.cast(r, tf.complex64), tf.transpose(A))*kdata[0][0] + + +def test_forward_and_gradients(): + N = 20 + M = 20*5 + nufft_ob = KbNufftModule(im_size=(N, N), grid_size=(2*N, 2*N), norm='ortho', grad_traj=True) + # Generate Trajectory + ktraj = tf.Variable(tf.random.uniform((1, 2, M), minval=-1/2, maxval=1/2)*2*np.pi) + # Have a random signal + signal = tf.Variable(tf.cast(tf.random.uniform((1, 1, N, N)), tf.complex64)) + + with tf.GradientTape(persistent=True) as g: + kdata_nufft = kbnufft_forward(nufft_ob._extract_nufft_interpob())(signal, ktraj)[0] + r = tf.cast(tf.reshape(tf.meshgrid( + tf.linspace(-N/2, N/2-1, N), + tf.linspace(-N/2, N/2-1, N), + indexing='ij' + ), (2, N * N)), tf.float32) + A = tf.exp(-2j * np.pi * tf.cast(tf.matmul(tf.transpose(ktraj[0])/2/np.pi, r), tf.complex64))/N/2 + kdata_ndft = tf.transpose(tf.matmul(A, tf.reshape(signal[0][0], (N*N, 1)))) + + tf_test = tf.test.TestCase() + # Test if the NUFFT and NDFT operation is same + tf_test.assertAllClose(kdata_nufft, kdata_ndft, atol=1e-1) + + # Test gradients with respect to kdata + gradient_ndft_kdata = g.gradient(kdata_ndft, signal)[0] + gradient_nufft_kdata = g.gradient(kdata_nufft, signal)[0] + tf_test.assertAllClose(gradient_ndft_kdata, gradient_nufft_kdata, atol=1) + + # Test gradients with respect to trajectory location + gradient_ndft_traj = g.gradient(kdata_ndft, ktraj)[0] + gradient_nufft_traj = g.gradient(kdata_nufft, ktraj)[0] + tf_test.assertAllClose(gradient_ndft_traj, gradient_nufft_traj, atol=1) + # This is gradient of NDFT from matrix, will help in debug + # gradient_ndft_matrix = -2j * np.pi * tf.transpose(tf.matmul(A, tf.transpose(tf.cast(r, tf.complex64) * tf.reshape(signal[0][0], (N*N,))))) + From 5112603ff3312f4e4422d076ee786e4feee76c41 Mon Sep 17 00:00:00 2001 From: Chaithya G R Date: Wed, 17 Mar 2021 11:10:51 +0100 Subject: [PATCH 13/24] Update tfkbnufft/tests/nufft/interp_functions_test.py Co-authored-by: Zaccharie Ramzi --- tfkbnufft/tests/nufft/interp_functions_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tfkbnufft/tests/nufft/interp_functions_test.py b/tfkbnufft/tests/nufft/interp_functions_test.py index f257152..18a1ede 100644 --- a/tfkbnufft/tests/nufft/interp_functions_test.py +++ b/tfkbnufft/tests/nufft/interp_functions_test.py @@ -63,7 +63,7 @@ def test_run_interp(n_coil, conjcoef): 'numpoints': numpoints, 'Jlist': Jgen, 'table_oversamp': L, - 'conjcoef': False, + 'conjcoef': conjcoef, } args = [griddat, tm, params] torch_args = [to_torch_arg(arg) for arg in args] From fcd47ef5f9e7110d72621b9a88bf9a3d768144ef Mon Sep 17 00:00:00 2001 From: chaithyagr Date: Wed, 17 Mar 2021 11:19:46 +0100 Subject: [PATCH 14/24] Update tests to be more flexible and extensive --- tfkbnufft/__init__.py | 2 +- tfkbnufft/tests/ndft_tests.py | 58 ++++++++++++++++++----------------- 2 files changed, 31 insertions(+), 29 deletions(-) diff --git a/tfkbnufft/__init__.py b/tfkbnufft/__init__.py index d3bb613..7c652f4 100644 --- a/tfkbnufft/__init__.py +++ b/tfkbnufft/__init__.py @@ -1,6 +1,6 @@ """Package info""" -__version__ = '0.1.4' +__version__ = '0.2.0' __author__ = 'Zaccharie Ramzi' __author_email__ = 'zaccharie.ramzi@inria.fr' __license__ = 'MIT' diff --git a/tfkbnufft/tests/ndft_tests.py b/tfkbnufft/tests/ndft_tests.py index 5788703..76fba84 100644 --- a/tfkbnufft/tests/ndft_tests.py +++ b/tfkbnufft/tests/ndft_tests.py @@ -1,32 +1,34 @@ +import pytest import tensorflow as tf import numpy as np from tfkbnufft import kbnufft_forward, kbnufft_adjoint from tfkbnufft.kbnufft import KbNufftModule -def test_adjoint_and_gradients(): - N = 20 - M = 20*5 - nufft_ob = KbNufftModule(im_size=(N, N), grid_size=(2*N, 2*N), norm='ortho', grad_traj=True) +@pytest.mark.parametrize('im_size', [(20, ), (10, 15), (10, 15, 12)]) +def test_adjoint_and_gradients(im_size): + grid_size = tuple(np.array(im_size)*2) + im_rank = len(im_size) + M = im_size[0] * 3**im_rank + nufft_ob = KbNufftModule(im_size=im_size, grid_size=grid_size, norm='ortho', grad_traj=True) # Generate Trajectory - ktraj = tf.Variable(tf.random.uniform((1, 2, M), minval=-1/2, maxval=1/2)*2*np.pi) + ktraj = tf.Variable(tf.random.uniform((1, im_rank, M), minval=-1/2, maxval=1/2)*2*np.pi) # Have a random signal - signal = tf.Variable(tf.cast(tf.random.uniform((1, 1, N, N)), tf.complex64)) + signal = tf.Variable(tf.cast(tf.random.uniform((1, 1, *im_size)), tf.complex64)) kdata = tf.Variable(kbnufft_forward(nufft_ob._extract_nufft_interpob())(signal, ktraj)) with tf.GradientTape(persistent=True) as g: I_nufft = kbnufft_adjoint(nufft_ob._extract_nufft_interpob())(kdata, ktraj)[0][0] - r = tf.cast(tf.reshape(tf.meshgrid( - tf.linspace(-N/2, N/2-1, N), - tf.linspace(-N/2, N/2-1, N), - indexing='ij' - ), (2, N * N)), tf.float32) - A = tf.exp(2j * np.pi * tf.cast(tf.matmul(tf.transpose(ktraj[0]/2/np.pi), r), tf.complex64))/N/2 - I_ndft = tf.reshape(tf.matmul(tf.transpose(A), kdata[0][0][..., None]), (N, N)) + r = [tf.linspace(-im_size[i]/2, im_size[i]/2-1, im_size[i]) for i in range(im_rank)] + grid_r =tf.cast(tf.reshape(tf.meshgrid(*r ,indexing='ij'), (im_rank, tf.reduce_prod(im_size))), tf.float32) + A = tf.exp(1j * tf.cast(tf.matmul(tf.transpose(ktraj[0]), grid_r), tf.complex64)) / ( + np.sqrt(tf.reduce_prod(im_size)) * np.power(np.sqrt(2), im_rank) + ) + I_ndft = tf.reshape(tf.matmul(tf.transpose(A), kdata[0][0][..., None]), im_size) tf_test = tf.test.TestCase() # Test if the NUFFT and NDFT operation is same - tf_test.assertAllClose(I_nufft, I_ndft, rtol=1e-1) + tf_test.assertAllClose(I_nufft, I_ndft, atol=1e-1) # Test gradients with respect to kdata gradient_ndft_kdata = g.gradient(I_ndft, kdata)[0] @@ -41,24 +43,25 @@ def test_adjoint_and_gradients(): # gradient_from_matrix = 2*np.pi*1j*tf.matmul(tf.cast(r, tf.complex64), tf.transpose(A))*kdata[0][0] -def test_forward_and_gradients(): - N = 20 - M = 20*5 - nufft_ob = KbNufftModule(im_size=(N, N), grid_size=(2*N, 2*N), norm='ortho', grad_traj=True) +@pytest.mark.parametrize('im_size', [(20, ), (10, 15), (10, 15, 12)]) +def test_forward_and_gradients(im_size): + grid_size = tuple(np.array(im_size)*2) + im_rank = len(im_size) + M = im_size[0] * 3**im_rank + nufft_ob = KbNufftModule(im_size=im_size, grid_size=grid_size, norm='ortho', grad_traj=True) # Generate Trajectory - ktraj = tf.Variable(tf.random.uniform((1, 2, M), minval=-1/2, maxval=1/2)*2*np.pi) + ktraj = tf.Variable(tf.random.uniform((1, im_rank, M), minval=-1/2, maxval=1/2)*2*np.pi) # Have a random signal - signal = tf.Variable(tf.cast(tf.random.uniform((1, 1, N, N)), tf.complex64)) + signal = tf.Variable(tf.cast(tf.random.uniform((1, 1, *im_size)), tf.complex64)) with tf.GradientTape(persistent=True) as g: kdata_nufft = kbnufft_forward(nufft_ob._extract_nufft_interpob())(signal, ktraj)[0] - r = tf.cast(tf.reshape(tf.meshgrid( - tf.linspace(-N/2, N/2-1, N), - tf.linspace(-N/2, N/2-1, N), - indexing='ij' - ), (2, N * N)), tf.float32) - A = tf.exp(-2j * np.pi * tf.cast(tf.matmul(tf.transpose(ktraj[0])/2/np.pi, r), tf.complex64))/N/2 - kdata_ndft = tf.transpose(tf.matmul(A, tf.reshape(signal[0][0], (N*N, 1)))) + r = [tf.linspace(-im_size[i]/2, im_size[i]/2-1, im_size[i]) for i in range(im_rank)] + grid_r =tf.cast(tf.reshape(tf.meshgrid(*r ,indexing='ij'), (im_rank, tf.reduce_prod(im_size))), tf.float32) + A = tf.exp(-1j * tf.cast(tf.matmul(tf.transpose(ktraj[0]), grid_r), tf.complex64)) / ( + np.sqrt(tf.reduce_prod(im_size)) * np.power(np.sqrt(2), im_rank) + ) + kdata_ndft = tf.transpose(tf.matmul(A, tf.reshape(signal[0][0], (tf.reduce_prod(im_size), 1)))) tf_test = tf.test.TestCase() # Test if the NUFFT and NDFT operation is same @@ -75,4 +78,3 @@ def test_forward_and_gradients(): tf_test.assertAllClose(gradient_ndft_traj, gradient_nufft_traj, atol=1) # This is gradient of NDFT from matrix, will help in debug # gradient_ndft_matrix = -2j * np.pi * tf.transpose(tf.matmul(A, tf.transpose(tf.cast(r, tf.complex64) * tf.reshape(signal[0][0], (N*N,))))) - From 34b8527ae3c4a7b786d0112c7e874a560db42d23 Mon Sep 17 00:00:00 2001 From: chaithyagr Date: Wed, 17 Mar 2021 13:08:51 +0100 Subject: [PATCH 15/24] Update the tests, workflow complete --- tfkbnufft/nufft/fft_functions.py | 6 +++--- tfkbnufft/tests/ndft_tests.py | 18 ++++++++++-------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/tfkbnufft/nufft/fft_functions.py b/tfkbnufft/nufft/fft_functions.py index 9ace35b..cdb718f 100644 --- a/tfkbnufft/nufft/fft_functions.py +++ b/tfkbnufft/nufft/fft_functions.py @@ -85,7 +85,7 @@ def tf_mp_fourier3d(x, trans_type='inv'): # Generate a fourier dictionary to simplify its use below. # In the end we have the following list: # fourier_dict[do_ifft][multiprocessing][rank of image - 1] -fourier_dict = [ +fourier_list = [ [ [ tf.signal.fft, @@ -155,7 +155,7 @@ def scale_and_fft_on_image_volume(x, scaling_coef, grid_size, im_size, norm, im_ # zero pad and fft x = tf.pad(x, pad_sizes) # this might have to be a tf py function, or I could use tf cond - x = fourier_dict[do_ifft][multiprocessing][im_rank - 1](x) + x = fourier_list[do_ifft][multiprocessing][im_rank - 1](x) if norm == 'ortho': scaling_factor = tf.cast(tf.reduce_prod(grid_size), x.dtype) if do_ifft: @@ -184,7 +184,7 @@ def ifft_and_scale_on_gridded_data(x, scaling_coef, grid_size, im_size, norm, im # we don't need permutations since the fft in fourier is done on the # innermost dimensions and we are handling complex tensors # do the inverse fft - x = fourier_dict[True][multiprocessing][im_rank - 1](x) + x = fourier_list[True][multiprocessing][im_rank - 1](x) im_size = tf.cast(im_size, tf.int32) # crop to output size x = x[:, :, :im_size[0]] diff --git a/tfkbnufft/tests/ndft_tests.py b/tfkbnufft/tests/ndft_tests.py index 76fba84..81ad6f2 100644 --- a/tfkbnufft/tests/ndft_tests.py +++ b/tfkbnufft/tests/ndft_tests.py @@ -5,8 +5,9 @@ from tfkbnufft.kbnufft import KbNufftModule -@pytest.mark.parametrize('im_size', [(20, ), (10, 15), (10, 15, 12)]) +@pytest.mark.parametrize('im_size', [(20, ), (10, 10), (10, 10, 10)]) def test_adjoint_and_gradients(im_size): + tf.random.set_seed(0) grid_size = tuple(np.array(im_size)*2) im_rank = len(im_size) M = im_size[0] * 3**im_rank @@ -28,23 +29,24 @@ def test_adjoint_and_gradients(im_size): tf_test = tf.test.TestCase() # Test if the NUFFT and NDFT operation is same - tf_test.assertAllClose(I_nufft, I_ndft, atol=1e-1) + tf_test.assertAllClose(I_nufft, I_ndft, atol=1e-2) # Test gradients with respect to kdata gradient_ndft_kdata = g.gradient(I_ndft, kdata)[0] gradient_nufft_kdata = g.gradient(I_nufft, kdata)[0] - tf_test.assertAllClose(gradient_ndft_kdata, gradient_nufft_kdata, atol=1e-1) + tf_test.assertAllClose(gradient_ndft_kdata, gradient_nufft_kdata, atol=1e-2) # Test gradients with respect to trajectory location gradient_ndft_traj = g.gradient(I_ndft, ktraj)[0] gradient_nufft_traj = g.gradient(I_nufft, ktraj)[0] - tf_test.assertAllClose(gradient_ndft_traj, gradient_nufft_traj, atol=1e-1) + tf_test.assertAllClose(gradient_ndft_traj, gradient_nufft_traj, atol=1e-2) # This is gradient of NDFT from matrix, will help in debug # gradient_from_matrix = 2*np.pi*1j*tf.matmul(tf.cast(r, tf.complex64), tf.transpose(A))*kdata[0][0] -@pytest.mark.parametrize('im_size', [(20, ), (10, 15), (10, 15, 12)]) +@pytest.mark.parametrize('im_size', [(20, ), (10, 10)]) def test_forward_and_gradients(im_size): + tf.random.set_seed(0) grid_size = tuple(np.array(im_size)*2) im_rank = len(im_size) M = im_size[0] * 3**im_rank @@ -65,16 +67,16 @@ def test_forward_and_gradients(im_size): tf_test = tf.test.TestCase() # Test if the NUFFT and NDFT operation is same - tf_test.assertAllClose(kdata_nufft, kdata_ndft, atol=1e-1) + tf_test.assertAllClose(kdata_nufft, kdata_ndft, atol=1e-2) # Test gradients with respect to kdata gradient_ndft_kdata = g.gradient(kdata_ndft, signal)[0] gradient_nufft_kdata = g.gradient(kdata_nufft, signal)[0] - tf_test.assertAllClose(gradient_ndft_kdata, gradient_nufft_kdata, atol=1) + tf_test.assertAllClose(gradient_ndft_kdata, gradient_nufft_kdata, atol=1e-2) # Test gradients with respect to trajectory location gradient_ndft_traj = g.gradient(kdata_ndft, ktraj)[0] gradient_nufft_traj = g.gradient(kdata_nufft, ktraj)[0] - tf_test.assertAllClose(gradient_ndft_traj, gradient_nufft_traj, atol=1) + tf_test.assertAllClose(gradient_ndft_traj, gradient_nufft_traj, atol=1e-2) # This is gradient of NDFT from matrix, will help in debug # gradient_ndft_matrix = -2j * np.pi * tf.transpose(tf.matmul(A, tf.transpose(tf.cast(r, tf.complex64) * tf.reshape(signal[0][0], (N*N,))))) From 93d7f1bc93de6bbeefd4385b509f54c48ef75ed1 Mon Sep 17 00:00:00 2001 From: chaithyagr Date: Wed, 17 Mar 2021 13:10:31 +0100 Subject: [PATCH 16/24] ssh debug --- .github/workflows/test.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b294c05..12e980a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -27,3 +27,6 @@ jobs: pip install . - name: Test with pytest run: bash run_tests.sh + - name: Setup tmate session + if: ${{ failure() }} + uses: mxschmitt/action-tmate@v3 From 853753f9947d85e1c638626dd8336a345b859dae Mon Sep 17 00:00:00 2001 From: chaithyagr Date: Wed, 17 Mar 2021 13:11:36 +0100 Subject: [PATCH 17/24] ssh debug --- .github/workflows/test.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 12e980a..a76b67c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -28,5 +28,5 @@ jobs: - name: Test with pytest run: bash run_tests.sh - name: Setup tmate session - if: ${{ failure() }} - uses: mxschmitt/action-tmate@v3 + if: ${{ failure() }} + uses: mxschmitt/action-tmate@v3 From 8f2c9ff56ac92e58d8e8f8f52810e76fd9b67751 Mon Sep 17 00:00:00 2001 From: chaithyagr Date: Wed, 17 Mar 2021 13:39:11 +0100 Subject: [PATCH 18/24] Locally running fine --- tfkbnufft/tests/nufft/interp_functions_test.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tfkbnufft/tests/nufft/interp_functions_test.py b/tfkbnufft/tests/nufft/interp_functions_test.py index 18a1ede..466a70f 100644 --- a/tfkbnufft/tests/nufft/interp_functions_test.py +++ b/tfkbnufft/tests/nufft/interp_functions_test.py @@ -66,15 +66,18 @@ def test_run_interp(n_coil, conjcoef): 'conjcoef': conjcoef, } args = [griddat, tm, params] - torch_args = [to_torch_arg(arg) for arg in args] - # I need this because griddat is first n_coil then real/imag - torch_args[0] = torch_args[0].permute(1, 0, 2) - res_torch = torch_interp_functions.run_interp(*torch_args) + if not conjcoef: + torch_args = [to_torch_arg(arg) for arg in args] + # I need this because griddat is first n_coil then real/imag + torch_args[0] = torch_args[0].permute(1, 0, 2) + res_torch = torch_interp_functions.run_interp(*torch_args) # I need this because I create Jlist in a neater way for tensorflow params['Jlist'] = Jgen.T tf_args = [to_tf_arg(arg) for arg in args] res_tf = tf_interp_functions.run_interp(*tf_args) - np.testing.assert_allclose(torch_to_numpy(res_torch, complex_dim=1), res_tf.numpy()) + if not conjcoef: + # Compare results with torch + np.testing.assert_allclose(torch_to_numpy(res_torch, complex_dim=1), res_tf.numpy()) @pytest.mark.parametrize('n_coil', [1, 2, 5, 16]) def test_run_interp_back(n_coil): From d3c8ba2caec0dbd2719b00c2b85d9a2329022c08 Mon Sep 17 00:00:00 2001 From: chaithyagr Date: Wed, 17 Mar 2021 14:23:37 +0100 Subject: [PATCH 19/24] Add torch 1.7 for testing torchkbnufft --- run_tests.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/run_tests.sh b/run_tests.sh index 8b8c837..4d4a7ef 100644 --- a/run_tests.sh +++ b/run_tests.sh @@ -1,3 +1,3 @@ #!/usr/bin/env bash -pip install torchkbnufft==0.3.4 scikit-image pytest +pip install torch==1.7 torchkbnufft==0.3.4 scikit-image pytest python -m pytest tfkbnufft From 745754e6322ab5dad8fd5725dbe282f1f72ceb54 Mon Sep 17 00:00:00 2001 From: chaithyagr Date: Wed, 17 Mar 2021 14:34:46 +0100 Subject: [PATCH 20/24] mark test so that it runs. and remove ssh --- .github/workflows/test.yml | 3 --- tfkbnufft/tests/{ndft_tests.py => ndft_test.py} | 0 2 files changed, 3 deletions(-) rename tfkbnufft/tests/{ndft_tests.py => ndft_test.py} (100%) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index a76b67c..b294c05 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -27,6 +27,3 @@ jobs: pip install . - name: Test with pytest run: bash run_tests.sh - - name: Setup tmate session - if: ${{ failure() }} - uses: mxschmitt/action-tmate@v3 diff --git a/tfkbnufft/tests/ndft_tests.py b/tfkbnufft/tests/ndft_test.py similarity index 100% rename from tfkbnufft/tests/ndft_tests.py rename to tfkbnufft/tests/ndft_test.py From 7a3f380799b3514067486883eb05482460b5c4a4 Mon Sep 17 00:00:00 2001 From: Chaithya G R Date: Wed, 17 Mar 2021 14:58:42 +0100 Subject: [PATCH 21/24] Update tfkbnufft/nufft/fft_functions.py Co-authored-by: Zaccharie Ramzi --- tfkbnufft/nufft/fft_functions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tfkbnufft/nufft/fft_functions.py b/tfkbnufft/nufft/fft_functions.py index cdb718f..12b274c 100644 --- a/tfkbnufft/nufft/fft_functions.py +++ b/tfkbnufft/nufft/fft_functions.py @@ -154,7 +154,6 @@ def scale_and_fft_on_image_volume(x, scaling_coef, grid_size, im_size, norm, im_ # zero pad and fft x = tf.pad(x, pad_sizes) - # this might have to be a tf py function, or I could use tf cond x = fourier_list[do_ifft][multiprocessing][im_rank - 1](x) if norm == 'ortho': scaling_factor = tf.cast(tf.reduce_prod(grid_size), x.dtype) From 5a78ff3a3e1c98fc9691b161dcd62f1df52ffe65 Mon Sep 17 00:00:00 2001 From: chaithyagr Date: Wed, 17 Mar 2021 15:47:24 +0100 Subject: [PATCH 22/24] Update testing --- run_tests.sh | 3 ++- tfkbnufft/tests/ndft_test.py | 26 ++++++++++++++++---------- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/run_tests.sh b/run_tests.sh index 4d4a7ef..e41c0ec 100644 --- a/run_tests.sh +++ b/run_tests.sh @@ -1,3 +1,4 @@ #!/usr/bin/env bash pip install torch==1.7 torchkbnufft==0.3.4 scikit-image pytest -python -m pytest tfkbnufft +python -m pytest tfkbnufft --ignore=tfkbnufft/tests/ndft_test.py +python -m pytest tfkbnufft/tests/ndft_test.py diff --git a/tfkbnufft/tests/ndft_test.py b/tfkbnufft/tests/ndft_test.py index 81ad6f2..9313e89 100644 --- a/tfkbnufft/tests/ndft_test.py +++ b/tfkbnufft/tests/ndft_test.py @@ -5,12 +5,22 @@ from tfkbnufft.kbnufft import KbNufftModule -@pytest.mark.parametrize('im_size', [(20, ), (10, 10), (10, 10, 10)]) +def get_fourier_matrix(ktraj, grid_r, im_size, im_rank, do_ifft=False): + traj_grid = tf.cast(tf.matmul(tf.transpose(ktraj[0]), grid_r), tf.complex64) + if do_ifft: + A = tf.exp(1j * traj_grid) + else: + A = tf.exp(-1j * traj_grid) + A = A / (np.sqrt(tf.reduce_prod(im_size)) * np.power(np.sqrt(2), im_rank)) + return A + + +@pytest.mark.parametrize('im_size', [(10, ), (10, 10)]) def test_adjoint_and_gradients(im_size): tf.random.set_seed(0) grid_size = tuple(np.array(im_size)*2) im_rank = len(im_size) - M = im_size[0] * 3**im_rank + M = im_size[0] * 2**im_rank nufft_ob = KbNufftModule(im_size=im_size, grid_size=grid_size, norm='ortho', grad_traj=True) # Generate Trajectory ktraj = tf.Variable(tf.random.uniform((1, im_rank, M), minval=-1/2, maxval=1/2)*2*np.pi) @@ -22,9 +32,7 @@ def test_adjoint_and_gradients(im_size): I_nufft = kbnufft_adjoint(nufft_ob._extract_nufft_interpob())(kdata, ktraj)[0][0] r = [tf.linspace(-im_size[i]/2, im_size[i]/2-1, im_size[i]) for i in range(im_rank)] grid_r =tf.cast(tf.reshape(tf.meshgrid(*r ,indexing='ij'), (im_rank, tf.reduce_prod(im_size))), tf.float32) - A = tf.exp(1j * tf.cast(tf.matmul(tf.transpose(ktraj[0]), grid_r), tf.complex64)) / ( - np.sqrt(tf.reduce_prod(im_size)) * np.power(np.sqrt(2), im_rank) - ) + A = get_fourier_matrix(ktraj, grid_r, im_size, im_rank, do_ifft=True) I_ndft = tf.reshape(tf.matmul(tf.transpose(A), kdata[0][0][..., None]), im_size) tf_test = tf.test.TestCase() @@ -44,12 +52,12 @@ def test_adjoint_and_gradients(im_size): # gradient_from_matrix = 2*np.pi*1j*tf.matmul(tf.cast(r, tf.complex64), tf.transpose(A))*kdata[0][0] -@pytest.mark.parametrize('im_size', [(20, ), (10, 10)]) +@pytest.mark.parametrize('im_size', [(10, ), (10, 10)]) def test_forward_and_gradients(im_size): tf.random.set_seed(0) grid_size = tuple(np.array(im_size)*2) im_rank = len(im_size) - M = im_size[0] * 3**im_rank + M = im_size[0] * 2**im_rank nufft_ob = KbNufftModule(im_size=im_size, grid_size=grid_size, norm='ortho', grad_traj=True) # Generate Trajectory ktraj = tf.Variable(tf.random.uniform((1, im_rank, M), minval=-1/2, maxval=1/2)*2*np.pi) @@ -60,9 +68,7 @@ def test_forward_and_gradients(im_size): kdata_nufft = kbnufft_forward(nufft_ob._extract_nufft_interpob())(signal, ktraj)[0] r = [tf.linspace(-im_size[i]/2, im_size[i]/2-1, im_size[i]) for i in range(im_rank)] grid_r =tf.cast(tf.reshape(tf.meshgrid(*r ,indexing='ij'), (im_rank, tf.reduce_prod(im_size))), tf.float32) - A = tf.exp(-1j * tf.cast(tf.matmul(tf.transpose(ktraj[0]), grid_r), tf.complex64)) / ( - np.sqrt(tf.reduce_prod(im_size)) * np.power(np.sqrt(2), im_rank) - ) + A = get_fourier_matrix(ktraj, grid_r, im_size, im_rank, do_ifft=False) kdata_ndft = tf.transpose(tf.matmul(A, tf.reshape(signal[0][0], (tf.reduce_prod(im_size), 1)))) tf_test = tf.test.TestCase() From 6009f99ede1ed600ab5cd27e99d7d3990449f0db Mon Sep 17 00:00:00 2001 From: chaithyagr Date: Wed, 17 Mar 2021 16:04:12 +0100 Subject: [PATCH 23/24] Remove 1D tests --- tfkbnufft/tests/ndft_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tfkbnufft/tests/ndft_test.py b/tfkbnufft/tests/ndft_test.py index 9313e89..f21ad3a 100644 --- a/tfkbnufft/tests/ndft_test.py +++ b/tfkbnufft/tests/ndft_test.py @@ -15,7 +15,7 @@ def get_fourier_matrix(ktraj, grid_r, im_size, im_rank, do_ifft=False): return A -@pytest.mark.parametrize('im_size', [(10, ), (10, 10)]) +@pytest.mark.parametrize('im_size', [(10, 10)]) def test_adjoint_and_gradients(im_size): tf.random.set_seed(0) grid_size = tuple(np.array(im_size)*2) @@ -52,7 +52,7 @@ def test_adjoint_and_gradients(im_size): # gradient_from_matrix = 2*np.pi*1j*tf.matmul(tf.cast(r, tf.complex64), tf.transpose(A))*kdata[0][0] -@pytest.mark.parametrize('im_size', [(10, ), (10, 10)]) +@pytest.mark.parametrize('im_size', [(10, 10)]) def test_forward_and_gradients(im_size): tf.random.set_seed(0) grid_size = tuple(np.array(im_size)*2) From 8a7c6e43f8f2b7faf2b88996ee89682769b17bce Mon Sep 17 00:00:00 2001 From: chaithyagr Date: Wed, 17 Mar 2021 16:12:37 +0100 Subject: [PATCH 24/24] Final refactoring --- tfkbnufft/tests/ndft_test.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/tfkbnufft/tests/ndft_test.py b/tfkbnufft/tests/ndft_test.py index f21ad3a..2cdcd7b 100644 --- a/tfkbnufft/tests/ndft_test.py +++ b/tfkbnufft/tests/ndft_test.py @@ -5,7 +5,9 @@ from tfkbnufft.kbnufft import KbNufftModule -def get_fourier_matrix(ktraj, grid_r, im_size, im_rank, do_ifft=False): +def get_fourier_matrix(ktraj, im_size, im_rank, do_ifft=False): + r = [tf.linspace(-im_size[i]/2, im_size[i]/2-1, im_size[i]) for i in range(im_rank)] + grid_r =tf.cast(tf.reshape(tf.meshgrid(*r ,indexing='ij'), (im_rank, tf.reduce_prod(im_size))), tf.float32) traj_grid = tf.cast(tf.matmul(tf.transpose(ktraj[0]), grid_r), tf.complex64) if do_ifft: A = tf.exp(1j * traj_grid) @@ -30,9 +32,7 @@ def test_adjoint_and_gradients(im_size): with tf.GradientTape(persistent=True) as g: I_nufft = kbnufft_adjoint(nufft_ob._extract_nufft_interpob())(kdata, ktraj)[0][0] - r = [tf.linspace(-im_size[i]/2, im_size[i]/2-1, im_size[i]) for i in range(im_rank)] - grid_r =tf.cast(tf.reshape(tf.meshgrid(*r ,indexing='ij'), (im_rank, tf.reduce_prod(im_size))), tf.float32) - A = get_fourier_matrix(ktraj, grid_r, im_size, im_rank, do_ifft=True) + A = get_fourier_matrix(ktraj, im_size, im_rank, do_ifft=True) I_ndft = tf.reshape(tf.matmul(tf.transpose(A), kdata[0][0][..., None]), im_size) tf_test = tf.test.TestCase() @@ -66,9 +66,7 @@ def test_forward_and_gradients(im_size): with tf.GradientTape(persistent=True) as g: kdata_nufft = kbnufft_forward(nufft_ob._extract_nufft_interpob())(signal, ktraj)[0] - r = [tf.linspace(-im_size[i]/2, im_size[i]/2-1, im_size[i]) for i in range(im_rank)] - grid_r =tf.cast(tf.reshape(tf.meshgrid(*r ,indexing='ij'), (im_rank, tf.reduce_prod(im_size))), tf.float32) - A = get_fourier_matrix(ktraj, grid_r, im_size, im_rank, do_ifft=False) + A = get_fourier_matrix(ktraj, im_size, im_rank, do_ifft=False) kdata_ndft = tf.transpose(tf.matmul(A, tf.reshape(signal[0][0], (tf.reduce_prod(im_size), 1)))) tf_test = tf.test.TestCase()