Skip to content

Commit

Permalink
[GPU] Enable Indirect gemm for non-batch gather axis support (openvin…
Browse files Browse the repository at this point in the history
…otoolkit#24321)

### Details:
- Currently `indirect Gemm` is implemented with an assumption of `KV
Cache`'s gather axis is 0
- Update `KV Cache` and `indirect Gemm` to support non-batch gather axis
support

### Tickets:
 - 137142

---------

Signed-off-by: Andrew Park <andrew.park@intel.com>
  • Loading branch information
andrew-k-park authored May 13, 2024
1 parent dac7b3d commit c9e2ce9
Show file tree
Hide file tree
Showing 18 changed files with 108 additions and 21 deletions.
3 changes: 3 additions & 0 deletions src/plugins/intel_gpu/include/intel_gpu/op/indirect_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class IndirectGemm : public ov::intel_gpu::op::Gemm {
const ov::Output<Node>& I,
bool indirect_a,
bool indirect_b,
int64_t indirect_axis,
const std::vector<int64_t>& order_a,
const std::vector<int64_t>& order_b,
const std::vector<int64_t>& order_c,
Expand All @@ -38,12 +39,14 @@ class IndirectGemm : public ov::intel_gpu::op::Gemm {

bool get_indirect_a() const { return m_indirect_a; }
bool get_indirect_b() const { return m_indirect_b; }
int64_t get_indirect_axis() const { return m_indirect_axis; }

using ov::intel_gpu::op::Gemm::default_order;

protected:
bool m_indirect_a = false;
bool m_indirect_b = false;
int64_t m_indirect_axis = 0;
};

} // namespace op
Expand Down
9 changes: 8 additions & 1 deletion src/plugins/intel_gpu/include/intel_gpu/primitives/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ struct gemm : public primitive_base<gemm> {
const std::vector<int64_t>& output_transpose_order,
bool indirect_a,
bool indirect_b,
int64_t indirect_axis,
const float alpha = 1.0f,
const float beta = 0.0f,
const padding& output_padding = padding())
Expand All @@ -130,7 +131,8 @@ struct gemm : public primitive_base<gemm> {
weight_rank(input1_transpose_order.size()),
beam_table(beam_table),
indirect_a(indirect_a),
indirect_b(indirect_b) {
indirect_b(indirect_b),
indirect_axis(indirect_axis) {
if (inputs.size() != 2 && inputs.size() != 3) {
throw std::invalid_argument("Invalid inputs count - gemm expects either two or three inputs");
}
Expand Down Expand Up @@ -162,13 +164,15 @@ struct gemm : public primitive_base<gemm> {
input_info beam_table = {};
bool indirect_a = false;
bool indirect_b = false;
int64_t indirect_axis = 0;

size_t hash() const override {
size_t seed = primitive::hash();
seed = hash_combine(seed, transpose_input0);
seed = hash_combine(seed, transpose_input1);
seed = hash_combine(seed, indirect_a);
seed = hash_combine(seed, indirect_b);
seed = hash_combine(seed, indirect_axis);
seed = hash_range(seed, input0_transpose_order.begin(), input0_transpose_order.end());
seed = hash_range(seed, input1_transpose_order.begin(), input1_transpose_order.end());
seed = hash_range(seed, output_transpose_order.begin(), output_transpose_order.end());
Expand All @@ -189,6 +193,7 @@ struct gemm : public primitive_base<gemm> {
beta == rhs_casted.beta &&
indirect_a == rhs_casted.indirect_a &&
indirect_b == rhs_casted.indirect_b &&
indirect_axis == rhs_casted.indirect_axis &&
input_rank == rhs_casted.input_rank &&
weight_rank == rhs_casted.weight_rank;
}
Expand All @@ -206,6 +211,7 @@ struct gemm : public primitive_base<gemm> {
ob << weight_rank;
ob << indirect_a;
ob << indirect_b;
ob << indirect_axis;
ob << beam_table.pid;
ob << beam_table.idx;
}
Expand All @@ -223,6 +229,7 @@ struct gemm : public primitive_base<gemm> {
ib >> weight_rank;
ib >> indirect_a;
ib >> indirect_b;
ib >> indirect_axis;
ib >> beam_table.pid;
ib >> beam_table.idx;
}
Expand Down
4 changes: 3 additions & 1 deletion src/plugins/intel_gpu/src/graph/impls/ocl/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ struct gemm_impl : multi_stage_primitive<gemm> {
return false;

const auto& params = *inst.get_impl_params();
if (params.input_layouts[get_beam_table_id(desc)].get_partial_shape()[0].get_length() == 1)
const auto indirect_axis = desc->indirect_axis;
if (params.input_layouts[get_beam_table_id(desc)].get_partial_shape()[indirect_axis].get_length() == 1)
return false;

const auto& deps = inst.dependencies();
Expand Down Expand Up @@ -226,6 +227,7 @@ struct gemm_impl : multi_stage_primitive<gemm> {

params.indirect_input0 = primitive->indirect_a && indirect;
params.indirect_input1 = primitive->indirect_b && indirect;
params.indirect_axis = primitive->indirect_axis;
if (indirect && (primitive->indirect_a || primitive->indirect_b)) {
OPENVINO_ASSERT(impl_param.input_layouts.size() >= 3, "[GPU] Actual inputs count: ", impl_param.input_layouts.size());
params.inputs.push_back(convert_data_tensor(impl_param.input_layouts[get_beam_table_id(primitive)]));
Expand Down
3 changes: 3 additions & 0 deletions src/plugins/intel_gpu/src/graph/impls/ocl/kv_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,9 @@ struct kv_cache_impl : multi_stage_primitive<kv_cache> {
}

static bt_kernel_params_t get_bt_update_kernel_params(const kernel_impl_params& impl_param, bool is_state_set = false) {
const auto& primitive = impl_param.typed_desc<kv_cache>();
auto params = get_default_params<kernel_selector::beam_table_update_params>(impl_param, true);
auto indirect_axis = primitive->gather_axis;

auto inputs_count = 2;
auto bt_present_layout = impl_param.output_layouts[1];
Expand All @@ -260,6 +262,7 @@ struct kv_cache_impl : multi_stage_primitive<kv_cache> {
params.outputs[0] = convert_data_tensor(bt_present_layout);
params.inputs.resize(inputs_count);
params.is_state_set = is_state_set;
params.indirect_axis = indirect_axis;

const auto& in_offsets_map = impl_param.in_port_to_shape_info_offset; // [kv_past, kv_new_token, [beam_idx, beam_table_past]]
const auto& out_offsets_map = impl_param.out_port_to_shape_info_offset; // [kv_present, beam_table_present]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,40 @@ KERNEL(beam_table_update)(
const unsigned int s = (uint)get_global_id(1);

const unsigned int out_offset = b * OUTPUT_BATCH_PITCH + s;
#if INDIRECT_AXIS == 0
const unsigned int in_offset = beam_idx[b] * INPUT0_BATCH_PITCH + s;
#elif INDIRECT_AXIS == 1
const unsigned int in_offset = b * INPUT0_BATCH_PITCH + beam_idx[s];
#else
# error beam_table_update_ref.cl : Unsupported indirect axis for beam table
#endif

if (s >= OUTPUT_BATCH_PITCH)
return;

if (!is_state_set) {
#if INDIRECT_AXIS == 0
state_new[out_offset] = TO_OUTPUT_TYPE(b);
#elif INDIRECT_AXIS == 1
state_new[out_offset] = TO_OUTPUT_TYPE(s);
#else
# error beam_table_update_ref.cl : Unsupported indirect axis for beam table
#endif
} else {
#if INDIRECT_AXIS == 0
if (s < INPUT0_BATCH_PITCH) {
state_new[out_offset] = state_prev[in_offset];
} else {
state_new[out_offset] = b;
state_new[out_offset] = TO_OUTPUT_TYPE(b);
}
#elif INDIRECT_AXIS == 1
if (b < INPUT0_BATCH_NUM) {
state_new[out_offset] = state_prev[in_offset];
} else {
state_new[out_offset] = TO_OUTPUT_TYPE(s);
}
#else
# error beam_table_update_ref.cl : Unsupported indirect axis for beam table
#endif
}
}
16 changes: 14 additions & 2 deletions src/plugins/intel_gpu/src/kernel_selector/cl_kernels/gemm_ref.cl
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,22 @@ KERNEL(gemm_ref)(
uint b0 = b;
uint b1 = b;
#if INDIRECT_INPUT0
b0 = BEAM_TABLE_BATCH_NUM > 1 ? beam_table[FUNC_CALL(get_bt_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, y, ki)] : b;
#if INDIRECT_AXIS == 0
b0 = BEAM_TABLE_BATCH_NUM > 1 ? beam_table[FUNC_CALL(get_bt_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, y, ki)] : b;
#elif INDIRECT_AXIS == 1
b0 = BEAM_TABLE_FEATURE_NUM > 1 ? beam_table[FUNC_CALL(get_bt_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, y, ki)] : b;
#else
# error gemm_ref.cl : Unsupported indirect axis for beam table
#endif
#endif
#if INDIRECT_INPUT1
b1 = BEAM_TABLE_BATCH_NUM > 1 ? beam_table[FUNC_CALL(get_bt_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, ki, x)] : b;
#if INDIRECT_AXIS == 0
b1 = BEAM_TABLE_BATCH_NUM > 1 ? beam_table[FUNC_CALL(get_bt_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, ki, x)] : b;
#elif INDIRECT_AXIS == 1
b1 = BEAM_TABLE_FEATURE_NUM > 1 ? beam_table[FUNC_CALL(get_bt_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, ki, x)] : b;
#else
# error gemm_ref.cl : Unsupported indirect axis for beam table
#endif
#endif

uint in0_idx = FUNC_CALL(get_input0_index)(OPTIONAL_SHAPE_INFO_TENSOR b0, f, w, z, y, ki);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,20 @@ inline uint FUNC(get_bt_index)(OPTIONAL_SHAPE_INFO_ARG uint b, uint f, uint w, u

#if INDIRECT_INPUT0
inline uint FUNC(get_input0_indirect_index)(OPTIONAL_SHAPE_INFO_ARG uint b, uint f, uint w, uint z, uint y, uint x, __global BEAM_TABLE_TYPE* beam_table) {
#if INDIRECT_AXIS == 0
int b_index = BEAM_TABLE_BATCH_NUM > 1 ? beam_table[FUNC_CALL(get_bt_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, y, x)] : b;
#elif INDIRECT_AXIS == 1
int b_index = BEAM_TABLE_FEATURE_NUM > 1 ? beam_table[FUNC_CALL(get_bt_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, y, x)] : b;
#else
# error gemm_tiled_opt.cl : Unsupported indirect axis for beam table
#endif
return FUNC_CALL(get_input0_index)(OPTIONAL_SHAPE_INFO_TENSOR b_index, f, w, z, y, x);
}
#endif

#if INDIRECT_INPUT1
inline uint FUNC(get_input1_indirect_index)(OPTIONAL_SHAPE_INFO_ARG uint b, uint f, uint w, uint z, uint y, uint x, __global BEAM_TABLE_TYPE* beam_table) {
int b_index = BEAM_TABLE_BATCH_NUM > 1 ? beam_table[FUNC_CALL(get_bt_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, y, x)] : b;
int b_index = beam_table[FUNC_CALL(get_bt_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, y, x)];
return FUNC_CALL(get_input1_index)(OPTIONAL_SHAPE_INFO_TENSOR b_index, f, w, z, y, x);
}
#endif
Expand Down Expand Up @@ -218,7 +224,13 @@ KERNEL(gemm_tiled_opt)(
const uint b_raw_global_id = tile_n_offset + sglid;

#if INDIRECT_INPUT0 || INDIRECT_INPUT1
#if INDIRECT_AXIS == 0
const char do_indirect_load = BEAM_TABLE_BATCH_NUM > 1;
#elif INDIRECT_AXIS == 1
const char do_indirect_load = BEAM_TABLE_FEATURE_NUM > 1;
#else
# error gemm_tiled_opt.cl : Unsupported indirect axis for beam table
#endif
#endif

#if TRANSPOSE_INPUT0 != TRANSPOSE_X_LAST
Expand Down Expand Up @@ -344,15 +356,13 @@ KERNEL(gemm_tiled_opt)(
unroll_for (uint b_load_id = 0; b_load_id < TILE_K; b_load_id++) {
uint b_load_offset = (k * TILE_K) + b_load_id;
uint b_idx = FUNC_CALL(get_input1_indirect_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, b_load_offset, x, beam_table);
uint bt_idx = FUNC_CALL(get_bt_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, b_load_offset, x);
b_tile[b_load_id] = b_raw_global_id >= N ? 0 : input1[b_idx];
}
#else
unroll_for (uint b_elem = 0; b_elem < B_VEC_SIZE; ++b_elem) {
unroll_for (uint b_load_id = 0; b_load_id < TILE_K; b_load_id++) {
uint b_load_offset = k * TILE_K + b_load_id;
uint b_idx = FUNC_CALL(get_input1_indirect_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, b_load_offset, x + sglid + SIMD_WIDTH * b_elem, beam_table);
uint bt_idx = FUNC_CALL(get_bt_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, b_load_offset, x + sglid + SIMD_WIDTH * b_elem);
b_tile[b_elem][b_load_id] = b_raw_global_id + SIMD_WIDTH * b_elem >= N ? 0 : input1[b_idx];
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,23 @@ bool BeamTableUpdateKernelRef::Validate(const Params& params) const {
}

JitConstants BeamTableUpdateKernelRef::GetJitConstants(const beam_table_update_params& kernel_params) const {
return MakeBaseParamsJitConstants(kernel_params);
JitConstants jit = MakeBaseParamsJitConstants(kernel_params);
jit.AddConstant({MakeJitConstant("INDIRECT_AXIS", kernel_params.indirect_axis)});
return jit;
}

CommonDispatchData BeamTableUpdateKernelRef::SetDefault(const beam_table_update_params& kernel_params) {
CommonDispatchData dispatch_data;

auto output = kernel_params.outputs[0];
if (!output.is_dynamic()) {
dispatch_data.gws = {output.Batch().v, Align(output.LogicalSize() / output.Batch().v, 16), 1};
dispatch_data.lws = {1, 16, 1};
if (kernel_params.indirect_axis == 0) {
dispatch_data.gws = {output.Batch().v, Align(output.LogicalSize() / output.Batch().v, 16), 1};
dispatch_data.lws = {1, 16, 1};
} else {
dispatch_data.gws = {Align(output.LogicalSize() / output.Feature().v, 16), output.Feature().v, 1};
dispatch_data.lws = {16, 1, 1};
}
}

return dispatch_data;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ namespace kernel_selector {
struct beam_table_update_params : base_params {
beam_table_update_params() : base_params(KernelType::BEAM_TABLE_UPDATE) {}
bool is_state_set = false;
int64_t indirect_axis = 0;
};

class BeamTableUpdateKernelRef : public KernelBaseOpenCL {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ JitConstants GemmKernelBase::GetJitConstants(const gemm_params& params) const {
MakeJitConstant("QUANTIZATION_TERM", params.quantization != QuantizationType::NONE),
MakeJitConstant("INDIRECT_INPUT0", params.indirect_input0),
MakeJitConstant("INDIRECT_INPUT1", params.indirect_input1),
MakeJitConstant("INDIRECT_AXIS", params.indirect_axis),
MakeJitConstant("BEAM_TABLE_TERM", params.indirect_input0 || params.indirect_input1),
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ struct gemm_params : public base_params {
DataTensor beam_table;
bool indirect_input0 = false;
bool indirect_input1 = false;
int64_t indirect_axis = 0;
QuantizationType quantization = QuantizationType::NONE;

ParamsKey GetParamsKey() const override {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ GemmKernelTiledOpt::GemmTuningData GemmKernelTiledOpt::SetTuningParams(const gem
tuning_data.tile_m_size = tuning_data.simd_size;
bool output_ndim_transposed = (params.output_order.size() > 0 && (params.output_order.back() != (static_cast<int>(params.output_order.size()) - 1)));
if ((params.transpose_input0 == 0 /*X_LAST*/) && (params.transpose_input1 == 0 /*X_LAST*/ || params.transpose_input1 == 1 /*Y_LAST*/)
&& (!params.indirect_input0 && !params.inputs[0].has_dynamic_pad())
&& (!params.indirect_input0 && !params.inputs[0].has_dynamic_pad() && params.indirect_axis != 1)
&& (!output_ndim_transposed || params.fused_ops.empty())
&& !params.engineInfo.supports_immad) {
// - Not supports transposed input0 / transposed input1 for OTHER mode yet
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_gpu/src/plugin/ops/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ static void CreateIndirectGemmOp(ProgramBuilder& p, const std::shared_ptr<ov::in
op->get_output_transpose_order(),
op->get_indirect_a(),
op->get_indirect_b(),
op->get_indirect_axis(),
alpha,
beta);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ IndirectKVCache::IndirectKVCache() {
auto gather_input = wrap_type<ov::intel_gpu::op::ReadValue>();
auto axis_const = wrap_type<ov::op::v0::Constant>(
ov::op::util::constant_predicate<int64_t>([](const std::vector<int64_t>& value) -> bool {
return value.size() == 1 && value[0] == 0;
return value.size() == 1 && (value[0] == 0 || value[0] == 1);
}));
auto gather_past = wrap_type<ov::op::v8::Gather>({gather_input, beam_idx, axis_const});
auto kv_cache = wrap_type<ov::intel_gpu::op::KVCache>({gather_past, any_input()});
Expand All @@ -68,14 +68,15 @@ IndirectKVCache::IndirectKVCache() {
auto beam_idx_node = pattern_map.at(beam_idx).get_node_shared_ptr();
auto gather_input_node = pattern_map.at(gather_input).get_node_shared_ptr();
auto gather_node = std::dynamic_pointer_cast<ov::op::v8::Gather>(pattern_map.at(gather_past).get_node_shared_ptr());
auto gather_axis = gather_node->get_axis();
ov::replace_node(gather_node, gather_input_node);

auto indirect_kv_cache = std::make_shared<op::KVCache>(gather_input_node,
kv_cache_node->get_input_node_shared_ptr(1),
beam_idx_node,
kv_cache_node->get_variable(),
kv_cache_node->get_concat_axis(),
gather_node->get_axis(),
gather_axis,
kv_cache_node->get_output_element_type(0));

indirect_kv_cache->set_friendly_name(kv_cache_node->get_friendly_name());
Expand All @@ -95,6 +96,7 @@ IndirectKVCache::IndirectKVCache() {
indirect_kv_cache->output(1), // beam table
matmul_kv_cache_index == 0,
matmul_kv_cache_index == 1,
gather_axis,
order_in0,
order_in1,
order_out);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@ IndirectGemm::IndirectGemm(const ov::Output<Node>& A,
const ov::Output<Node>& I,
bool indirect_a,
bool indirect_b,
int64_t indirect_axis,
const std::vector<int64_t>& order_a,
const std::vector<int64_t>& order_b,
const std::vector<int64_t>& order_c,
const ov::element::Type output_type)
: ov::intel_gpu::op::Gemm(A, B, order_a, order_b, order_c, output_type)
, m_indirect_a(indirect_a)
, m_indirect_b(indirect_b) {
, m_indirect_b(indirect_b)
, m_indirect_axis(indirect_axis) {
set_argument(2, I);
OPENVINO_ASSERT((indirect_a && indirect_b) == false, "[GPU] Gemm supports indirect addressing for one input only");
validate_and_infer_types();
Expand All @@ -34,6 +36,7 @@ std::shared_ptr<ov::Node> IndirectGemm::clone_with_new_inputs(const ov::OutputVe
new_args.at(2),
m_indirect_a,
m_indirect_b,
m_indirect_axis,
m_order_a,
m_order_b,
m_order_c,
Expand Down Expand Up @@ -62,6 +65,7 @@ bool IndirectGemm::visit_attributes(ov::AttributeVisitor &visitor) {
Gemm::visit_attributes(visitor);
visitor.on_attribute("indirect_a", m_indirect_a);
visitor.on_attribute("indirect_b", m_indirect_b);
visitor.on_attribute("indirect_axis", m_indirect_axis);
return true;
}

Expand Down
Loading

0 comments on commit c9e2ce9

Please sign in to comment.