-
Notifications
You must be signed in to change notification settings - Fork 635
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fix] Fix the activation checkpointing when using SwiGLUPackedFusedOp #1127
Conversation
According to the docs (https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function) forward() method should not be called directly, apply() method have to be used instead. After removing forward call, activation checkpointing starts working.
@danthe3rd Can you review this PR please? This change fixes the integration with FSDP + activation_checkpointing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
It seems you're basically trying to revert the change introduced in #706. I don't have full context on that old PR, but I wonder whether there are ways to achieve both goals at once.
For example, we could take your PR and also replace the check if (ctx == nullptr)
above with if (x.required_grad)
. I believe this should get us everything we want?
6033c6c
to
bb5ef21
Compare
The IF conditional on the x.requires_grad state (to change the behavior between inference/training modes) changes behavior of the recomputation of the forward() method which breaks activation checkpointing (as on recomputation phase x is detached with requires_grad==False, and different number of tensors are saved in the save_for_backward() method).
@lw thank you for your review and suggestions. |
Dear @zyan0, do you have any objections to this change? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the changes, the new version looks good to me, though I'm still not sure I understand why the original code was that way. Perhaps @danthe3rd remembers?
Hi, |
@lw my guess that it was planned to optimize the inference path but this was never done. |
I discussed it with @danthe3rd and we agree that it's ok to undo the separation of |
… by removing the inference path. The IF conditional on the x.requires_grad state changes the behavior of the recomputation of the forward() method which breaks activation checkpointing (as on the recomputation phase x is detached with requires_grad==False, and different number of tensors are saved in the save_for_backward() method).
@lw Ok, I've uploaded the first version. Please check it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great! Thanks for all the iterations!
According to the docs (https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function) forward() method should not be called directly, apply() method have to be used instead. After removing forward call, activation checkpointing starts working.
What does this PR do?
Fixes #1126
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.