Skip to content

Commit

Permalink
[fix] Fix the activation checkpointing when using SwiGLUPackedFusedOp (
Browse files Browse the repository at this point in the history
…#1127)

* [fix] Fix the activation checkpointing when using SwiGLUPackedFusedOp

According to the docs (https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function) forward() method should not be called directly, apply() method have to be used instead.
After removing forward call, activation checkpointing starts working.

* [fix] Fix the activation checkpointing when using SwiGLUPackedFusedOp

The IF conditional on the x.requires_grad state (to change the behavior between inference/training modes) changes behavior of the recomputation of the forward() method which breaks activation checkpointing
(as on recomputation phase x is detached with requires_grad==False, and different number of tensors are saved in the save_for_backward() method).

* [fix] Fix the activation checkpointing when using SwiGLUPackedFusedOp by removing the inference path.

The IF conditional on the x.requires_grad state changes the behavior of the recomputation of the forward() method which breaks activation checkpointing
(as on the recomputation phase x is detached with requires_grad==False, and different number of tensors are saved in the save_for_backward() method).
  • Loading branch information
warpuv authored Oct 29, 2024
1 parent a97a1e0 commit 85cc546
Showing 1 changed file with 5 additions and 11 deletions.
16 changes: 5 additions & 11 deletions xformers/csrc/swiglu/swiglu_packedw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,10 @@ class SwiGLUPackedWeights
auto x5 = torch::nn::functional::linear(
x4, w3, b3.has_value() ? b3.value() : at::Tensor());

if (ctx != nullptr) {
ctx->save_for_backward({x, w1w2, w3, x1, x2});
ctx->saved_data["has_b1b2"] = b1b2.has_value();
ctx->saved_data["has_b3"] = b3.has_value();
}
ctx->save_for_backward({x, w1w2, w3, x1, x2});
ctx->saved_data["has_b1b2"] = b1b2.has_value();
ctx->saved_data["has_b3"] = b3.has_value();

return x5;
}

Expand Down Expand Up @@ -211,12 +210,7 @@ at::Tensor swiglu_packedw_cuda(
const std::optional<at::Tensor> b1b2,
const at::Tensor w3,
const std::optional<at::Tensor> b3) {
if (x.requires_grad()) {
return SwiGLUPackedWeights::apply(x, w1w2, b1b2, w3, b3);
} else {
return SwiGLUPackedWeights::forward(
/* ctx */ nullptr, x, w1w2, b1b2, w3, b3);
}
return SwiGLUPackedWeights::apply(x, w1w2, b1b2, w3, b3);
}
} // namespace

Expand Down

0 comments on commit 85cc546

Please sign in to comment.