Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gradient with respect to trajectory #30

Merged
merged 25 commits into from
Mar 17, 2021

Conversation

chaithyagr
Copy link
Collaborator

@chaithyagr chaithyagr commented Mar 16, 2021

This PR handles adding the gradients for the NUFFT and adj-NUFFT with respect to trajectory locations.
Further, I updated some codes to have NUFFT workable with 1D signals (although there isnt any direct use of this immediately)

@zaccharieramzi You can proceed with review as the code is complete.
However, we will merge once I get tests working for this, for both 2D and 3D case atleast

This resolves #29

@zaccharieramzi zaccharieramzi self-requested a review March 16, 2021 13:56
Copy link
Owner

@zaccharieramzi zaccharieramzi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is some level of confusion from one of us at least in the adjoint jacobian computation.

I also have 2 remarks:

  • how did you test this?
  • how costly to compute is this new gradient compared to the old one? I am wondering if we shouldn't by default not compute it (it's an additional NUFFT op) and have a flag to compute it if required.

tfkbnufft/kbnufft.py Outdated Show resolved Hide resolved
tfkbnufft/kbnufft.py Outdated Show resolved Hide resolved
tfkbnufft/kbnufft.py Outdated Show resolved Hide resolved
tfkbnufft/kbnufft.py Outdated Show resolved Hide resolved
tfkbnufft/kbnufft.py Outdated Show resolved Hide resolved
tfkbnufft/nufft/fft_functions.py Show resolved Hide resolved
tfkbnufft/nufft/interp_functions.py Show resolved Hide resolved
@chaithyagr
Copy link
Collaborator Author

chaithyagr commented Mar 16, 2021

how did you test this?

So basically while I was writing the code, I really wanted to make sure I am right, so I wrote down NDFT and carried out actual autodiff for gradient computation which I can treat as ground truth. We can add this as test, however the dimensions must be small. Here is an example:

N = 20
M = 20*5
nufft_ob = KbNufftModule(im_size=(N, N), grid_size=(2*N, 2*N), norm='ortho')
ktraj = tf.Variable(tf.random.uniform((1, 2, M), minval=-1/2, maxval=1/2)*2*np.pi)
signal = tf.Variable(tf.cast(tf.random.uniform((1, 1, N, N)), tf.complex64))
with tf.GradientTape(persistent=True) as g:
    kdata_nufft = kbnufft_forward(nufft_ob._extract_nufft_interpob())(signal, ktraj)
    r = tf.cast(tf.reshape(tf.meshgrid(
        np.linspace(-N/2, N/2, N, endpoint=False),
        np.linspace(-N/2, N/2, N, endpoint=False),
        indexing='ij'
    ), (2, N*N)), tf.float32)
    A = tf.exp(-2j * np.pi * tf.cast(tf.matmul(tf.transpose(ktraj[0])/2/np.pi, r), tf.complex64))/N/2
    kdata_ndft = tf.matmul(A, tf.reshape(signal[0][0], (N*N, 1)))

grad2 = g.gradient(kdata_ndft, ktraj)[0]
grad3 = -2j * np.pi * tf.transpose(tf.matmul(A, tf.transpose(tf.cast(r, tf.complex64) * tf.reshape(signal[0][0], (N*N,)))))
grad4 = g.gradient(kdata_nufft, ktraj)[0]

In the above example, all grad2 (autodiff gradient for NDFT), grad3 (gradient from matrix which I calculate for NDFT) and grad4 (Gradient for NUFFT, which is implemented), must be the same!

how costly to compute is this new gradient compared to the old one? I am wondering if we shouldn't by default not compute it (it's an additional NUFFT op) and have a flag to compute it if required.

Well I have not done performance checks yet, but yes, we can surely have that as a part of initialization.

Copy link
Owner

@zaccharieramzi zaccharieramzi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments on readability.
Also I think that allowing im_rank == 1 is a bit out of the scope of this. I would rather have it in a different PR, also to make sure this important PR can be read without mixing concerns.

I think that the most convincing argument for the correctness of your implem is the test you are going to provide, so I am waiting for this impatiently, because it's going to be the core of the review.

