Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] Fix the activation checkpointing when using SwiGLUPackedFusedOp #1127

Merged
merged 3 commits into from
Oct 29, 2024

Conversation

warpuv
Copy link
Contributor

@warpuv warpuv commented Oct 11, 2024

According to the docs (https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function) forward() method should not be called directly, apply() method have to be used instead. After removing forward call, activation checkpointing starts working.

What does this PR do?

Fixes #1126

Before submitting

  • Did you have fun?
    • Make sure you had fun coding 🙃
  • Did you read the contributor guideline?
  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
    • N/A
  • Did you make sure to update the docs?
    • N/A
  • Did you write any new necessary tests?
    • N/A
  • Did you update the changelog? (if needed)
    • N/A

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

According to the docs (https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function) forward() method should not be called directly, apply() method have to be used instead.
After removing forward call, activation checkpointing starts working.
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 11, 2024
@pansershrek
Copy link

@danthe3rd Can you review this PR please? This change fixes the integration with FSDP + activation_checkpointing

Copy link
Contributor

@lw lw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

It seems you're basically trying to revert the change introduced in #706. I don't have full context on that old PR, but I wonder whether there are ways to achieve both goals at once.

For example, we could take your PR and also replace the check if (ctx == nullptr) above with if (x.required_grad). I believe this should get us everything we want?

@warpuv warpuv force-pushed the main branch 2 times, most recently from 6033c6c to bb5ef21 Compare October 17, 2024 17:27
The IF conditional on the x.requires_grad state (to change the behavior between inference/training modes) changes behavior of the recomputation of the forward() method which breaks activation checkpointing
(as on recomputation phase x is detached with requires_grad==False, and different number of tensors are saved in the save_for_backward() method).
@warpuv
Copy link
Contributor Author

warpuv commented Oct 17, 2024

Thanks for the PR!

It seems you're basically trying to revert the change introduced in #706. I don't have full context on that old PR, but I wonder whether there are ways to achieve both goals at once.

For example, we could take your PR and also replace the check if (ctx == nullptr) above with if (x.required_grad). I believe this should get us everything we want?

@lw thank you for your review and suggestions.
The “if” conditional on x.requires_grad changes the behavior of the recomputation of the forward since x.requires_grad has different value as it is detached on recomputation phase, and in turn save_for_backward is not called.
I have pushed an alternative solution using torch::GradMode::is_enabled(), I believe both goals are achieved this way.

@warpuv warpuv requested a review from lw October 22, 2024 12:22
@warpuv
Copy link
Contributor Author

warpuv commented Oct 24, 2024

Dear @zyan0, do you have any objections to this change?

@pansershrek
Copy link

@lw, @zyan0 Is this solution ok or have you any objections to this change?

Copy link
Contributor

@lw lw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the changes, the new version looks good to me, though I'm still not sure I understand why the original code was that way. Perhaps @danthe3rd remembers?

@danthe3rd
Copy link
Contributor

Hi,
I believe the first version of the fix 46d2823 was simpler. Can you revert to that one? Then we can merge.

@warpuv
Copy link
Contributor Author

warpuv commented Oct 28, 2024

@lw my guess that it was planned to optimize the inference path but this was never done.
Current implementation of dual_gemm_silu_identity_mul produces two additional intermediate tensors which are not part of the output, but are used in the backward pass. I can imagine implementing the forward pass without producing these 2 intermediate tensors to improve speed in the inference mode.
@danthe3rd do you still think the first solution is better? If so, I will revert to it. In the second solution it is possible to optimize the inference path sometime in the future.

@lw
Copy link
Contributor

lw commented Oct 28, 2024

I discussed it with @danthe3rd and we agree that it's ok to undo the separation of apply and forward that was introduced in #706. If we ever need it again we will evaluate other options.

… by removing the inference path.

The IF conditional on the x.requires_grad state changes the behavior of the recomputation of the forward() method which breaks activation checkpointing
(as on the recomputation phase x is detached with requires_grad==False, and different number of tensors are saved in the save_for_backward() method).
@warpuv
Copy link
Contributor Author

warpuv commented Oct 28, 2024

@lw Ok, I've uploaded the first version. Please check it.

@warpuv warpuv requested a review from lw October 28, 2024 16:02
Copy link
Contributor

@lw lw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! Thanks for all the iterations!

@lw lw merged commit 85cc546 into facebookresearch:main Oct 29, 2024
1 check passed
@warpuv warpuv deleted the main branch October 29, 2024 17:37
@warpuv warpuv restored the main branch October 29, 2024 17:38
bertmaher pushed a commit to bertmaher/xformers that referenced this pull request Dec 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Activation checkpointing is not working on SwiGLU
5 participants