-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add gradient with respect to trajectory #30
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there is some level of confusion from one of us at least in the adjoint jacobian computation.
I also have 2 remarks:
- how did you test this?
- how costly to compute is this new gradient compared to the old one? I am wondering if we shouldn't by default not compute it (it's an additional NUFFT op) and have a flag to compute it if required.
So basically while I was writing the code, I really wanted to make sure I am right, so I wrote down NDFT and carried out actual autodiff for gradient computation which I can treat as ground truth. We can add this as test, however the dimensions must be small. Here is an example: N = 20
M = 20*5
nufft_ob = KbNufftModule(im_size=(N, N), grid_size=(2*N, 2*N), norm='ortho')
ktraj = tf.Variable(tf.random.uniform((1, 2, M), minval=-1/2, maxval=1/2)*2*np.pi)
signal = tf.Variable(tf.cast(tf.random.uniform((1, 1, N, N)), tf.complex64))
with tf.GradientTape(persistent=True) as g:
kdata_nufft = kbnufft_forward(nufft_ob._extract_nufft_interpob())(signal, ktraj)
r = tf.cast(tf.reshape(tf.meshgrid(
np.linspace(-N/2, N/2, N, endpoint=False),
np.linspace(-N/2, N/2, N, endpoint=False),
indexing='ij'
), (2, N*N)), tf.float32)
A = tf.exp(-2j * np.pi * tf.cast(tf.matmul(tf.transpose(ktraj[0])/2/np.pi, r), tf.complex64))/N/2
kdata_ndft = tf.matmul(A, tf.reshape(signal[0][0], (N*N, 1)))
grad2 = g.gradient(kdata_ndft, ktraj)[0]
grad3 = -2j * np.pi * tf.transpose(tf.matmul(A, tf.transpose(tf.cast(r, tf.complex64) * tf.reshape(signal[0][0], (N*N,)))))
grad4 = g.gradient(kdata_nufft, ktraj)[0] In the above example, all
Well I have not done performance checks yet, but yes, we can surely have that as a part of initialization. |
Co-authored-by: Zaccharie Ramzi <zaccharie.ramzi@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some comments on readability.
Also I think that allowing im_rank == 1
is a bit out of the scope of this. I would rather have it in a different PR, also to make sure this important PR can be read without mixing concerns.
I think that the most convincing argument for the correctness of your implem is the test you are going to provide, so I am waiting for this impatiently, because it's going to be the core of the review.
Co-authored-by: Zaccharie Ramzi <zaccharie.ramzi@gmail.com>
Done with the above refactoring and also added a Things left in this PR is:
However, @zaccharieramzi you can maybe review codes till now, so that you wont have a lot of things to look into at once. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some comments regarding the tests and minor nitpicks
tfkbnufft/tests/ndft_tests.py
Outdated
r = tf.cast(tf.reshape(tf.meshgrid( | ||
tf.linspace(-N/2, N/2-1, N), | ||
tf.linspace(-N/2, N/2-1, N), | ||
indexing='ij' | ||
), (2, N * N)), tf.float32) | ||
A = tf.exp(-2j * np.pi * tf.cast(tf.matmul(tf.transpose(ktraj[0])/2/np.pi, r), tf.complex64))/N/2 | ||
kdata_ndft = tf.transpose(tf.matmul(A, tf.reshape(signal[0][0], (N*N, 1)))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think both these things should be refactored with the above test to enforce consistency.
Re the version bump: be my guest! Re the seg fault: I think there must exist a way to debug GitHub Actions. I can help you with figuring out how to do that. |
Co-authored-by: Zaccharie Ramzi <zaccharie.ramzi@gmail.com>
@zaccharieramzi I think all our issues were due to the fact that we were trying to use torch 1.8 with old version of torchkbnufft. I have fixed it in this PR itself. I feel #31 is not needed. We can remove the test when we remove the codes also in #33 . |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still some comments on tests
tfkbnufft/tests/ndft_test.py
Outdated
grid_r =tf.cast(tf.reshape(tf.meshgrid(*r ,indexing='ij'), (im_rank, tf.reduce_prod(im_size))), tf.float32) | ||
A = tf.exp(-1j * tf.cast(tf.matmul(tf.transpose(ktraj[0]), grid_r), tf.complex64)) / ( | ||
np.sqrt(tf.reduce_prod(im_size)) * np.power(np.sqrt(2), im_rank) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As I said in the previous review, I think some refactoring should happen between the forward and adjoint tests in order to have a minimal review/maintenance effort
Co-authored-by: Zaccharie Ramzi <zaccharie.ramzi@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Final bit
python -m pytest tfkbnufft --ignore=tfkbnufft/tests/ndft_test.py | ||
python -m pytest tfkbnufft/tests/ndft_test.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
python -m pytest tfkbnufft --ignore=tfkbnufft/tests/ndft_test.py | |
python -m pytest tfkbnufft/tests/ndft_test.py | |
python -m pytest tfkbnufft |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh no no, that wont work :P We cant merge them as then the codes just hang as we discussed..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah ok understood that wrong
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
python -m pytest tfkbnufft --ignore=tfkbnufft/tests/ndft_test.py | ||
python -m pytest tfkbnufft/tests/ndft_test.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah ok understood that wrong
This PR handles adding the gradients for the NUFFT and adj-NUFFT with respect to trajectory locations.
Further, I updated some codes to have NUFFT workable with 1D signals (although there isnt any direct use of this immediately)
@zaccharieramzi You can proceed with review as the code is complete.
However, we will merge once I get tests working for this, for both 2D and 3D case atleast
This resolves #29