Skip to content

Commit

Permalink
Remove using partitioner for fmha kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
qianfengz committed Dec 29, 2024
1 parent 870fefc commit ecf7724
Show file tree
Hide file tree
Showing 13 changed files with 78 additions and 199 deletions.
2 changes: 1 addition & 1 deletion third_party/composable_kernel_tiled
Submodule composable_kernel_tiled updated 38 files
+1 −2 example/ck_tile/01_fmha/README.md
+3 −17 example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
+2 −4 example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py
+4 −10 example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
+12 −2 example/ck_tile/01_fmha/fmha_fwd.hpp
+3 −3 example/ck_tile/03_gemm/gemm_basic.hpp
+4 −4 example/ck_tile/03_gemm/run_gemm_example.inc
+7 −13 example/ck_tile/03_gemm/universal_gemm.cpp
+34 −19 example/ck_tile/13_moe_sorting/moe_sorting_api.cpp
+2 −1 example/ck_tile/13_moe_sorting/script/smoke_test.sh
+34 −19 example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp
+9 −4 example/ck_tile/16_batched_gemm/batched_gemm.cpp
+2 −1 example/ck_tile/16_batched_gemm/batched_gemm.hpp
+4 −0 example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
+2 −2 include/ck/config.h.in
+27 −4 include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
+22 −4 include/ck_tile/ops/epilogue/default_2d_epilogue.hpp
+0 −3 include/ck_tile/ops/fmha.hpp
+20 −8 include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp
+71 −7 include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
+30 −9 include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
+0 −48 include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp
+30 −10 include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
+0 −54 include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp
+0 −105 include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp
+210 −37 include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
+9 −4 include/ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp
+25 −7 include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
+120 −44 include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+2 −0 include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
+2 −0 include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
+2 −0 include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
+8 −6 include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
+2 −0 include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp
+2 −0 include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
+3 −0 include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
+2 −1 test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
+3 −1 test/ck_tile/gemm/test_gemm_pipeline_util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ struct batched_forward_mask_bias_dropout_dispatch {
using FmhaMask = ck_tile::SimplifiedGenericAttentionMask<kHasMask>;

using FmhaFwdShape_ = typename FmhaFwdShape<MaxK, MTile>::Type;
using FmhaFwdTilePartitioner_ =
ck_tile::FmhaFwdTilePartitioner<FmhaFwdShape_>;
constexpr ck_tile::index_t occupancy =
(MaxK == 64) ? 3 : ((MaxK == 256) ? 1 : 2);

Expand Down Expand Up @@ -101,10 +99,8 @@ struct batched_forward_mask_bias_dropout_dispatch {
kPadSeqLenQ,
kPadHeadDim>>;

using FmhaFwdKernel_ = ck_tile::FmhaFwdKernel<
FmhaFwdTilePartitioner_,
FmhaFwdPipeline_,
FmhaFwdEpilogue_>;
using FmhaFwdKernel_ =
ck_tile::FmhaFwdKernel<FmhaFwdPipeline_, FmhaFwdEpilogue_>;

RunWithKernel<FmhaFwdKernel_>(param, stream);
});
Expand Down Expand Up @@ -163,7 +159,7 @@ struct batched_forward_mask_bias_dropout_dispatch {
}();