tfkbnufft/kbnufft.py Outdated Show resolved Hide resolved
tfkbnufft/nufft/fft_functions.py Outdated Show resolved Hide resolved
tfkbnufft/nufft/interp_functions.py Outdated Show resolved Hide resolved
chaithyagr and others added 3 commits March 16, 2021 16:25
Co-authored-by: Zaccharie Ramzi <zaccharie.ramzi@gmail.com>
@chaithyagr
Copy link
Collaborator Author

chaithyagr commented Mar 17, 2021

Done with the above refactoring and also added a grad_traj argument...

Things left in this PR is:

  • Version Bump (Maybe this will be a major update, so I feel 0.2.0 maybe?)
  • Add tests
  • Check why current tests are failing with segfault and close in on it.

However, @zaccharieramzi you can maybe review codes till now, so that you wont have a lot of things to look into at once.

Copy link
Owner

@zaccharieramzi zaccharieramzi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments regarding the tests and minor nitpicks

tfkbnufft/tests/ndft_tests.py Outdated Show resolved Hide resolved
tfkbnufft/tests/ndft_tests.py Outdated Show resolved Hide resolved
Comment on lines 55 to 61
r = tf.cast(tf.reshape(tf.meshgrid(
tf.linspace(-N/2, N/2-1, N),
tf.linspace(-N/2, N/2-1, N),
indexing='ij'
), (2, N * N)), tf.float32)
A = tf.exp(-2j * np.pi * tf.cast(tf.matmul(tf.transpose(ktraj[0])/2/np.pi, r), tf.complex64))/N/2
kdata_ndft = tf.transpose(tf.matmul(A, tf.reshape(signal[0][0], (N*N, 1))))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think both these things should be refactored with the above test to enforce consistency.

tfkbnufft/nufft/fft_functions.py Outdated Show resolved Hide resolved
tfkbnufft/tests/nufft/interp_functions_test.py Outdated Show resolved Hide resolved
@zaccharieramzi
Copy link
Owner

Re the version bump: be my guest!

Re the seg fault: I think there must exist a way to debug GitHub Actions. I can help you with figuring out how to do that.

@chaithyagr
Copy link
Collaborator Author

@zaccharieramzi I think all our issues were due to the fact that we were trying to use torch 1.8 with old version of torchkbnufft. I have fixed it in this PR itself. I feel #31 is not needed. We can remove the test when we remove the codes also in #33 .
WDYT?
For now I think I am done with this PR, can you see if everything is fine and if we can go ahead with merge when green?

@zaccharieramzi
Copy link
Owner

Yes we can close #31 for the time being and reopen it when we tackle #33 entirely.

Reviewing in a few.

Copy link
Owner

@zaccharieramzi zaccharieramzi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still some comments on tests

tfkbnufft/nufft/fft_functions.py Outdated Show resolved Hide resolved
Comment on lines 62 to 65
grid_r =tf.cast(tf.reshape(tf.meshgrid(*r ,indexing='ij'), (im_rank, tf.reduce_prod(im_size))), tf.float32)
A = tf.exp(-1j * tf.cast(tf.matmul(tf.transpose(ktraj[0]), grid_r), tf.complex64)) / (
np.sqrt(tf.reduce_prod(im_size)) * np.power(np.sqrt(2), im_rank)
)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I said in the previous review, I think some refactoring should happen between the forward and adjoint tests in order to have a minimal review/maintenance effort

chaithyagr and others added 4 commits March 17, 2021 14:58
Copy link
Owner

@zaccharieramzi zaccharieramzi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Final bit

Comment on lines +3 to +4
python -m pytest tfkbnufft --ignore=tfkbnufft/tests/ndft_test.py
python -m pytest tfkbnufft/tests/ndft_test.py
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
python -m pytest tfkbnufft --ignore=tfkbnufft/tests/ndft_test.py
python -m pytest tfkbnufft/tests/ndft_test.py
python -m pytest tfkbnufft

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh no no, that wont work :P We cant merge them as then the codes just hang as we discussed..

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ok understood that wrong

Copy link
Owner

@zaccharieramzi zaccharieramzi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment on lines +3 to +4
python -m pytest tfkbnufft --ignore=tfkbnufft/tests/ndft_test.py
python -m pytest tfkbnufft/tests/ndft_test.py
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ok understood that wrong

@zaccharieramzi zaccharieramzi merged commit c086022 into zaccharieramzi:master Mar 17, 2021
@chaithyagr chaithyagr deleted the grad_traj branch March 17, 2021 18:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Have Implementation for gradients with respect to k-space locations
2 participants