Skip to content

Commit

Permalink
[refactor] Generalization of dual_gemm_silu_identity_mul to calculate…
Browse files Browse the repository at this point in the history
… any activation function. (#1141)

renamed dual_gemm_silu_identity_mul* to dual_gemm_lhs_activation_and_mul*,
added a template parameter for the activation function to it.
added customizable epilogue class (EpilogueLHSActivationAndMul)
with the same template parameter.
  • Loading branch information
warpuv authored Nov 5, 2024
1 parent 1277989 commit 3296c2b
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 12 deletions.
32 changes: 20 additions & 12 deletions xformers/csrc/swiglu/cuda/dual_gemm_silu_identity_mul.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@
#include <torch/library.h>

#include <45_dual_gemm/device/dual_gemm.h>
#include <45_dual_gemm/thread/left_silu_and_mul.h>
#include "epilogue_lhs_activation_and_mul.h"

namespace {

template <typename scalar_t>
std::tuple<at::Tensor, at::Tensor, at::Tensor> dual_gemm_silu_identity_mul_(
template <typename T>
using SiLu = cutlass::epilogue::thread::SiLu<T>;

template <typename scalar_t, template <typename> typename ActivationFn>
std::tuple<at::Tensor, at::Tensor, at::Tensor> dual_gemm_lhs_activation_and_mul_(
const at::Tensor& x,
const at::Tensor& w0,
const std::optional<at::Tensor>& b0,
Expand Down Expand Up @@ -59,9 +62,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> dual_gemm_silu_identity_mul_(
ElementAccumulator,
ElementCompute,
cutlass::epilogue::thread::ScaleType::NoBetaScaling>;
using EpilogueOutputOp2 = cutlass::epilogue::thread::LeftSiLUAndMul<
using EpilogueOutputOp2 = EpilogueLHSActivationAndMul<
ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value,
ActivationFn,
ElementOutput,
ElementCompute>;

Expand Down Expand Up @@ -163,7 +167,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> dual_gemm_silu_identity_mul_(
cutlass::Status status = dual_gemm.can_implement(arguments);
TORCH_CHECK(
status == cutlass::Status::kSuccess,
"`dual_gemm_silu_identity_mul` does not support this input: ",
"`dual_gemm_lhs_activation_and_mul` does not support this input: ",
cutlass::cutlassGetStatusString(status));

status = dual_gemm.initialize(arguments, (uint8_t*)workspace.data_ptr());
Expand All @@ -174,7 +178,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> dual_gemm_silu_identity_mul_(
return std::make_tuple(d0, d1, d2);
}

std::tuple<at::Tensor, at::Tensor, at::Tensor> dual_gemm_silu_identity_mul(
template <template <typename> typename ActivationFn>
std::tuple<at::Tensor, at::Tensor, at::Tensor> dual_gemm_lhs_activation_and_mul(
const at::Tensor& x,
const at::Tensor& w0,
const std::optional<at::Tensor>& b0,
Expand All @@ -187,24 +192,27 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> dual_gemm_silu_identity_mul(

#define FWD_PARAMS x, w0, b0, w1, b1
if (x.scalar_type() == at::ScalarType::Half) {
return dual_gemm_silu_identity_mul_<cutlass::half_t>(FWD_PARAMS);
return dual_gemm_lhs_activation_and_mul_<cutlass::half_t, ActivationFn>(
FWD_PARAMS);
} else {
TORCH_CHECK(
x.scalar_type() == at::ScalarType::BFloat16, "Only supports bf16/f16");
return dual_gemm_silu_identity_mul_<cutlass::bfloat16_t>(FWD_PARAMS);
return dual_gemm_lhs_activation_and_mul_<cutlass::bfloat16_t, ActivationFn>(
FWD_PARAMS);
}
}

template <template <typename> typename ActivationFn>
std::tuple<at::Tensor, at::Tensor, at::Tensor>
dual_gemm_silu_identity_mul_autocast(
dual_gemm_lhs_activation_and_mul_autocast(
const at::Tensor& x,
const at::Tensor& w0,
const std::optional<at::Tensor>& b0,
const at::Tensor& w1,
const std::optional<at::Tensor>& b1) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
auto exec_type = at::autocast::get_autocast_dtype(at::kCUDA);
return dual_gemm_silu_identity_mul(
return dual_gemm_lhs_activation_and_mul<ActivationFn>(
at::autocast::cached_cast(exec_type, x),
at::autocast::cached_cast(exec_type, w0),
at::autocast::cached_cast(exec_type, b0),
Expand All @@ -217,11 +225,11 @@ dual_gemm_silu_identity_mul_autocast(
TORCH_LIBRARY_IMPL(xformers, CUDA, m) {
m.impl(
TORCH_SELECTIVE_NAME("xformers::dual_gemm_silu_identity_mul"),
TORCH_FN(dual_gemm_silu_identity_mul));
TORCH_FN(dual_gemm_lhs_activation_and_mul<SiLu>));
}

TORCH_LIBRARY_IMPL(xformers, Autocast, m) {
m.impl(
TORCH_SELECTIVE_NAME("xformers::dual_gemm_silu_identity_mul"),
TORCH_FN(dual_gemm_silu_identity_mul_autocast));
TORCH_FN(dual_gemm_lhs_activation_and_mul_autocast<SiLu>));
}
108 changes: 108 additions & 0 deletions xformers/csrc/swiglu/cuda/epilogue_lhs_activation_and_mul.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

/*
Implementation of the element_wise epilogue of the DualGemm kernel
with a custom activation function passed as a template parameter.
(third_party/cutlass/examples/45_dual_gemm/dual_gemm.cu)
DualGemm defined as:
D0 = epilogue0(X @ B0, C0)
D1 = epilogue1(X @ B1, C1)
D2 = element_wise(D0, D1)
where element_wise(D0, D1) = eltwise_mul(activation_func(D0), D1)
Code from CUTLASS examples used as reference:
third_party/cutlass/examples/45_dual_gemm/thread/left_silu_and_mul.h
*/

#pragma once

#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/functional.h"
#include "cutlass/numeric_conversion.h"

template <
typename ElementOutput_, // Type used for load and store
int Count, // Number of elements computed per operation
template <typename> typename ActivationFn_, // Activation functor
typename ElementAccumulator_ = ElementOutput_, // Accumulator type
typename ElementCompute_ = ElementOutput_, // Type used for internal compute
cutlass::FloatRoundStyle Round = cutlass::FloatRoundStyle::round_to_nearest>
class EpilogueLHSActivationAndMul {
public:
static int const kCount = Count;
static cutlass::FloatRoundStyle const kRound = Round;

using ElementOutput = ElementOutput_;
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;
using ActivationFnElementCompute = ActivationFn_<ElementCompute>;

using FragmentOutput = cutlass::Array<ElementOutput, kCount>;
using FragmentAccumulator = cutlass::Array<ElementAccumulator, kCount>;
using FragmentCompute = cutlass::Array<ElementCompute, kCount>;
using ActivationFnFragmentCompute = ActivationFn_<FragmentCompute>;

struct Params {};

public:
CUTLASS_HOST_DEVICE
EpilogueLHSActivationAndMul(Params const& /*params*/) {}

CUTLASS_HOST_DEVICE
bool is_source_needed() const {
return true;
}

CUTLASS_HOST_DEVICE
void set_k_partition(int k_partition, int k_partition_count) {
assert(false);
}

CUTLASS_HOST_DEVICE
FragmentOutput operator()(
FragmentAccumulator const& input_lhs,
FragmentAccumulator const& input_rhs) const {
cutlass::NumericArrayConverter<
ElementCompute,
ElementAccumulator,
kCount,
kRound>
acc_to_compute;

cutlass::
NumericArrayConverter<ElementOutput, ElementCompute, kCount, kRound>
compute_to_out;

FragmentCompute casted_lhs = acc_to_compute(input_lhs);
FragmentCompute casted_rhs = acc_to_compute(input_rhs);

ActivationFnFragmentCompute activation_func;
cutlass::multiplies<FragmentCompute> mul_func;

auto activation_lhs_out = activation_func(casted_lhs);
return compute_to_out(mul_func(activation_lhs_out, casted_rhs));
}

CUTLASS_HOST_DEVICE
ElementOutput operator()(
ElementAccumulator const& input_lhs,
ElementAccumulator const& input_rhs) const {
ElementCompute casted_lhs(input_lhs);
ElementCompute casted_rhs(input_rhs);

ActivationFnElementCompute activation_func;
cutlass::multiplies<ElementCompute> mul_func;

auto activation_lhs_out = activation_func(casted_lhs);
return ElementOutput(mul_func(activation_lhs_out, casted_rhs));
}
};

0 comments on commit 3296c2b

Please sign in to comment.