Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPU] Optimize RMS Stack Size for Better Performance #26515

Merged
merged 11 commits into from
Sep 11, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#include "kernel_selector_utils.h"
#include <string>

#define MAX_ITEMS_NUM 8

namespace kernel_selector {
static constexpr size_t subgroup_size = 16;
ParamsKey RMSKernelBfyxOpt::GetSupportedKey() const {
Expand Down Expand Up @@ -58,9 +60,9 @@ JitConstants RMSKernelBfyxOpt::GetJitConstants(const rms_params& params, Dispatc
}

const std::string lws_0 = "get_local_size(0)";
// It can be expected that the maximum possible itemsNum will not exceed 32
// Therefore, in dynamic shape, stack_size including additional buffer is set to 33
constexpr size_t stack_size = 33;
// It can be expected that the maximum possible itemsNum will not exceed MAX_ITEMS_NUM
// Therefore, in dynamic shape, stack_size including additional buffer is set to MAX_ITEMS_NUM
constexpr size_t stack_size = MAX_ITEMS_NUM;
vladimir-paramuzov marked this conversation as resolved.
Show resolved Hide resolved
jit.AddConstants({
MakeJitConstant("DATA_SIZE", data_size),
MakeJitConstant("LWS", lws_0),
Expand Down Expand Up @@ -120,7 +122,7 @@ RMSKernelBase::DispatchData RMSKernelBfyxOpt::SetDefault(const rms_params& param

dispatchData.itemsNum = dispatchData.dataSize;
Copy link
Contributor

@dnkurek dnkurek Sep 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hint: maybe check if dispatchData.dataSize is actually constant and won't change when executing LLM, therefore we know beforehand what will always dispatchData.itemsNum be in each execution. In other words, maybe the only thing that changes is dataCount

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Qwen, dataSize is always equal to 896.

Copy link
Contributor

@dnkurek dnkurek Sep 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, so that looks to me like a possible optimization there, since now you exactly know how much private memory you will need and better use the GPU's resources

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And also you would know LWS and therefore you could use reqd_work_group_size

Copy link
Contributor

@dnkurek dnkurek Sep 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also 896 seems relatively not a lot. It makes sense why reducing stack size would improve performance, since you are basically freeing up resources that are not used.

// Compute maximum possible LWS that does not exceed device capabilities and optimizes number of global memory reads
while ((dispatchData.itemsNum > 8 || dispatchData.lws[0] < dispatchData.itemsNum) && (2 * dispatchData.lws[0] <= max_lws)) {
while ((dispatchData.itemsNum > MAX_ITEMS_NUM || dispatchData.lws[0] < dispatchData.itemsNum) && (2 * dispatchData.lws[0] <= max_lws)) {
dispatchData.lws[0] *= 2;
dispatchData.itemsNum /= 2;
}
Expand Down
Loading