Skip to content

Commit

Permalink
[GPU] Fixes for speculative decoding (openvinotoolkit#26578)
Browse files Browse the repository at this point in the history
### Details:
- This PR adds additional kernel variation for proper speculative
decoding handling and minor improvements of paged attention

### Tickets:
 - [CVS-138903](https://jira.devtools.intel.com/browse/CVS-138903)
  • Loading branch information
sshlyapn authored Sep 13, 2024
1 parent 41b502a commit 661cc03
Show file tree
Hide file tree
Showing 7 changed files with 157 additions and 21 deletions.
76 changes: 66 additions & 10 deletions src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,15 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
PA_SDPA,
};

bool requires_update(primitive_inst& inst, const kernel_impl_params& impl_params) const override {
const auto stage = get_paged_attention_stage(impl_params);

// In case of MIXED mode execution Paged Attention may require dispatch data update and internal
// buffers reallocation even if the input shapes haven't been changed. Therefore, check the current execution
// mode and update parameters if needed
return stage == PagedAttentionStage::MIXED;
}

void load(BinaryInputBuffer& ib) override {
parent::load(ib);
if (is_dynamic()) {
Expand Down Expand Up @@ -90,7 +99,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
return layouts;
}

kernel_arguments_data get_arguments(const paged_attention_inst& instance, size_t stage, size_t kernel_idx) const {
kernel_arguments_data get_arguments(const paged_attention_inst& instance, size_t stage, size_t kernel_idx, bool is_mixed_mode) const {
const auto desc = instance.get_node().as<paged_attention>().get_primitive();

kernel_arguments_data args;
Expand Down Expand Up @@ -129,7 +138,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
instance.block_indices_memory_ptr(),
instance.block_indices_begins_memory_ptr() };

if (kernel_idx == 1) {
if (is_mixed_mode) {
// Multi tokens kernel version has additional subsequence_begins_memory memory
// dependency
args.inputs.push_back(instance.subsequence_begins_memory_ptr());
Expand All @@ -140,6 +149,12 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
}
} else {
args.inputs = { instance.past_lens_memory_ptr() };

if (is_mixed_mode) {
// Multi tokens kernel version has additional subsequence_begins_memory memory
// dependency
args.inputs.push_back(instance.subsequence_begins_memory_ptr());
}
}

args.outputs = { instance.output_memory_ptr(0) };
Expand All @@ -153,7 +168,11 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
6, /* PA_SDPA multiple tokens mode */ };
};

void execute_stage(const std::vector<event::ptr>& events, paged_attention_inst& instance, std::vector<event::ptr>& all_events, size_t stage) {
void execute_stage(const std::vector<event::ptr>& events,
paged_attention_inst& instance,
std::vector<event::ptr>& all_events,
size_t stage,
bool is_mixed_mode) {
stream& stream = instance.get_network().get_stream();
std::vector<event::ptr> tmp_events(events);
size_t kernel_offset = 0;
Expand Down Expand Up @@ -181,7 +200,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {

auto& params = _kernels_data[stage].kernels[kd_idx].params;

auto args = get_arguments(instance, stage, kd_idx);
auto args = get_arguments(instance, stage, kd_idx, is_mixed_mode);
args.scalars = &params.scalars;

const auto& intermediate_memories = instance.get_intermediates_memories();
Expand Down Expand Up @@ -211,14 +230,15 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
event::ptr execute_impl(const std::vector<event::ptr>& events, paged_attention_inst& instance) override {
std::vector<event::ptr> res_events;
const auto stage = get_paged_attention_stage(*instance.get_impl_params());
const auto is_mixed_mode = stage == PagedAttentionStage::MIXED;

execute_stage(events, instance, res_events, Stage::KV_CACHE_UPDATE);
execute_stage(events, instance, res_events, Stage::KV_CACHE_UPDATE, is_mixed_mode);

std::vector<event::ptr> dep_events(res_events.begin(), res_events.end());
if (stage == PagedAttentionStage::PREFILL) {
execute_stage(dep_events, instance, res_events, Stage::SDPA);
execute_stage(dep_events, instance, res_events, Stage::SDPA, is_mixed_mode);
} else if (stage == PagedAttentionStage::GENERATE || stage == PagedAttentionStage::MIXED) {
execute_stage(dep_events, instance, res_events, Stage::PA_SDPA);
execute_stage(dep_events, instance, res_events, Stage::PA_SDPA, is_mixed_mode);
}

return instance.get_network().get_stream().aggregate_events(res_events, res_events.size() > 1);
Expand Down Expand Up @@ -248,9 +268,45 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
mem_lock<int32_t, mem_lock_type::read> subsequence_begins_mem_lock(subsequence_begins_mem, *impl_param.strm);

auto aligned_seq_len = 0;
for (size_t i = 0; i < subsequence_begins_mem_lock.size() - 1; i++) {
auto prompt_length = subsequence_begins_mem_lock[i + 1] - subsequence_begins_mem_lock[i];
aligned_seq_len += align_to(prompt_length, target_seq_len_block_size);
if (stage == PagedAttentionStage::MIXED) {
const auto past_lens_idx = 5;
const auto past_lens_mem = input_mem.at(past_lens_idx);
mem_lock<int32_t, mem_lock_type::read> past_lens_mem_lock(past_lens_mem, *impl_param.strm);

for (size_t i = 0; i < subsequence_begins_mem_lock.size() - 1; i++) {
auto past_len = past_lens_mem_lock[i];
auto seq_length = subsequence_begins_mem_lock[i + 1] - subsequence_begins_mem_lock[i];

// Since in MIXED execution mode the present KV-cache can be appended to the past KV-cache at any offset inside block,
// to ensure proper alignment and update_kv_cache kernel scheduling, we need to account for the number of unaligned tokens
// in the first block
// For example, if we need to store values in the following slots:
//
// block0: |O|O|O|O|O|O|O|O|O|O|O|O|U|U|U|U|
// block1: |U|U|U|U|U|U|U|U|U|U|U|U|U|U|U|U|
// block2: |U|U|U|U|U|U|E|E|E|E|E|E|E|E|E|E|
// Where O - occupied slots, U - currently beeing updated slots, E - empty slots
//
// We need to schedule 3 update_kv_cache operations:
// - For ranges of block0: [12-15]
// - For ranges of block1: [0-15]
// - For ranges of block2: [0-5]
//
// Therefore, consider an additional increment of aligned_seq_len to properly process all the blocks

auto occupied_slots_num = past_len % target_seq_len_block_size;
if (past_len != 0 && seq_length + occupied_slots_num > target_seq_len_block_size) {
aligned_seq_len += target_seq_len_block_size;
seq_length -= target_seq_len_block_size - occupied_slots_num;
}

aligned_seq_len += align_to(seq_length, target_seq_len_block_size);
}
} else {
for (size_t i = 0; i < subsequence_begins_mem_lock.size() - 1; i++) {
auto prompt_length = subsequence_begins_mem_lock[i + 1] - subsequence_begins_mem_lock[i];
aligned_seq_len += align_to(prompt_length, target_seq_len_block_size);
}
}

return aligned_seq_len;
Expand Down
5 changes: 5 additions & 0 deletions src/plugins/intel_gpu/src/graph/include/primitive_inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ struct primitive_impl {
OPENVINO_ASSERT(false, "[GPU] update() is not implemented for dynamic implemenation ", _kernel_name);
}

virtual bool requires_update(primitive_inst& inst, const kernel_impl_params& impl_params) const {
OPENVINO_ASSERT(_is_dynamic, "[GPU] requires_update() is called for static shape implementation ", _kernel_name);
return false;
}

static kernel_impl_params static_canonicalize_shapes(const kernel_impl_params& impl_params);

virtual kernel_impl_params canonicalize_shapes(const kernel_impl_params& impl_params) const {
Expand Down
21 changes: 20 additions & 1 deletion src/plugins/intel_gpu/src/graph/paged_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ void paged_attention_inst::on_execute() {
const auto blocks_indexes_end_idx = 1;
const auto blocked_gws_subseq_mapping_idx = 2;

const auto past_lens_mem = past_lens_memory_ptr();
auto subsequence_begins_mem = subsequence_begins_memory_ptr();
auto blocks_indexes_start_mem = _intermediates_memory[blocks_indexes_start_idx];
auto blocks_indexes_end_mem = _intermediates_memory[blocks_indexes_end_idx];
Expand All @@ -101,6 +102,7 @@ void paged_attention_inst::on_execute() {
OPENVINO_ASSERT(subsequence_begins_mem->get_layout().data_type == data_types::i32);

auto& stream = get_network().get_stream();
mem_lock<int32_t, mem_lock_type::read> past_lens_mem_lock(past_lens_mem, stream);
mem_lock<int32_t, mem_lock_type::read> subsequence_begins_mem_lock(subsequence_begins_mem, stream);
mem_lock<int32_t, mem_lock_type::write> blocks_indexes_start_lock(blocks_indexes_start_mem, stream);
mem_lock<int32_t, mem_lock_type::write> blocks_indexes_end_lock(blocks_indexes_end_mem, stream);
Expand All @@ -120,11 +122,28 @@ void paged_attention_inst::on_execute() {
size_t index = 0;
const auto target_seq_len_block_size = 16; // TODO: Get block size from the impl
for (size_t i = 0; i < subsequence_begins_mem_lock.size() - 1; i++) {
const auto past_len = past_lens_mem_lock[i];
const auto seq_start = subsequence_begins_mem_lock[i];
const auto seq_end = subsequence_begins_mem_lock[i + 1];
const auto seq_length = seq_end - seq_start;

for (int32_t j = 0; j < seq_length; j += target_seq_len_block_size) {
int32_t j = 0;
if (past_len != 0) {
auto block_start_pos = seq_start;
auto empty_slots = target_seq_len_block_size - (past_len % target_seq_len_block_size);
auto block_end_pos = seq_start + std::min(empty_slots, seq_length);

blocks_indexes_start_lock[index] = block_start_pos;
blocks_indexes_end_lock[index] = block_end_pos;
blocked_gws_subseq_mapping_mem_lock[index] = static_cast<int32_t>(i);

index++;

auto added_slots = block_end_pos - block_start_pos;
j += added_slots;
}

for (; j < seq_length; j += target_seq_len_block_size) {
auto block_start_pos = subsequence_begins_mem_lock[i] + j;
auto block_end_pos = std::min(block_start_pos + target_seq_len_block_size, seq_end);

Expand Down
14 changes: 14 additions & 0 deletions src/plugins/intel_gpu/src/graph/primitive_inst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1648,6 +1648,7 @@ event::ptr primitive_inst::execute(const std::vector<event::ptr>& events) {
// Try update impl if current impl is dynamic because opt kernel may be added to impl cache through async compilation.
// Only try update weight and realloc when impl is updated.
const bool can_use_async_compilation = use_async_compilation();
bool is_updated = false;
if (shape_changed() || !_impl || (!shape_changed() && _impl->is_dynamic() && can_use_async_compilation)) {
if (update_impl(can_use_async_compilation)) {
need_args_update = true;
Expand All @@ -1657,9 +1658,22 @@ event::ptr primitive_inst::execute(const std::vector<event::ptr>& events) {
auto ev_reset = realloc_if_needed();
if (ev_reset)
dependencies.push_back(ev_reset);

is_updated = true;
}
}

// Paged Attention may require dispatch data update and internal buffers reallocation
// even if the input shapes haven't been changed
if (_node->is_type<paged_attention>() && !is_updated && _impl->requires_update(*this, *_impl_params)) {
_impl->update(*this, *_impl_params);

need_args_update = true;
auto ev_reset = realloc_if_needed();
if (ev_reset)
dependencies.push_back(ev_reset);
}

OPENVINO_ASSERT(_impl_params->get_output_layout().is_static(),
"[GPU] Can't execute ", primitive_id, " primitive as output layout is dynamic in runtime");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

#include "include/batch_headers/common.cl"

REQD_SUB_GROUP_SIZE(SUBGROUP_SIZE)
__attribute__((reqd_work_group_size(1, 1, SUBGROUP_SIZE)))
KERNEL(pa_kv_cache_update)(
OPTIONAL_SHAPE_INFO_ARG
__global const INPUT0_TYPE* key_data,
Expand Down Expand Up @@ -77,20 +79,25 @@ KERNEL(pa_kv_cache_update)(
const uint block_start_pos = blocked_indexes_start[block_idx];
const uint block_end_pos = blocked_indexes_end[block_idx];
const uint tokens_num = block_end_pos - block_start_pos;
const uint past_len = past_lens[subsequence_idx];

const uint token_start_pos = (past_len + block_start_pos - subsequence_begin_idx) % PAGED_ATTENTION_BLOCK_SIZE;

uint key_value_in_offset = block_start_pos * KV_HEADS_NUM * HEAD_SIZE +
head_idx * HEAD_SIZE;

const uint cached_blocks_num = past_lens[subsequence_idx] / PAGED_ATTENTION_BLOCK_SIZE;
const uint current_block_idx = (block_start_pos - subsequence_begin_idx) / PAGED_ATTENTION_BLOCK_SIZE;
const uint current_block_idx = (past_len + block_start_pos - subsequence_begin_idx) / PAGED_ATTENTION_BLOCK_SIZE;

const uint block_offset = block_indices_begins[subsequence_idx] + cached_blocks_num + current_block_idx;
const uint block_offset = block_indices_begins[subsequence_idx] + current_block_idx;

uint key_out_offset = block_indices[block_offset] * KV_HEADS_NUM * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE +
head_idx * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE;

uint value_out_offset = key_out_offset;

key_out_offset += token_start_pos;
value_out_offset += token_start_pos * HEAD_SIZE;

if (tokens_num == PAGED_ATTENTION_BLOCK_SIZE) {
unroll_for (uint token_num = 0; token_num < PAGED_ATTENTION_BLOCK_SIZE; token_num++) {
uint head_idx_index = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,11 @@ KERNEL(pa_sdpa_opt)(
// TODO: const uint global_data_idx = partition_idx * SEQ_LEN_PARTITION_SIZE + local_data_idx
const uint global_data_idx = partition_idx * SEQ_LEN_PARTITION_SIZE + qk_idx * (SUBGROUPS_PER_WG * SUBGROUP_SIZE) + sgid * SUBGROUP_SIZE + sglid;

#if SEQ_LEN_PARTITION_SIZE % SUBGROUPS_PER_WG * SUBGROUP_SIZE == 0
if (global_data_idx < seq_len) {
#else
if (global_data_idx < seq_len && local_data_idx < SEQ_LEN_PARTITION_SIZE) {
#endif
SOFTMAX_ACCUMULATOR_TYPE qk_new = native_exp(TO_SOFTMAX_ACCUMULATOR_TYPE(slm_qk_vals[local_data_idx]) - qk_max);
slm_qk_vals[local_data_idx] = TO_OUTPUT_TYPE(qk_new);

Expand All @@ -242,7 +246,11 @@ KERNEL(pa_sdpa_opt)(
const uint local_data_idx = qk_idx * (SUBGROUPS_PER_WG * SUBGROUP_SIZE) + sgid * SUBGROUP_SIZE + sglid;
const uint global_data_idx = partition_idx * SEQ_LEN_PARTITION_SIZE + qk_idx * (SUBGROUPS_PER_WG * SUBGROUP_SIZE) + sgid * SUBGROUP_SIZE + sglid;

#if SEQ_LEN_PARTITION_SIZE % SUBGROUPS_PER_WG * SUBGROUP_SIZE == 0
if (global_data_idx < seq_len) {
#else
if (global_data_idx < seq_len && local_data_idx < SEQ_LEN_PARTITION_SIZE) {
#endif
SOFTMAX_ACCUMULATOR_TYPE qk_new = TO_SOFTMAX_ACCUMULATOR_TYPE(slm_qk_vals[local_data_idx]) / exp_sum;
slm_qk_vals[local_data_idx] = TO_OUTPUT_TYPE(qk_new);
}
Expand Down Expand Up @@ -351,17 +359,29 @@ KERNEL(pa_sdpa_opt)(
REQD_SUB_GROUP_SIZE(SUBGROUP_SIZE)
KERNEL(pa_sdpa_finalization_stage)(
const __global INPUT3_TYPE* past_lens,
#if MULTI_TOKENS_PROCESSING
const __global INPUT6_TYPE* subsequence_begins,
#endif
__global OUTPUT_TYPE* output,
const __global SOFTMAX_ACCUMULATOR_TYPE* exp_sums,
const __global SOFTMAX_ACCUMULATOR_TYPE* max_logits,
const __global OUTPUT_TYPE* tmp_out,
#if MULTI_TOKENS_PROCESSING
const __global int* gws_subseq_mapping,
#endif
const uint total_partitions_num) {
const uint seq_idx = get_global_id(0);
const uint head_num_idx = get_global_id(1);
const uint head_size_idx = get_global_id(2);
const uint sglid = get_sub_group_local_id();

#if MULTI_TOKENS_PROCESSING
const int subsequence_idx = gws_subseq_mapping[seq_idx];
const int subsequence_begin = subsequence_begins[subsequence_idx];
const uint seq_len = past_lens[subsequence_idx] + 1 + (seq_idx - subsequence_begin);
#else
const uint seq_len = past_lens[seq_idx] + 1;
#endif

const uint num_of_partitions = CEIL_DIV(seq_len, SEQ_LEN_PARTITION_SIZE);

Expand Down
Loading

0 comments on commit 661cc03

Please sign in to comment.