Skip to content

Commit

Permalink
[GPU] Use parallel sum reduction in RMS BFYX OPT kernel (#25896)
Browse files Browse the repository at this point in the history
### Details:
 - Use parallel sum reduction for RMS BFYX OPT kernel
 - Improve heuristics

### Tickets:
 - 148937

Co-authored-by: Pavel Durandin <pavel.durandin@intel.com>
  • Loading branch information
dnkurek and p-durandin authored Aug 6, 2024
1 parent 5ec4375 commit decdac6
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,15 @@ KERNEL(rms_gpu_bfyx_opt)(
slm_buf[get_sub_group_id()] = rms;

barrier(CLK_LOCAL_MEM_FENCE);
if (in_data_idx == 0) {
for (uint i = 1; i < get_num_sub_groups(); ++i)
{
rms += slm_buf[i];
for (uint offset = get_num_sub_groups() / 2; offset > 0; offset /= 2) {
if (in_data_idx < offset) {
slm_buf[in_data_idx] += slm_buf[in_data_idx + offset];
}
rms = rms / data_size;
barrier(CLK_LOCAL_MEM_FENCE);
}

if (in_data_idx == 0) {
rms = slm_buf[0] / data_size;
slm_buf[0] = native_powr(sqrt(rms + TO_ACCUMULATOR_TYPE(EPSILON)), -1);
}
barrier(CLK_LOCAL_MEM_FENCE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ RMSKernelBase::DispatchData RMSKernelBfyxOpt::SetDefault(const rms_params& param

dispatchData.itemsNum = dispatchData.dataSize;
// Compute maximum possible LWS that does not exceed device capabilities and optimizes number of global memory reads
while ((dispatchData.itemsNum > 32 || dispatchData.lws[0] < dispatchData.itemsNum) && (2 * dispatchData.lws[0] <= max_lws)) {
while ((dispatchData.itemsNum > 8 || dispatchData.lws[0] < dispatchData.itemsNum) && (2 * dispatchData.lws[0] <= max_lws)) {
dispatchData.lws[0] *= 2;
dispatchData.itemsNum /= 2;
}
Expand Down

0 comments on commit decdac6

Please sign in to comment.