From c9e2ce90b3328a0fb4f27d98e43148ad0507d25e Mon Sep 17 00:00:00 2001 From: Andrew Kwangwoong Park Date: Mon, 13 May 2024 14:03:02 +0900 Subject: [PATCH] [GPU] Enable Indirect gemm for non-batch gather axis support (#24321) ### Details: - Currently `indirect Gemm` is implemented with an assumption of `KV Cache`'s gather axis is 0 - Update `KV Cache` and `indirect Gemm` to support non-batch gather axis support ### Tickets: - 137142 --------- Signed-off-by: Andrew Park --- .../include/intel_gpu/op/indirect_gemm.hpp | 3 +++ .../include/intel_gpu/primitives/gemm.hpp | 9 ++++++- .../intel_gpu/src/graph/impls/ocl/gemm.cpp | 4 +++- .../src/graph/impls/ocl/kv_cache.cpp | 3 +++ .../cl_kernels/beam_table_update_ref.cl | 24 ++++++++++++++++++- .../kernel_selector/cl_kernels/gemm_ref.cl | 16 +++++++++++-- .../cl_kernels/gemm_tiled_opt.cl | 16 ++++++++++--- .../beam_table_update_kernel_ref.cpp | 13 +++++++--- .../beam_table_update_kernel_ref.hpp | 1 + .../kernels/gemm/gemm_kernel_base.cpp | 1 + .../kernels/gemm/gemm_kernel_base.h | 1 + .../kernels/gemm/gemm_kernel_tiled_opt.cpp | 2 +- .../intel_gpu/src/plugin/ops/matmul.cpp | 1 + .../transformations/indirect_kv_cache.cpp | 6 +++-- .../transformations/op/indirect_gemm.cpp | 6 ++++- .../tests/unit/test_cases/gemm_gpu_test.cpp | 3 ++- .../unit/test_cases/hash_key_gpu_test.cpp | 4 ++-- .../indirect_kv_cache_test.cpp | 16 ++++++++++--- 18 files changed, 108 insertions(+), 21 deletions(-) diff --git a/src/plugins/intel_gpu/include/intel_gpu/op/indirect_gemm.hpp b/src/plugins/intel_gpu/include/intel_gpu/op/indirect_gemm.hpp index ef51a9fb11d7f4..654f049f278c79 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/op/indirect_gemm.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/op/indirect_gemm.hpp @@ -24,6 +24,7 @@ class IndirectGemm : public ov::intel_gpu::op::Gemm { const ov::Output& I, bool indirect_a, bool indirect_b, + int64_t indirect_axis, const std::vector& order_a, const std::vector& order_b, const std::vector& order_c, @@ -38,12 +39,14 @@ class IndirectGemm : public ov::intel_gpu::op::Gemm { bool get_indirect_a() const { return m_indirect_a; } bool get_indirect_b() const { return m_indirect_b; } + int64_t get_indirect_axis() const { return m_indirect_axis; } using ov::intel_gpu::op::Gemm::default_order; protected: bool m_indirect_a = false; bool m_indirect_b = false; + int64_t m_indirect_axis = 0; }; } // namespace op diff --git a/src/plugins/intel_gpu/include/intel_gpu/primitives/gemm.hpp b/src/plugins/intel_gpu/include/intel_gpu/primitives/gemm.hpp index b5d2dd66508af8..0917074cce66c0 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/primitives/gemm.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/primitives/gemm.hpp @@ -117,6 +117,7 @@ struct gemm : public primitive_base { const std::vector& output_transpose_order, bool indirect_a, bool indirect_b, + int64_t indirect_axis, const float alpha = 1.0f, const float beta = 0.0f, const padding& output_padding = padding()) @@ -130,7 +131,8 @@ struct gemm : public primitive_base { weight_rank(input1_transpose_order.size()), beam_table(beam_table), indirect_a(indirect_a), - indirect_b(indirect_b) { + indirect_b(indirect_b), + indirect_axis(indirect_axis) { if (inputs.size() != 2 && inputs.size() != 3) { throw std::invalid_argument("Invalid inputs count - gemm expects either two or three inputs"); } @@ -162,6 +164,7 @@ struct gemm : public primitive_base { input_info beam_table = {}; bool indirect_a = false; bool indirect_b = false; + int64_t indirect_axis = 0; size_t hash() const override { size_t seed = primitive::hash(); @@ -169,6 +172,7 @@ struct gemm : public primitive_base { seed = hash_combine(seed, transpose_input1); seed = hash_combine(seed, indirect_a); seed = hash_combine(seed, indirect_b); + seed = hash_combine(seed, indirect_axis); seed = hash_range(seed, input0_transpose_order.begin(), input0_transpose_order.end()); seed = hash_range(seed, input1_transpose_order.begin(), input1_transpose_order.end()); seed = hash_range(seed, output_transpose_order.begin(), output_transpose_order.end()); @@ -189,6 +193,7 @@ struct gemm : public primitive_base { beta == rhs_casted.beta && indirect_a == rhs_casted.indirect_a && indirect_b == rhs_casted.indirect_b && + indirect_axis == rhs_casted.indirect_axis && input_rank == rhs_casted.input_rank && weight_rank == rhs_casted.weight_rank; } @@ -206,6 +211,7 @@ struct gemm : public primitive_base { ob << weight_rank; ob << indirect_a; ob << indirect_b; + ob << indirect_axis; ob << beam_table.pid; ob << beam_table.idx; } @@ -223,6 +229,7 @@ struct gemm : public primitive_base { ib >> weight_rank; ib >> indirect_a; ib >> indirect_b; + ib >> indirect_axis; ib >> beam_table.pid; ib >> beam_table.idx; } diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/gemm.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl/gemm.cpp index bcbae3b9dc7f6b..f7220dad387348 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl/gemm.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/gemm.cpp @@ -138,7 +138,8 @@ struct gemm_impl : multi_stage_primitive { return false; const auto& params = *inst.get_impl_params(); - if (params.input_layouts[get_beam_table_id(desc)].get_partial_shape()[0].get_length() == 1) + const auto indirect_axis = desc->indirect_axis; + if (params.input_layouts[get_beam_table_id(desc)].get_partial_shape()[indirect_axis].get_length() == 1) return false; const auto& deps = inst.dependencies(); @@ -226,6 +227,7 @@ struct gemm_impl : multi_stage_primitive { params.indirect_input0 = primitive->indirect_a && indirect; params.indirect_input1 = primitive->indirect_b && indirect; + params.indirect_axis = primitive->indirect_axis; if (indirect && (primitive->indirect_a || primitive->indirect_b)) { OPENVINO_ASSERT(impl_param.input_layouts.size() >= 3, "[GPU] Actual inputs count: ", impl_param.input_layouts.size()); params.inputs.push_back(convert_data_tensor(impl_param.input_layouts[get_beam_table_id(primitive)])); diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/kv_cache.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl/kv_cache.cpp index b84de8bdfb2f18..4c463d5d6b7c1e 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl/kv_cache.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/kv_cache.cpp @@ -243,7 +243,9 @@ struct kv_cache_impl : multi_stage_primitive { } static bt_kernel_params_t get_bt_update_kernel_params(const kernel_impl_params& impl_param, bool is_state_set = false) { + const auto& primitive = impl_param.typed_desc(); auto params = get_default_params(impl_param, true); + auto indirect_axis = primitive->gather_axis; auto inputs_count = 2; auto bt_present_layout = impl_param.output_layouts[1]; @@ -260,6 +262,7 @@ struct kv_cache_impl : multi_stage_primitive { params.outputs[0] = convert_data_tensor(bt_present_layout); params.inputs.resize(inputs_count); params.is_state_set = is_state_set; + params.indirect_axis = indirect_axis; const auto& in_offsets_map = impl_param.in_port_to_shape_info_offset; // [kv_past, kv_new_token, [beam_idx, beam_table_past]] const auto& out_offsets_map = impl_param.out_port_to_shape_info_offset; // [kv_present, beam_table_present] diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/beam_table_update_ref.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/beam_table_update_ref.cl index 009cef79c25c53..ded3c333da05eb 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/beam_table_update_ref.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/beam_table_update_ref.cl @@ -15,18 +15,40 @@ KERNEL(beam_table_update)( const unsigned int s = (uint)get_global_id(1); const unsigned int out_offset = b * OUTPUT_BATCH_PITCH + s; +#if INDIRECT_AXIS == 0 const unsigned int in_offset = beam_idx[b] * INPUT0_BATCH_PITCH + s; +#elif INDIRECT_AXIS == 1 + const unsigned int in_offset = b * INPUT0_BATCH_PITCH + beam_idx[s]; +#else +# error beam_table_update_ref.cl : Unsupported indirect axis for beam table +#endif if (s >= OUTPUT_BATCH_PITCH) return; if (!is_state_set) { + #if INDIRECT_AXIS == 0 state_new[out_offset] = TO_OUTPUT_TYPE(b); + #elif INDIRECT_AXIS == 1 + state_new[out_offset] = TO_OUTPUT_TYPE(s); + #else + # error beam_table_update_ref.cl : Unsupported indirect axis for beam table + #endif } else { + #if INDIRECT_AXIS == 0 if (s < INPUT0_BATCH_PITCH) { state_new[out_offset] = state_prev[in_offset]; } else { - state_new[out_offset] = b; + state_new[out_offset] = TO_OUTPUT_TYPE(b); } + #elif INDIRECT_AXIS == 1 + if (b < INPUT0_BATCH_NUM) { + state_new[out_offset] = state_prev[in_offset]; + } else { + state_new[out_offset] = TO_OUTPUT_TYPE(s); + } + #else + # error beam_table_update_ref.cl : Unsupported indirect axis for beam table + #endif } } diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/gemm_ref.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/gemm_ref.cl index e90841d56fd33d..6a2cf8e268f649 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/gemm_ref.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/gemm_ref.cl @@ -131,10 +131,22 @@ KERNEL(gemm_ref)( uint b0 = b; uint b1 = b; #if INDIRECT_INPUT0 - b0 = BEAM_TABLE_BATCH_NUM > 1 ? beam_table[FUNC_CALL(get_bt_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, y, ki)] : b; + #if INDIRECT_AXIS == 0 + b0 = BEAM_TABLE_BATCH_NUM > 1 ? beam_table[FUNC_CALL(get_bt_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, y, ki)] : b; + #elif INDIRECT_AXIS == 1 + b0 = BEAM_TABLE_FEATURE_NUM > 1 ? beam_table[FUNC_CALL(get_bt_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, y, ki)] : b; + #else + # error gemm_ref.cl : Unsupported indirect axis for beam table + #endif #endif #if INDIRECT_INPUT1 - b1 = BEAM_TABLE_BATCH_NUM > 1 ? beam_table[FUNC_CALL(get_bt_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, ki, x)] : b; + #if INDIRECT_AXIS == 0 + b1 = BEAM_TABLE_BATCH_NUM > 1 ? beam_table[FUNC_CALL(get_bt_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, ki, x)] : b; + #elif INDIRECT_AXIS == 1 + b1 = BEAM_TABLE_FEATURE_NUM > 1 ? beam_table[FUNC_CALL(get_bt_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, ki, x)] : b; + #else + # error gemm_ref.cl : Unsupported indirect axis for beam table + #endif #endif uint in0_idx = FUNC_CALL(get_input0_index)(OPTIONAL_SHAPE_INFO_TENSOR b0, f, w, z, y, ki); diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/gemm_tiled_opt.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/gemm_tiled_opt.cl index eb732b122a1c4a..8dc2103fdca5a3 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/gemm_tiled_opt.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/gemm_tiled_opt.cl @@ -79,14 +79,20 @@ inline uint FUNC(get_bt_index)(OPTIONAL_SHAPE_INFO_ARG uint b, uint f, uint w, u #if INDIRECT_INPUT0 inline uint FUNC(get_input0_indirect_index)(OPTIONAL_SHAPE_INFO_ARG uint b, uint f, uint w, uint z, uint y, uint x, __global BEAM_TABLE_TYPE* beam_table) { +#if INDIRECT_AXIS == 0 int b_index = BEAM_TABLE_BATCH_NUM > 1 ? beam_table[FUNC_CALL(get_bt_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, y, x)] : b; +#elif INDIRECT_AXIS == 1 + int b_index = BEAM_TABLE_FEATURE_NUM > 1 ? beam_table[FUNC_CALL(get_bt_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, y, x)] : b; +#else +# error gemm_tiled_opt.cl : Unsupported indirect axis for beam table +#endif return FUNC_CALL(get_input0_index)(OPTIONAL_SHAPE_INFO_TENSOR b_index, f, w, z, y, x); } #endif #if INDIRECT_INPUT1 inline uint FUNC(get_input1_indirect_index)(OPTIONAL_SHAPE_INFO_ARG uint b, uint f, uint w, uint z, uint y, uint x, __global BEAM_TABLE_TYPE* beam_table) { - int b_index = BEAM_TABLE_BATCH_NUM > 1 ? beam_table[FUNC_CALL(get_bt_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, y, x)] : b; + int b_index = beam_table[FUNC_CALL(get_bt_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, y, x)]; return FUNC_CALL(get_input1_index)(OPTIONAL_SHAPE_INFO_TENSOR b_index, f, w, z, y, x); } #endif @@ -218,7 +224,13 @@ KERNEL(gemm_tiled_opt)( const uint b_raw_global_id = tile_n_offset + sglid; #if INDIRECT_INPUT0 || INDIRECT_INPUT1 +#if INDIRECT_AXIS == 0 const char do_indirect_load = BEAM_TABLE_BATCH_NUM > 1; +#elif INDIRECT_AXIS == 1 + const char do_indirect_load = BEAM_TABLE_FEATURE_NUM > 1; +#else +# error gemm_tiled_opt.cl : Unsupported indirect axis for beam table +#endif #endif #if TRANSPOSE_INPUT0 != TRANSPOSE_X_LAST @@ -344,7 +356,6 @@ KERNEL(gemm_tiled_opt)( unroll_for (uint b_load_id = 0; b_load_id < TILE_K; b_load_id++) { uint b_load_offset = (k * TILE_K) + b_load_id; uint b_idx = FUNC_CALL(get_input1_indirect_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, b_load_offset, x, beam_table); - uint bt_idx = FUNC_CALL(get_bt_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, b_load_offset, x); b_tile[b_load_id] = b_raw_global_id >= N ? 0 : input1[b_idx]; } #else @@ -352,7 +363,6 @@ KERNEL(gemm_tiled_opt)( unroll_for (uint b_load_id = 0; b_load_id < TILE_K; b_load_id++) { uint b_load_offset = k * TILE_K + b_load_id; uint b_idx = FUNC_CALL(get_input1_indirect_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, b_load_offset, x + sglid + SIMD_WIDTH * b_elem, beam_table); - uint bt_idx = FUNC_CALL(get_bt_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, b_load_offset, x + sglid + SIMD_WIDTH * b_elem); b_tile[b_elem][b_load_id] = b_raw_global_id + SIMD_WIDTH * b_elem >= N ? 0 : input1[b_idx]; } } diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/beam_table_update/beam_table_update_kernel_ref.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/beam_table_update/beam_table_update_kernel_ref.cpp index 11cd76a3aa62f0..f6d098cfe00d5c 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/beam_table_update/beam_table_update_kernel_ref.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/beam_table_update/beam_table_update_kernel_ref.cpp @@ -94,7 +94,9 @@ bool BeamTableUpdateKernelRef::Validate(const Params& params) const { } JitConstants BeamTableUpdateKernelRef::GetJitConstants(const beam_table_update_params& kernel_params) const { - return MakeBaseParamsJitConstants(kernel_params); + JitConstants jit = MakeBaseParamsJitConstants(kernel_params); + jit.AddConstant({MakeJitConstant("INDIRECT_AXIS", kernel_params.indirect_axis)}); + return jit; } CommonDispatchData BeamTableUpdateKernelRef::SetDefault(const beam_table_update_params& kernel_params) { @@ -102,8 +104,13 @@ CommonDispatchData BeamTableUpdateKernelRef::SetDefault(const beam_table_update_ auto output = kernel_params.outputs[0]; if (!output.is_dynamic()) { - dispatch_data.gws = {output.Batch().v, Align(output.LogicalSize() / output.Batch().v, 16), 1}; - dispatch_data.lws = {1, 16, 1}; + if (kernel_params.indirect_axis == 0) { + dispatch_data.gws = {output.Batch().v, Align(output.LogicalSize() / output.Batch().v, 16), 1}; + dispatch_data.lws = {1, 16, 1}; + } else { + dispatch_data.gws = {Align(output.LogicalSize() / output.Feature().v, 16), output.Feature().v, 1}; + dispatch_data.lws = {16, 1, 1}; + } } return dispatch_data; diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/beam_table_update/beam_table_update_kernel_ref.hpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/beam_table_update/beam_table_update_kernel_ref.hpp index 5df1c009ff358e..e565d49db353f4 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/beam_table_update/beam_table_update_kernel_ref.hpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/beam_table_update/beam_table_update_kernel_ref.hpp @@ -11,6 +11,7 @@ namespace kernel_selector { struct beam_table_update_params : base_params { beam_table_update_params() : base_params(KernelType::BEAM_TABLE_UPDATE) {} bool is_state_set = false; + int64_t indirect_axis = 0; }; class BeamTableUpdateKernelRef : public KernelBaseOpenCL { diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/gemm/gemm_kernel_base.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/gemm/gemm_kernel_base.cpp index 461c306a56758d..3b9a348622f4fb 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/gemm/gemm_kernel_base.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/gemm/gemm_kernel_base.cpp @@ -184,6 +184,7 @@ JitConstants GemmKernelBase::GetJitConstants(const gemm_params& params) const { MakeJitConstant("QUANTIZATION_TERM", params.quantization != QuantizationType::NONE), MakeJitConstant("INDIRECT_INPUT0", params.indirect_input0), MakeJitConstant("INDIRECT_INPUT1", params.indirect_input1), + MakeJitConstant("INDIRECT_AXIS", params.indirect_axis), MakeJitConstant("BEAM_TABLE_TERM", params.indirect_input0 || params.indirect_input1), }); diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/gemm/gemm_kernel_base.h b/src/plugins/intel_gpu/src/kernel_selector/kernels/gemm/gemm_kernel_base.h index 32a186412d3d66..e2a94584ce0dcd 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/gemm/gemm_kernel_base.h +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/gemm/gemm_kernel_base.h @@ -29,6 +29,7 @@ struct gemm_params : public base_params { DataTensor beam_table; bool indirect_input0 = false; bool indirect_input1 = false; + int64_t indirect_axis = 0; QuantizationType quantization = QuantizationType::NONE; ParamsKey GetParamsKey() const override { diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/gemm/gemm_kernel_tiled_opt.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/gemm/gemm_kernel_tiled_opt.cpp index 2041cc983837ef..b367e40308104d 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/gemm/gemm_kernel_tiled_opt.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/gemm/gemm_kernel_tiled_opt.cpp @@ -108,7 +108,7 @@ GemmKernelTiledOpt::GemmTuningData GemmKernelTiledOpt::SetTuningParams(const gem tuning_data.tile_m_size = tuning_data.simd_size; bool output_ndim_transposed = (params.output_order.size() > 0 && (params.output_order.back() != (static_cast(params.output_order.size()) - 1))); if ((params.transpose_input0 == 0 /*X_LAST*/) && (params.transpose_input1 == 0 /*X_LAST*/ || params.transpose_input1 == 1 /*Y_LAST*/) - && (!params.indirect_input0 && !params.inputs[0].has_dynamic_pad()) + && (!params.indirect_input0 && !params.inputs[0].has_dynamic_pad() && params.indirect_axis != 1) && (!output_ndim_transposed || params.fused_ops.empty()) && !params.engineInfo.supports_immad) { // - Not supports transposed input0 / transposed input1 for OTHER mode yet diff --git a/src/plugins/intel_gpu/src/plugin/ops/matmul.cpp b/src/plugins/intel_gpu/src/plugin/ops/matmul.cpp index 05335d2193b1c3..9cbbe179173915 100644 --- a/src/plugins/intel_gpu/src/plugin/ops/matmul.cpp +++ b/src/plugins/intel_gpu/src/plugin/ops/matmul.cpp @@ -205,6 +205,7 @@ static void CreateIndirectGemmOp(ProgramBuilder& p, const std::shared_ptrget_output_transpose_order(), op->get_indirect_a(), op->get_indirect_b(), + op->get_indirect_axis(), alpha, beta); diff --git a/src/plugins/intel_gpu/src/plugin/transformations/indirect_kv_cache.cpp b/src/plugins/intel_gpu/src/plugin/transformations/indirect_kv_cache.cpp index 026853e8439c9a..d612ad03886f19 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations/indirect_kv_cache.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations/indirect_kv_cache.cpp @@ -49,7 +49,7 @@ IndirectKVCache::IndirectKVCache() { auto gather_input = wrap_type(); auto axis_const = wrap_type( ov::op::util::constant_predicate([](const std::vector& value) -> bool { - return value.size() == 1 && value[0] == 0; + return value.size() == 1 && (value[0] == 0 || value[0] == 1); })); auto gather_past = wrap_type({gather_input, beam_idx, axis_const}); auto kv_cache = wrap_type({gather_past, any_input()}); @@ -68,6 +68,7 @@ IndirectKVCache::IndirectKVCache() { auto beam_idx_node = pattern_map.at(beam_idx).get_node_shared_ptr(); auto gather_input_node = pattern_map.at(gather_input).get_node_shared_ptr(); auto gather_node = std::dynamic_pointer_cast(pattern_map.at(gather_past).get_node_shared_ptr()); + auto gather_axis = gather_node->get_axis(); ov::replace_node(gather_node, gather_input_node); auto indirect_kv_cache = std::make_shared(gather_input_node, @@ -75,7 +76,7 @@ IndirectKVCache::IndirectKVCache() { beam_idx_node, kv_cache_node->get_variable(), kv_cache_node->get_concat_axis(), - gather_node->get_axis(), + gather_axis, kv_cache_node->get_output_element_type(0)); indirect_kv_cache->set_friendly_name(kv_cache_node->get_friendly_name()); @@ -95,6 +96,7 @@ IndirectKVCache::IndirectKVCache() { indirect_kv_cache->output(1), // beam table matmul_kv_cache_index == 0, matmul_kv_cache_index == 1, + gather_axis, order_in0, order_in1, order_out); diff --git a/src/plugins/intel_gpu/src/plugin/transformations/op/indirect_gemm.cpp b/src/plugins/intel_gpu/src/plugin/transformations/op/indirect_gemm.cpp index 80e8a7a602c222..e530859f608fb5 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations/op/indirect_gemm.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations/op/indirect_gemm.cpp @@ -14,13 +14,15 @@ IndirectGemm::IndirectGemm(const ov::Output& A, const ov::Output& I, bool indirect_a, bool indirect_b, + int64_t indirect_axis, const std::vector& order_a, const std::vector& order_b, const std::vector& order_c, const ov::element::Type output_type) : ov::intel_gpu::op::Gemm(A, B, order_a, order_b, order_c, output_type) , m_indirect_a(indirect_a) - , m_indirect_b(indirect_b) { + , m_indirect_b(indirect_b) + , m_indirect_axis(indirect_axis) { set_argument(2, I); OPENVINO_ASSERT((indirect_a && indirect_b) == false, "[GPU] Gemm supports indirect addressing for one input only"); validate_and_infer_types(); @@ -34,6 +36,7 @@ std::shared_ptr IndirectGemm::clone_with_new_inputs(const ov::OutputVe new_args.at(2), m_indirect_a, m_indirect_b, + m_indirect_axis, m_order_a, m_order_b, m_order_c, @@ -62,6 +65,7 @@ bool IndirectGemm::visit_attributes(ov::AttributeVisitor &visitor) { Gemm::visit_attributes(visitor); visitor.on_attribute("indirect_a", m_indirect_a); visitor.on_attribute("indirect_b", m_indirect_b); + visitor.on_attribute("indirect_axis", m_indirect_axis); return true; } diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/gemm_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/gemm_gpu_test.cpp index 45acc204da033d..8ce9e294a867fe 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/gemm_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/gemm_gpu_test.cpp @@ -740,6 +740,7 @@ class gemm_gpu_tests: public ::testing::Test { beam_table_shape = { BATCH_SIZE, K_SIZE, 1, 1 }; else if (indirect_input1) beam_table_shape = { BATCH_SIZE, 1, 1, K_SIZE }; + int64_t indirect_axis = 0; cldnn::layout input0_layout = layout{ov::PartialShape::dynamic(input0_shape.size()), data_types::f32, format::bfyx}; cldnn::layout input1_layout = layout{ov::PartialShape::dynamic(input1_shape.size()), data_types::f32, format::bfyx}; @@ -761,7 +762,7 @@ class gemm_gpu_tests: public ::testing::Test { topology.add(input_layout("input0", input0_layout), input_layout("input1", input1_layout), input_layout("beam_table", beam_table_layout), - gemm("gemm", { input_info("input0"), input_info("input1") }, input_info("beam_table"), data_types::f32, input0_order, input1_order, {}, indirect_input0, indirect_input1) + gemm("gemm", { input_info("input0"), input_info("input1") }, input_info("beam_table"), data_types::f32, input0_order, input1_order, {}, indirect_input0, indirect_input1, indirect_axis) ); ExecutionConfig config = get_test_default_config(engine); diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/hash_key_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/hash_key_gpu_test.cpp index a4837187b29a6a..2d9cfaa68e1243 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/hash_key_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/hash_key_gpu_test.cpp @@ -127,8 +127,8 @@ class check_hash_value: public ::testing::Test { const auto primitive_hash = primitve->hash(); const auto params_hash = prim_inst->get_impl_params()->hash(); - ASSERT_EQ(primitive_hash, 15839977233203008631UL); - ASSERT_EQ(params_hash, 15375157605915685928UL); + ASSERT_EQ(primitive_hash, 13388149315122571178UL); + ASSERT_EQ(params_hash, 2108356776161884759UL); } void test_permute_basic(bool is_caching_test) { diff --git a/src/plugins/intel_gpu/tests/unit/transformations/indirect_kv_cache_test.cpp b/src/plugins/intel_gpu/tests/unit/transformations/indirect_kv_cache_test.cpp index 074d3420b28636..15f3bffb0ff9ca 100644 --- a/src/plugins/intel_gpu/tests/unit/transformations/indirect_kv_cache_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/transformations/indirect_kv_cache_test.cpp @@ -54,7 +54,7 @@ TEST_F(TransformationTestsF, IndirectKVCache1) { auto past = std::make_shared(variable); auto kv_cache = std::make_shared(past, parameter, beam_idx, variable, 2, 0, ov::element::f32); auto gemm_in = std::make_shared(ov::element::f32, ov::PartialShape{1, 32, -1, -1}); - auto gemm = std::make_shared(gemm_in, kv_cache->output(0), kv_cache->output(1), false, true, + auto gemm = std::make_shared(gemm_in, kv_cache->output(0), kv_cache->output(1), false, true, 0, in0_order, in1_order, out_order); auto result = std::make_shared(gemm); @@ -89,7 +89,7 @@ TEST_F(TransformationTestsF, IndirectKVCache2) { auto past = std::make_shared(variable); auto kv_cache = std::make_shared(past, parameter, beam_idx, variable, 2, 0, ov::element::f32); auto gemm_in = std::make_shared(ov::element::f32, ov::PartialShape{1, 32, -1, -1}); - auto gemm = std::make_shared(kv_cache->output(0), gemm_in, kv_cache->output(1), true, false, + auto gemm = std::make_shared(kv_cache->output(0), gemm_in, kv_cache->output(1), true, false, 0, in0_order, in1_order, out_order); auto result = std::make_shared(gemm); @@ -118,7 +118,17 @@ TEST_F(TransformationTestsF, IndirectKVCache3) { manager.register_pass(); } { - model_ref = model->clone(); + auto variable = std::make_shared(ov::op::util::VariableInfo{{1, 32, -1, 80}, ov::element::f32, "v0"}); + auto parameter = std::make_shared(ov::element::f32, ov::PartialShape{1, 32, -1, 80}); + auto beam_idx = std::make_shared(ov::element::i32, ov::PartialShape{1}); + auto past = std::make_shared(variable); + auto kv_cache = std::make_shared(past, parameter, beam_idx, variable, 2, 1, ov::element::f32); + auto gemm_in = std::make_shared(ov::element::f32, ov::PartialShape{1, 32, -1, -1}); + auto gemm = std::make_shared(gemm_in, kv_cache->output(0), kv_cache->output(1), false, true, 1, + in0_order, in1_order, out_order); + auto result = std::make_shared(gemm); + + model_ref = std::make_shared(ov::ResultVector{result}, ov::ParameterVector{parameter, beam_idx, gemm_in}); comparator.enable(FunctionsComparator::ATTRIBUTES); } }