forked from microsoft/onnxruntime
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[webgpu native] Add rotary embedding op
- Loading branch information
Showing
4 changed files
with
200 additions
and
2 deletions.
There are no files selected for viewing
151 changes: 151 additions & 0 deletions
151
onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
// #include "contrib_ops/cpu/bert/rotary_embedding_helper.h" | ||
#include "core/providers/webgpu/shader_helper.h" | ||
#include "core/providers/webgpu/webgpu_supported_types.h" | ||
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" | ||
#include "contrib_ops/webgpu/bert/rotary_embedding.h" | ||
|
||
namespace onnxruntime { | ||
namespace contrib { | ||
namespace webgpu { | ||
|
||
ONNX_OPERATOR_KERNEL_EX( | ||
RotaryEmbedding, | ||
kMSDomain, | ||
1, | ||
kWebGpuExecutionProvider, | ||
(*KernelDefBuilder::Create()) | ||
.TypeConstraint("T", WebGpuSupportedFloatTypes()) | ||
.TypeConstraint("M", DataTypeImpl::GetTensorType<int64_t>()), | ||
RotaryEmbedding); | ||
|
||
const inline std::string ConvertBoolToString(bool b) { | ||
std::stringstream ss; | ||
ss << std::boolalpha << b; | ||
return ss.str(); | ||
} | ||
|
||
Status RotaryEmbeddingProgram::GenerateShaderCode(ShaderHelper& shader) const { | ||
const auto& input = shader.AddInput("input", ShaderUsage::UseUniform | | ||
ShaderUsage::UseValueTypeAlias); | ||
const auto& position_ids = shader.AddInput("position_ids", ShaderUsage::UseUniform | | ||
ShaderUsage::UseValueTypeAlias | | ||
ShaderUsage::UseShapeAndStride | | ||
ShaderUsage::UseIndicesToOffset); | ||
const auto& cos_cache = shader.AddInput("cos_cache", ShaderUsage::UseUniform | | ||
ShaderUsage::UseValueTypeAlias | | ||
ShaderUsage::UseShapeAndStride | | ||
ShaderUsage::UseIndicesToOffset); | ||
const auto& sin_cache = shader.AddInput("sin_cache", ShaderUsage::UseUniform | | ||
ShaderUsage::UseValueTypeAlias | | ||
ShaderUsage::UseShapeAndStride | | ||
ShaderUsage::UseIndicesToOffset); | ||
const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform); | ||
const auto& output_indices = shader.AddIndices("output_indices"); | ||
const auto interleaved_str = ConvertBoolToString(interleaved_); | ||
shader.SetMainFunctionBody( | ||
" let half_rotary_emb_dim = uniforms.cos_cache_shape[1];\n" | ||
" let bsnh = global_idx / uniforms.global_stride % uniforms.global_shape;\n" | ||
" let size = uniforms.global_shape[0] * uniforms.global_stride[0];\n", | ||
" if (global_idx >= size) { return; }\n" | ||
" if (bsnh[3] < half_rotary_emb_dim) {\n" | ||
" let position_ids_idx = " + | ||
position_ids.BroadcastedIndicesToOffset("bsnh.xy", output_indices) + ";\n" + | ||
" let position_id = u32(" + | ||
position_ids.GetByOffset("position_ids_idx") + ")" + | ||
" + select(0, bsnh[1], position_ids_idx == 0);\n" | ||
" let i = dot(bsnh, uniforms.input_output_stride) + select(0, bsnh[3], " + | ||
interleaved_str + | ||
");\n" | ||
" let j = i + select(half_rotary_emb_dim, 1, " + | ||
interleaved_str + | ||
");\n" | ||
" let re = " + | ||
input.GetByOffset("i") + " * " + cos_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") + "-" + | ||
input.GetByOffset("j") + " * " + sin_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") + ";\n" + | ||
" " + output.SetByOffset("i", "re") + "\n" + | ||
" let im = " + input.GetByOffset("i") + " * " + | ||
sin_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") + | ||
"+ " + input.GetByOffset("j") + | ||
" * " + cos_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") + | ||
";\n " + output.SetByOffset("j", "im") + | ||
"\n" | ||
" } else { \n" | ||
" let k = dot(bsnh, uniforms.input_output_stride) + half_rotary_emb_dim;\n" + | ||
" " + output.SetByOffset("k", input.GetByOffset("k")) + | ||
"\n" | ||
" }"); | ||
|
||
return Status::OK(); | ||
} | ||
|
||
RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : WebGpuKernel(info) { | ||
scale_ = info.GetAttrOrDefault<float>("scale", 1.0); | ||
rotary_embedding_dim_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("rotary_embedding_dim", 0)); | ||
num_heads_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("num_heads", 0)); | ||
interleaved_ = (info.GetAttrOrDefault<int64_t>("interleaved", 0) == 1); | ||
is_packed_batching_ = (info.GetAttrOrDefault<int64_t>("is_packed_batching", 0) == 1); | ||
} | ||
|
||
Status RotaryEmbedding::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { | ||
const Tensor* input = context.Input<Tensor>(0); | ||
const auto input_shape = input->Shape(); | ||
const Tensor* position_ids = context.Input<Tensor>(1); | ||
const Tensor* cos_cache = context.Input<Tensor>(2); | ||
const Tensor* sin_cache = context.Input<Tensor>(3); | ||
|
||
auto* output = context.Output(0, input_shape); | ||
|
||
const auto batch_size = gsl::narrow_cast<uint32_t>(input->Shape()[0]); | ||
const auto batch_stride = gsl::narrow_cast<uint32_t>(input_shape.SizeFromDimension(1)); | ||
const auto sequence_length = gsl::narrow_cast<uint32_t>(input_shape[input_shape.NumDimensions() - 2]); | ||
const auto hidden_size = batch_stride / sequence_length; | ||
const auto half_rotary_embedding_dim = gsl::narrow_cast<uint32_t>(cos_cache->Shape()[1]); | ||
const auto head_size = rotary_embedding_dim_ == 0 ? half_rotary_embedding_dim * 2 : hidden_size / num_heads_; | ||
|
||
// Rotary embeddings will be calculated in a pair-wise fashion. In accordance, use the shape | ||
// [batch size, sequence length, num of heads, num of pairs to rotate + num of dims to copy] | ||
// to unfold the global index in shader. | ||
const TensorShape global_shape({batch_size, | ||
sequence_length, | ||
hidden_size / head_size, | ||
head_size - half_rotary_embedding_dim}); | ||
|
||
const auto rank = global_shape.NumDimensions(); | ||
std::vector<uint32_t> global_dims(rank); | ||
std::vector<uint32_t> global_strides(rank); | ||
for (size_t j = 0; j < rank; ++j) { | ||
global_dims[j] = gsl::narrow_cast<uint32_t>(global_shape[j]); | ||
global_strides[j] = gsl::narrow_cast<uint32_t>(global_shape.SizeFromDimension(j + 1)); | ||
} | ||
|
||
const auto output_size = gsl::narrow_cast<const uint32_t>(global_shape.Size()); | ||
RotaryEmbeddingProgram program{interleaved_}; | ||
const auto input_output_strides = | ||
input_shape.NumDimensions() == 3 | ||
? std::vector<uint32_t>({batch_stride, hidden_size, head_size, 1}) | ||
: (input_shape.NumDimensions() == 4 | ||
? std::vector<uint32_t>({batch_stride, head_size, sequence_length * head_size, 1}) | ||
: std::vector<uint32_t>({})); | ||
|
||
program | ||
.CacheHint(interleaved_) | ||
.AddInputs({{input, ProgramTensorMetadataDependency::TypeAndRank}, | ||
{position_ids, ProgramTensorMetadataDependency::TypeAndRank}, | ||
{cos_cache, ProgramTensorMetadataDependency::TypeAndRank}, | ||
{sin_cache, ProgramTensorMetadataDependency::TypeAndRank}}) | ||
.AddOutput({output, ProgramTensorMetadataDependency::None, {input_shape.Size()}, 1}) | ||
.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) | ||
.AddUniformVariables({{scale_}, | ||
{gsl::make_span(global_dims)}, | ||
{gsl::make_span(global_strides)}, | ||
{gsl::make_span(input_output_strides)}}) | ||
.AddIndices(TensorShape{1, 1}); | ||
return context.RunProgram(program); | ||
} | ||
|
||
} // namespace webgpu | ||
} // namespace contrib | ||
} // namespace onnxruntime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
|
||
#include "core/providers/webgpu/program.h" | ||
#include "core/providers/webgpu/webgpu_kernel.h" | ||
|
||
namespace onnxruntime { | ||
namespace contrib { | ||
namespace webgpu { | ||
|
||
using namespace onnxruntime::webgpu; | ||
using onnxruntime::webgpu::ComputeContext; | ||
|
||
class RotaryEmbeddingProgram final : public Program<RotaryEmbeddingProgram> { | ||
public: | ||
RotaryEmbeddingProgram(bool interleaved) : Program{"RotaryEmbedding"}, interleaved_{interleaved} { | ||
} | ||
|
||
Status GenerateShaderCode(ShaderHelper& sh) const override; | ||
|
||
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"scale", ProgramUniformVariableDataType::Float32}, | ||
{"global_shape", ProgramUniformVariableDataType::Uint32}, | ||
{"global_stride", ProgramUniformVariableDataType::Uint32}, | ||
{"input_output_stride", ProgramUniformVariableDataType::Uint32}); | ||
|
||
private: | ||
const bool interleaved_; | ||
}; | ||
|
||
class RotaryEmbedding final : public WebGpuKernel { | ||
public: | ||
RotaryEmbedding(const OpKernelInfo& info); | ||
Status ComputeInternal(ComputeContext& context) const override; | ||
|
||
protected: | ||
float scale_; | ||
int num_heads_; | ||
int rotary_embedding_dim_; | ||
bool interleaved_; | ||
bool is_packed_batching_; | ||
}; | ||
|
||
} // namespace webgpu | ||
} // namespace contrib | ||
} // namespace onnxruntime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters