From 85cc54633b17d9e4322e2f26e60bfe3a640a7d47 Mon Sep 17 00:00:00 2001 From: Yury Parfenov <4665475+warpuv@users.noreply.github.com> Date: Tue, 29 Oct 2024 14:41:17 +0300 Subject: [PATCH] [fix] Fix the activation checkpointing when using SwiGLUPackedFusedOp (#1127) * [fix] Fix the activation checkpointing when using SwiGLUPackedFusedOp According to the docs (https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function) forward() method should not be called directly, apply() method have to be used instead. After removing forward call, activation checkpointing starts working. * [fix] Fix the activation checkpointing when using SwiGLUPackedFusedOp The IF conditional on the x.requires_grad state (to change the behavior between inference/training modes) changes behavior of the recomputation of the forward() method which breaks activation checkpointing (as on recomputation phase x is detached with requires_grad==False, and different number of tensors are saved in the save_for_backward() method). * [fix] Fix the activation checkpointing when using SwiGLUPackedFusedOp by removing the inference path. The IF conditional on the x.requires_grad state changes the behavior of the recomputation of the forward() method which breaks activation checkpointing (as on the recomputation phase x is detached with requires_grad==False, and different number of tensors are saved in the save_for_backward() method). --- xformers/csrc/swiglu/swiglu_packedw.cpp | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/xformers/csrc/swiglu/swiglu_packedw.cpp b/xformers/csrc/swiglu/swiglu_packedw.cpp index e70a3a72fe..d27241730f 100644 --- a/xformers/csrc/swiglu/swiglu_packedw.cpp +++ b/xformers/csrc/swiglu/swiglu_packedw.cpp @@ -101,11 +101,10 @@ class SwiGLUPackedWeights auto x5 = torch::nn::functional::linear( x4, w3, b3.has_value() ? b3.value() : at::Tensor()); - if (ctx != nullptr) { - ctx->save_for_backward({x, w1w2, w3, x1, x2}); - ctx->saved_data["has_b1b2"] = b1b2.has_value(); - ctx->saved_data["has_b3"] = b3.has_value(); - } + ctx->save_for_backward({x, w1w2, w3, x1, x2}); + ctx->saved_data["has_b1b2"] = b1b2.has_value(); + ctx->saved_data["has_b3"] = b3.has_value(); + return x5; } @@ -211,12 +210,7 @@ at::Tensor swiglu_packedw_cuda( const std::optional b1b2, const at::Tensor w3, const std::optional b3) { - if (x.requires_grad()) { - return SwiGLUPackedWeights::apply(x, w1w2, b1b2, w3, b3); - } else { - return SwiGLUPackedWeights::forward( - /* ctx */ nullptr, x, w1w2, b1b2, w3, b3); - } + return SwiGLUPackedWeights::apply(x, w1w2, b1b2, w3, b3); } } // namespace