From 3296c2bc0694a551800ba0b32d8e3a8dd707c2bb Mon Sep 17 00:00:00 2001 From: Yury Parfenov <4665475+warpuv@users.noreply.github.com> Date: Tue, 5 Nov 2024 12:24:03 +0300 Subject: [PATCH] [refactor] Generalization of dual_gemm_silu_identity_mul to calculate any activation function. (#1141) renamed dual_gemm_silu_identity_mul* to dual_gemm_lhs_activation_and_mul*, added a template parameter for the activation function to it. added customizable epilogue class (EpilogueLHSActivationAndMul) with the same template parameter. --- .../cuda/dual_gemm_silu_identity_mul.cu | 32 ++++-- .../cuda/epilogue_lhs_activation_and_mul.h | 108 ++++++++++++++++++ 2 files changed, 128 insertions(+), 12 deletions(-) create mode 100644 xformers/csrc/swiglu/cuda/epilogue_lhs_activation_and_mul.h diff --git a/xformers/csrc/swiglu/cuda/dual_gemm_silu_identity_mul.cu b/xformers/csrc/swiglu/cuda/dual_gemm_silu_identity_mul.cu index a643f5f3c8..1e3def4352 100644 --- a/xformers/csrc/swiglu/cuda/dual_gemm_silu_identity_mul.cu +++ b/xformers/csrc/swiglu/cuda/dual_gemm_silu_identity_mul.cu @@ -13,12 +13,15 @@ #include #include <45_dual_gemm/device/dual_gemm.h> -#include <45_dual_gemm/thread/left_silu_and_mul.h> +#include "epilogue_lhs_activation_and_mul.h" namespace { -template -std::tuple dual_gemm_silu_identity_mul_( +template +using SiLu = cutlass::epilogue::thread::SiLu; + +template typename ActivationFn> +std::tuple dual_gemm_lhs_activation_and_mul_( const at::Tensor& x, const at::Tensor& w0, const std::optional& b0, @@ -59,9 +62,10 @@ std::tuple dual_gemm_silu_identity_mul_( ElementAccumulator, ElementCompute, cutlass::epilogue::thread::ScaleType::NoBetaScaling>; - using EpilogueOutputOp2 = cutlass::epilogue::thread::LeftSiLUAndMul< + using EpilogueOutputOp2 = EpilogueLHSActivationAndMul< ElementOutput, 128 / cutlass::sizeof_bits::value, + ActivationFn, ElementOutput, ElementCompute>; @@ -163,7 +167,7 @@ std::tuple dual_gemm_silu_identity_mul_( cutlass::Status status = dual_gemm.can_implement(arguments); TORCH_CHECK( status == cutlass::Status::kSuccess, - "`dual_gemm_silu_identity_mul` does not support this input: ", + "`dual_gemm_lhs_activation_and_mul` does not support this input: ", cutlass::cutlassGetStatusString(status)); status = dual_gemm.initialize(arguments, (uint8_t*)workspace.data_ptr()); @@ -174,7 +178,8 @@ std::tuple dual_gemm_silu_identity_mul_( return std::make_tuple(d0, d1, d2); } -std::tuple dual_gemm_silu_identity_mul( +template