Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add microbenchmark for layer normalization and improve latency #22223

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
2b8cd17
Add microbenchmark for layer normalization
amarin16 Sep 25, 2024
0c89631
fix warnings
amarin16 Sep 25, 2024
bca13ca
initialize test input data at compile time
amarin16 Sep 26, 2024
680cf4f
remove unused specialization that fails on pipeline
amarin16 Sep 26, 2024
f0df526
fix build on linux
amarin16 Sep 30, 2024
87725c3
convert all inputs to float efficiently if needed
amarin16 Sep 30, 2024
8aa80da
convert output buffer efficiently in layer_norm_impl
amarin16 Sep 30, 2024
295d652
convert output buffer efficiently in skip_layer_norm
amarin16 Sep 30, 2024
405a0a0
add inline and fix some lint issues
amarin16 Sep 30, 2024
245f298
fix some lint errors
amarin16 Sep 30, 2024
f398b64
fix warning
amarin16 Sep 30, 2024
a483ca4
maybe_unused
amarin16 Oct 1, 2024
19d225a
Fix bug
amarin16 Oct 1, 2024
05b5037
separate MLFloat16 implementation in skip_layer_norm
amarin16 Oct 1, 2024
ab2e5f2
fix linter issues
amarin16 Oct 1, 2024
63e9644
fix precision warning
amarin16 Oct 1, 2024
11eb7fb
cast
amarin16 Oct 2, 2024
46775a7
separate implementation for MLFloat16 inside layer_norm_impl
amarin16 Oct 2, 2024
fd904f6
don't use vectors
amarin16 Oct 2, 2024
a41b802
reuse allocated arrays when possible
amarin16 Oct 2, 2024
6aece95
make_unique instead of new
amarin16 Oct 2, 2024
766c4b2
Revert "make_unique instead of new" for latency
amarin16 Oct 2, 2024
cb55d4b
lint
amarin16 Oct 2, 2024
2895f37
fix bug
amarin16 Oct 2, 2024
f93ccb7
fix bug
amarin16 Oct 2, 2024
4be0255
handle errors
amarin16 Oct 3, 2024
48ce979
remove checks on tensor data
amarin16 Oct 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1128,7 +1128,8 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
${BENCHMARK_DIR}/gelu.cc
${BENCHMARK_DIR}/activation.cc
${BENCHMARK_DIR}/quantize.cc
${BENCHMARK_DIR}/reduceminmax.cc)
${BENCHMARK_DIR}/reduceminmax.cc
${BENCHMARK_DIR}/layer_normalization.cc)
target_include_directories(onnxruntime_benchmark PRIVATE ${ONNXRUNTIME_ROOT} ${onnxruntime_graph_header} ${ONNXRUNTIME_ROOT}/core/mlas/inc)
target_compile_definitions(onnxruntime_benchmark PRIVATE BENCHMARK_STATIC_DEFINE)
if(WIN32)
Expand Down
133 changes: 88 additions & 45 deletions onnxruntime/contrib_ops/cpu/skip_layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

#include "core/framework/tensor.h"
#include "core/mlas/inc/mlas.h"
#include "core/util/math_cpuonly.h"
#include "core/providers/common.h"
#include "core/platform/threadpool.h"
Expand Down Expand Up @@ -36,49 +37,56 @@
REGISTER_KERNEL_TYPED(double)
REGISTER_KERNEL_TYPED(MLFloat16)

// Utility to convert from MLFloat16 to float only when the input type is MLFloat16.
template <typename T, typename Ret>
ORT_FORCEINLINE Ret ConvertMLFloat16ToDoubleOrFloatIfNeeded(T val);

