Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[fix] Fix the activation checkpointing when using SwiGLUPackedFusedOp (…
…#1127) * [fix] Fix the activation checkpointing when using SwiGLUPackedFusedOp According to the docs (https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function) forward() method should not be called directly, apply() method have to be used instead. After removing forward call, activation checkpointing starts working. * [fix] Fix the activation checkpointing when using SwiGLUPackedFusedOp The IF conditional on the x.requires_grad state (to change the behavior between inference/training modes) changes behavior of the recomputation of the forward() method which breaks activation checkpointing (as on recomputation phase x is detached with requires_grad==False, and different number of tensors are saved in the save_for_backward() method). * [fix] Fix the activation checkpointing when using SwiGLUPackedFusedOp by removing the inference path. The IF conditional on the x.requires_grad state changes the behavior of the recomputation of the forward() method which breaks activation checkpointing (as on the recomputation phase x is detached with requires_grad==False, and different number of tensors are saved in the save_for_backward() method).
- Loading branch information