dim3 kGridSize =
FmhaFwdKernel::GridSize(param.B, param.Hq, param.M, param.Kv);
FmhaFwdKernel::GridSize(param.B, param.Hq, param.M, param.Kv, false);
constexpr dim3 kBlockSize = FmhaFwdKernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = FmhaFwdKernel::kBlockPerCu;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ struct batched_forward_splitkv_mask_bias_dropout_dispatch {

using FmhaTileShape =
typename FmhaFwdSplitKVShape<MaxK, MaxSeqlenQ>::Type;
using FmhaTilePartitioner =
ck_tile::FmhaFwdSplitKVTilePartitioner<FmhaTileShape>;
constexpr ck_tile::index_t occupancy = -1;

constexpr auto kBiasEnum = kHasBias
Expand Down Expand Up @@ -122,10 +120,8 @@ struct batched_forward_splitkv_mask_bias_dropout_dispatch {
false,
false>>;

using FmhaKernel = ck_tile::FmhaFwdSplitKVKernel<
FmhaTilePartitioner,
FmhaPipeline,
FmhaEpilogue>;
using FmhaKernel =
ck_tile::FmhaFwdSplitKVKernel<FmhaPipeline, FmhaEpilogue>;

RunWithFwdSplitKVKernel<FmhaKernel>(param, stream);
} else {
Expand All @@ -146,10 +142,8 @@ struct batched_forward_splitkv_mask_bias_dropout_dispatch {
false,
false>>;

using FmhaKernel = ck_tile::FmhaFwdSplitKVKernel<
FmhaTilePartitioner,
FmhaPipeline,
FmhaEpilogue>;
using FmhaKernel =
ck_tile::FmhaFwdSplitKVKernel<FmhaPipeline, FmhaEpilogue>;

RunWithFwdSplitKVKernel<FmhaKernel>(param, stream);
}
Expand All @@ -166,8 +160,6 @@ struct batched_forward_splitkv_mask_bias_dropout_dispatch {
typename FmhaFwdTypeConfig<ScalarType>::OaccDataType,
kN1>::kM0;

using FmhaTilePartitioner =
ck_tile::FmhaFwdSplitKVCombineTilePartitioner<kM0, kN1>;
constexpr ck_tile::index_t occupancy = -1;

const bool pad_seqlen_q = !(param.M % kM0 == 0);
Expand Down Expand Up @@ -199,10 +191,8 @@ struct batched_forward_splitkv_mask_bias_dropout_dispatch {
kPadSeqLenQ,
kPadHeadDimV>>;

using FmhaKernel = ck_tile::FmhaFwdSplitKVCombineKernel<
FmhaTilePartitioner,
FmhaPipeline,
FmhaEpilogue>;
using FmhaKernel = ck_tile::
FmhaFwdSplitKVCombineKernel<FmhaPipeline, FmhaEpilogue>;

RunWithSplitKVCombineKernel<FmhaKernel>(param, stream);
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,6 @@ struct batched_forward_splitkv_smallq_mask_bias_dropout_dispatch {
using FmhaMask = ck_tile::SimplifiedGenericAttentionMask<kHasMask>;

using FmhaTileShape = typename FmhaFwdSplitKVSmallQShape<MaxK>::Type;
using FmhaTilePartitioner =
ck_tile::FmhaFwdSplitKVTilePartitioner<FmhaTileShape>;
constexpr ck_tile::index_t occupancy = -1;

constexpr auto kBiasEnum = kHasBias
Expand Down Expand Up @@ -121,10 +119,8 @@ struct batched_forward_splitkv_smallq_mask_bias_dropout_dispatch {
false,
false>>;

using FmhaKernel = ck_tile::FmhaFwdSplitKVKernel<
FmhaTilePartitioner,
FmhaPipeline,
FmhaEpilogue>;
using FmhaKernel =
ck_tile::FmhaFwdSplitKVKernel<FmhaPipeline, FmhaEpilogue>;

RunWithFwdSplitKVKernel<FmhaKernel>(param, stream);
} else {
Expand All @@ -146,10 +142,8 @@ struct batched_forward_splitkv_smallq_mask_bias_dropout_dispatch {
false,
false>>;

using FmhaKernel = ck_tile::FmhaFwdSplitKVKernel<
FmhaTilePartitioner,
FmhaPipeline,
FmhaEpilogue>;
using FmhaKernel =
ck_tile::FmhaFwdSplitKVKernel<FmhaPipeline, FmhaEpilogue>;

RunWithFwdSplitKVKernel<FmhaKernel>(param, stream);
}
Expand All @@ -165,8 +159,6 @@ struct batched_forward_splitkv_smallq_mask_bias_dropout_dispatch {
typename FmhaFwdTypeConfig<ScalarType>::OaccDataType,
kN1>::kM0;

using FmhaTilePartitioner =
ck_tile::FmhaFwdSplitKVCombineTilePartitioner<kM0, kN1>;
constexpr ck_tile::index_t occupancy = -1;

const bool pad_seqlen_q = !(param.M % kM0 == 0);
Expand Down Expand Up @@ -198,10 +190,8 @@ struct batched_forward_splitkv_smallq_mask_bias_dropout_dispatch {
kPadSeqLenQ,
kPadHeadDimV>>;

using FmhaKernel = ck_tile::FmhaFwdSplitKVCombineKernel<
FmhaTilePartitioner,
FmhaPipeline,
FmhaEpilogue>;
using FmhaKernel = ck_tile::
FmhaFwdSplitKVCombineKernel<FmhaPipeline, FmhaEpilogue>;

RunWithSplitKVCombineKernel<FmhaKernel>(param, stream);
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ struct batched_infer_mask_bias_dropout_dispatch {
using FmhaMask = ck_tile::SimplifiedGenericAttentionMask<kHasMask>;

using FmhaShape = typename FmhaFwdShape<MaxK, MTile>::Type;
using FmhaTilePartitioner = ck_tile::FmhaFwdTilePartitioner<FmhaShape>;
constexpr ck_tile::index_t occupancy =
(MaxK == 64) ? 3 : ((MaxK == 256) ? 1 : 2);

Expand Down Expand Up @@ -103,8 +102,8 @@ struct batched_infer_mask_bias_dropout_dispatch {
kPadSeqLenQ,
kPadHeadDim>>;

using FmhaKernel = ck_tile::
FmhaFwdKernel<FmhaTilePartitioner, FmhaPipeline, FmhaEpilogue>;
using FmhaKernel =
ck_tile::FmhaFwdKernel<FmhaPipeline, FmhaEpilogue>;

RunWithKernel<FmhaKernel>(param, stream);
});
Expand Down Expand Up @@ -135,8 +134,7 @@ struct batched_infer_mask_bias_dropout_dispatch {
true,
true>>;

using FmhaKernel = ck_tile::
FmhaFwdKernel<FmhaTilePartitioner, FmhaPipeline, FmhaEpilogue>;
using FmhaKernel = ck_tile::FmhaFwdKernel<FmhaPipeline, FmhaEpilogue>;

RunWithKernel<FmhaKernel>(param, stream);
});
Expand Down Expand Up @@ -195,7 +193,8 @@ struct batched_infer_mask_bias_dropout_dispatch {
std::make_pair(param.philox_seed, param.philox_offset));
}();

dim3 kGridSize = FmhaKernel::GridSize(param.B, param.Hq, param.M, param.Kv);
dim3 kGridSize =
FmhaKernel::GridSize(param.B, param.Hq, param.M, param.Kv, false);
constexpr dim3 kBlockSize = FmhaKernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = FmhaKernel::kBlockPerCu;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ struct batched_infer_splitkv_mask_bias_dropout_dispatch {

using FmhaTileShape =
typename FmhaFwdSplitKVShape<MaxK, MaxSeqlenQ>::Type;
using FmhaTilePartitioner =
ck_tile::FmhaFwdSplitKVTilePartitioner<FmhaTileShape>;
constexpr ck_tile::index_t occupancy = -1;

constexpr auto kBiasEnum = kHasBias
Expand Down Expand Up @@ -122,10 +120,8 @@ struct batched_infer_splitkv_mask_bias_dropout_dispatch {
false,
false>>;

using FmhaKernel = ck_tile::FmhaFwdSplitKVKernel<
FmhaTilePartitioner,
FmhaPipeline,
FmhaEpilogue>;
using FmhaKernel =
ck_tile::FmhaFwdSplitKVKernel<FmhaPipeline, FmhaEpilogue>;

RunWithFwdSplitKVKernel<FmhaKernel>(param, stream);
} else {
Expand Down Expand Up @@ -159,10 +155,8 @@ struct batched_infer_splitkv_mask_bias_dropout_dispatch {
false,
false>>;

using FmhaKernel = ck_tile::FmhaFwdSplitKVKernel<
FmhaTilePartitioner,
FmhaPipeline,
FmhaEpilogue>;
using FmhaKernel =
ck_tile::FmhaFwdSplitKVKernel<FmhaPipeline, FmhaEpilogue>;

RunWithFwdSplitKVKernel<FmhaKernel>(param, stream);
}
Expand All @@ -179,8 +173,6 @@ struct batched_infer_splitkv_mask_bias_dropout_dispatch {
typename FmhaFwdTypeConfig<ScalarType>::OaccDataType,
kN1>::kM0;

using FmhaTilePartitioner =
ck_tile::FmhaFwdSplitKVCombineTilePartitioner<kM0, kN1>;
constexpr ck_tile::index_t occupancy = -1;

const bool pad_seqlen_q = !(param.M % kM0 == 0);
Expand Down Expand Up @@ -212,10 +204,8 @@ struct batched_infer_splitkv_mask_bias_dropout_dispatch {
kPadSeqLenQ,
kPadHeadDimV>>;

using FmhaKernel = ck_tile::FmhaFwdSplitKVCombineKernel<
FmhaTilePartitioner,
FmhaPipeline,
FmhaEpilogue>;
using FmhaKernel = ck_tile::
FmhaFwdSplitKVCombineKernel<FmhaPipeline, FmhaEpilogue>;

RunWithSplitKVCombineKernel<FmhaKernel>(param, stream);
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,6 @@ struct batched_infer_splitkv_smallq_mask_bias_dropout_dispatch {
using FmhaMask = ck_tile::SimplifiedGenericAttentionMask<kHasMask>;

using FmhaTileShape = typename FmhaFwdSplitKVSmallQShape<MaxK>::Type;
using FmhaTilePartitioner =
ck_tile::FmhaFwdSplitKVTilePartitioner<FmhaTileShape>;
constexpr ck_tile::index_t occupancy = -1;

constexpr auto kBiasEnum = kHasBias
Expand Down Expand Up @@ -121,10 +119,8 @@ struct batched_infer_splitkv_smallq_mask_bias_dropout_dispatch {
false,
false>>;

using FmhaKernel = ck_tile::FmhaFwdSplitKVKernel<
FmhaTilePartitioner,
FmhaPipeline,
FmhaEpilogue>;
using FmhaKernel =
ck_tile::FmhaFwdSplitKVKernel<FmhaPipeline, FmhaEpilogue>;

RunWithFwdSplitKVKernel<FmhaKernel>(param, stream);
} else {
Expand Down Expand Up @@ -159,10 +155,8 @@ struct batched_infer_splitkv_smallq_mask_bias_dropout_dispatch {
false,
false>>;

using FmhaKernel = ck_tile::FmhaFwdSplitKVKernel<
FmhaTilePartitioner,
FmhaPipeline,
FmhaEpilogue>;
using FmhaKernel =
ck_tile::FmhaFwdSplitKVKernel<FmhaPipeline, FmhaEpilogue>;

RunWithFwdSplitKVKernel<FmhaKernel>(param, stream);
}
Expand All @@ -178,8 +172,6 @@ struct batched_infer_splitkv_smallq_mask_bias_dropout_dispatch {
typename FmhaFwdTypeConfig<ScalarType>::OaccDataType,
kN1>::kM0;

using FmhaTilePartitioner =
ck_tile::FmhaFwdSplitKVCombineTilePartitioner<kM0, kN1>;
constexpr ck_tile::index_t occupancy = -1;

const bool pad_seqlen_q = !(param.M % kM0 == 0);
Expand Down Expand Up @@ -211,10 +203,8 @@ struct batched_infer_splitkv_smallq_mask_bias_dropout_dispatch {
kPadSeqLenQ,
kPadHeadDimV>>;

using FmhaKernel = ck_tile::FmhaFwdSplitKVCombineKernel<
FmhaTilePartitioner,
FmhaPipeline,
FmhaEpilogue>;
using FmhaKernel = ck_tile::
FmhaFwdSplitKVCombineKernel<FmhaPipeline, FmhaEpilogue>;

RunWithSplitKVCombineKernel<FmhaKernel>(param, stream);
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,26 +88,10 @@ struct grouped_forward_mask_bias_dropout_dispatch {
kPadSeqLenQ,
kPadHeadDimV>>;

if (param.seqlen_k_dev_ptr !=
nullptr) { // seqlen_k of batches are padded
using FmhaTilePartitioner =
ck_tile::FmhaFwdTilePartitioner_HBS<FmhaFwdShape_>;
using FmhaFwdKernel_ = ck_tile::FmhaFwdKernel<
FmhaTilePartitioner,
FmhaFwdPipeline_,
FmhaFwdEpilogue_>;

RunWithKernel<FmhaFwdKernel_>(param, stream);
} else {
using FmhaTilePartitioner =
ck_tile::FmhaFwdTilePartitioner_SHB<FmhaFwdShape_>;
using FmhaFwdKernel_ = ck_tile::FmhaFwdKernel<
FmhaTilePartitioner,
FmhaFwdPipeline_,
FmhaFwdEpilogue_>;

RunWithKernel<FmhaFwdKernel_>(param, stream);
}
using FmhaFwdKernel_ =
ck_tile::FmhaFwdKernel<FmhaFwdPipeline_, FmhaFwdEpilogue_>;

RunWithKernel<FmhaFwdKernel_>(param, stream);
});
};

Expand Down Expand Up @@ -157,7 +141,11 @@ struct grouped_forward_mask_bias_dropout_dispatch {
}();

dim3 kGridSize = FmhaFwdKernel::GridSize(
param.num_batches, param.Hq, param.max_seqlen_q, param.Kv);
param.num_batches,
param.Hq,
param.max_seqlen_q,
param.Kv,
param.seqlen_k_dev_ptr != nullptr);
constexpr dim3 kBlockSize = FmhaFwdKernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = FmhaFwdKernel::kBlockPerCu;

Expand Down
Loading

0 comments on commit ecf7724

Please sign in to comment.