template <>
ORT_FORCEINLINE float ConvertMLFloat16ToDoubleOrFloatIfNeeded<MLFloat16, float>(MLFloat16 val) {
return val.ToFloat();
}
namespace {

template <>
ORT_FORCEINLINE double ConvertMLFloat16ToDoubleOrFloatIfNeeded<MLFloat16, double>(MLFloat16 val) {
return static_cast<double>(ConvertMLFloat16ToDoubleOrFloatIfNeeded<MLFloat16, float>(val));
ORT_FORCEINLINE double* CreateBufferIfMLFloat16(double* p_output, int num_elems) {
return p_output;
}

template <>
ORT_FORCEINLINE constexpr float ConvertMLFloat16ToDoubleOrFloatIfNeeded<float, float>(float val) {
return val;
ORT_FORCEINLINE float* CreateBufferIfMLFloat16(float* p_output, int num_elems) {
return p_output;
}

template <>
ORT_FORCEINLINE constexpr double ConvertMLFloat16ToDoubleOrFloatIfNeeded<double, double>(double val) {
return val;
ORT_FORCEINLINE float* CreateBufferIfMLFloat16(MLFloat16* p_output, int num_elems) {
return p_output == nullptr ? nullptr : new float[num_elems];
Fixed Show fixed Hide fixed
}

// Function template that only converts the input value to MLFloat16 if T is MLFloat16.

template <typename T>
ORT_FORCEINLINE constexpr typename std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double>, T>
ConvertDoubleOrFloatToMLFloat16IfNeeded(T val) {
return val;
}
ORT_FORCEINLINE std::shared_ptr<std::vector<float>> ConvertHalfToFloatBufferIfNeeded(const T* p_input, int num_elems);

template <typename T>
ORT_FORCEINLINE constexpr typename std::enable_if_t<std::is_same_v<T, MLFloat16>, T>
ConvertDoubleOrFloatToMLFloat16IfNeeded(float val) {
return MLFloat16(val);
ORT_FORCEINLINE std::shared_ptr<std::vector<float>> ConvertHalfToFloatBufferIfNeeded(
const std::enable_if_t<std::is_same_v<T,float> || std::is_same_v<T, double>, T>* p_input, int num_elems) {

Check warning on line 61 in onnxruntime/contrib_ops/cpu/skip_layer_norm.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing space after , [whitespace/comma] [3] Raw Output: onnxruntime/contrib_ops/cpu/skip_layer_norm.cc:61: Missing space after , [whitespace/comma] [3]
return nullptr;
}

template <typename T>
ORT_FORCEINLINE constexpr typename std::enable_if_t<std::is_same_v<T, MLFloat16>, T>
ConvertDoubleOrFloatToMLFloat16IfNeeded(double val) {
return MLFloat16(static_cast<float>(val));
template<>
std::shared_ptr<std::vector<float>> ConvertHalfToFloatBufferIfNeeded<MLFloat16>(const MLFloat16* p_input, int num_elems) {

Check warning on line 66 in onnxruntime/contrib_ops/cpu/skip_layer_norm.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cpu/skip_layer_norm.cc:66: Lines should be <= 120 characters long [whitespace/line_length] [2]
if (!p_input) {
return nullptr;
}

// Efficiently convert all the MLFloat16 values to floats.
std::shared_ptr<std::vector<float>> vec = std::make_shared<std::vector<float>>(num_elems);
MlasConvertHalfToFloatBuffer(p_input, &(*vec)[0], num_elems);

return vec;
}


void ConvertFloatBufferToMLFloat16(const float* output_buffer, MLFloat16* p_output, int num_elems) {
if (!output_buffer || !p_output) {
return;
}

MlasConvertFloatToHalfBuffer(output_buffer, p_output, num_elems);
}

} // namespace

Check warning on line 87 in onnxruntime/contrib_ops/cpu/skip_layer_norm.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 At least two spaces is best between code and comments [whitespace/comments] [2] Raw Output: onnxruntime/contrib_ops/cpu/skip_layer_norm.cc:87: At least two spaces is best between code and comments [whitespace/comments] [2]


template <typename T, bool simplified>
SkipLayerNorm<T, simplified>::SkipLayerNorm(const OpKernelInfo& op_kernel_info)
: OpKernel(op_kernel_info) {
Expand Down Expand Up @@ -145,25 +153,46 @@
DoubleOrFloat mean(0.0f);
DoubleOrFloat mean_square(0.0f);

std::unique_ptr<DoubleOrFloat[]> output_buffer = std::make_unique<DoubleOrFloat[]>(hidden_size);
for (size_t h = 0; h < static_cast<size_t>(hidden_size); h++) {
DoubleOrFloat input_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded<T, DoubleOrFloat>(p_input[h]);
DoubleOrFloat skip_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded<T, DoubleOrFloat>(p_skip[h]);
std::shared_ptr<std::vector<float>> float_input = ConvertHalfToFloatBufferIfNeeded<T>(p_input, hidden_size);
const DoubleOrFloat* converted_input =
float_input == nullptr
? reinterpret_cast<const DoubleOrFloat*>(p_input)
: reinterpret_cast<const DoubleOrFloat*>(&(*float_input)[0]);
std::shared_ptr<std::vector<float>> float_skip = ConvertHalfToFloatBufferIfNeeded<T>(p_skip, hidden_size);
const DoubleOrFloat* converted_skip =
float_skip == nullptr
? reinterpret_cast<const DoubleOrFloat*>(p_skip)
: reinterpret_cast<const DoubleOrFloat*>(&(*float_skip)[0]);
std::shared_ptr<std::vector<float>> float_bias = ConvertHalfToFloatBufferIfNeeded<T>(bias_data, hidden_size);
const DoubleOrFloat* converted_bias =
float_bias == nullptr
? reinterpret_cast<const DoubleOrFloat*>(bias_data)
: reinterpret_cast<const DoubleOrFloat*>(&(*float_bias)[0]);

// If T is float or double, then output_buffer will be the same as p_output, so we don't allocate new memory.
// If T is MLFloat16, then we allocate hidden_size floats in output_buffer.
DoubleOrFloat* output_buffer = static_cast<DoubleOrFloat*>(CreateBufferIfMLFloat16(p_output, hidden_size));

DoubleOrFloat value = input_value + skip_value;
for (size_t h = 0; h < static_cast<size_t>(hidden_size); h++) {
DoubleOrFloat val = converted_input[h] + converted_skip[h];

if (nullptr != bias_data) {
value += ConvertMLFloat16ToDoubleOrFloatIfNeeded<T, DoubleOrFloat>(bias_data[h]);
val += converted_bias[h];
}

output_buffer[h] = value;
T converted_value = ConvertDoubleOrFloatToMLFloat16IfNeeded<T>(value);
if (nullptr != p_skip_input_bias_add_output_data) {
p_skip_input_bias_add_output_data[h] = converted_value;
output_buffer[h] = val;
mean += val;
mean_square += val * val;

if (nullptr != p_skip_input_bias_add_output_data && (std::is_same_v<T, float> || std::is_same_v<T, double>)) {
p_skip_input_bias_add_output_data[h] = *(reinterpret_cast<T*>(&val));
}
}

mean += value;
mean_square += value * value;
if (nullptr != p_skip_input_bias_add_output_data && std::is_same_v<T, MLFloat16>) {
ConvertFloatBufferToMLFloat16(reinterpret_cast<float*>(output_buffer),
reinterpret_cast<MLFloat16*>(p_skip_input_bias_add_output_data),
hidden_size);
}

mean = mean / hidden_size;
Expand All @@ -173,17 +202,31 @@
mean_square = sqrt(mean_square / hidden_size - mean * mean + epsilon_);
}

std::shared_ptr<std::vector<float>> float_gamma = ConvertHalfToFloatBufferIfNeeded<T>(gamma_data, hidden_size);
const DoubleOrFloat* converted_gamma =
float_gamma == nullptr
? reinterpret_cast<const DoubleOrFloat*>(gamma_data)
: reinterpret_cast<const DoubleOrFloat*>(&(*float_gamma)[0]);
std::shared_ptr<std::vector<float>> float_beta = ConvertHalfToFloatBufferIfNeeded<T>(beta_data, hidden_size);

Check warning on line 210 in onnxruntime/contrib_ops/cpu/skip_layer_norm.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for shared_ptr<> [build/include_what_you_use] [4] Raw Output: onnxruntime/contrib_ops/cpu/skip_layer_norm.cc:210: Add #include <memory> for shared_ptr<> [build/include_what_you_use] [4]

Check warning on line 210 in onnxruntime/contrib_ops/cpu/skip_layer_norm.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/contrib_ops/cpu/skip_layer_norm.cc:210: Add #include <vector> for vector<> [build/include_what_you_use] [4]
const DoubleOrFloat* converted_beta =
float_beta == nullptr
? reinterpret_cast<const DoubleOrFloat*>(beta_data)
: reinterpret_cast<const DoubleOrFloat*>(&(*float_beta)[0]);
for (size_t h = 0; h < static_cast<size_t>(hidden_size); h++) {
DoubleOrFloat gamma_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded<T, DoubleOrFloat>(gamma_data[h]);
if (simplified) {
p_output[h] = ConvertDoubleOrFloatToMLFloat16IfNeeded<T>(output_buffer[h] / mean_square * gamma_value);
output_buffer[h] = output_buffer[h] / mean_square * converted_gamma[h];
} else if (nullptr == beta_data) {
p_output[h] = ConvertDoubleOrFloatToMLFloat16IfNeeded<T>((output_buffer[h] - mean) / mean_square * gamma_value);
output_buffer[h] = (output_buffer[h] - mean) / mean_square * converted_gamma[h];
} else {
DoubleOrFloat beta_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded<T, DoubleOrFloat>(beta_data[h]);
p_output[h] = ConvertDoubleOrFloatToMLFloat16IfNeeded<T>((output_buffer[h] - mean) / mean_square * gamma_value + beta_value);
output_buffer[h] = (output_buffer[h] - mean) / mean_square * converted_gamma[h] + converted_beta[h];
}
}

if (std::is_same_v<decltype(p_output), MLFloat16>) {
ConvertFloatBufferToMLFloat16(
reinterpret_cast<float*>(output_buffer), reinterpret_cast<MLFloat16*>(p_output), hidden_size);
delete[] output_buffer;
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
}
},
0);

Expand Down
Loading
Loading