Skip to content

Commit

Permalink
[webgpu native] Add rotary embedding op
Browse files Browse the repository at this point in the history
  • Loading branch information
axinging committed Sep 24, 2024
1 parent cb9f3a4 commit dd40b49
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 2 deletions.
151 changes: 151 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

// #include "contrib_ops/cpu/bert/rotary_embedding_helper.h"
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_supported_types.h"
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"
#include "contrib_ops/webgpu/bert/rotary_embedding.h"

namespace onnxruntime {
namespace contrib {
namespace webgpu {

ONNX_OPERATOR_KERNEL_EX(
RotaryEmbedding,
kMSDomain,
1,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", WebGpuSupportedFloatTypes())
.TypeConstraint("M", DataTypeImpl::GetTensorType<int64_t>()),
RotaryEmbedding);

const inline std::string ConvertBoolToString(bool b) {
std::stringstream ss;
ss << std::boolalpha << b;
return ss.str();
}

Status RotaryEmbeddingProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& input = shader.AddInput("input", ShaderUsage::UseUniform |
ShaderUsage::UseValueTypeAlias);
const auto& position_ids = shader.AddInput("position_ids", ShaderUsage::UseUniform |
ShaderUsage::UseValueTypeAlias |
ShaderUsage::UseShapeAndStride |
ShaderUsage::UseIndicesToOffset);
const auto& cos_cache = shader.AddInput("cos_cache", ShaderUsage::UseUniform |
ShaderUsage::UseValueTypeAlias |
ShaderUsage::UseShapeAndStride |
ShaderUsage::UseIndicesToOffset);
const auto& sin_cache = shader.AddInput("sin_cache", ShaderUsage::UseUniform |
ShaderUsage::UseValueTypeAlias |
ShaderUsage::UseShapeAndStride |
ShaderUsage::UseIndicesToOffset);
const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform);
const auto& output_indices = shader.AddIndices("output_indices");
const auto interleaved_str = ConvertBoolToString(interleaved_);
shader.SetMainFunctionBody(
" let half_rotary_emb_dim = uniforms.cos_cache_shape[1];\n"
" let bsnh = global_idx / uniforms.global_stride % uniforms.global_shape;\n"
" let size = uniforms.global_shape[0] * uniforms.global_stride[0];\n",
" if (global_idx >= size) { return; }\n"
" if (bsnh[3] < half_rotary_emb_dim) {\n"
" let position_ids_idx = " +
position_ids.BroadcastedIndicesToOffset("bsnh.xy", output_indices) + ";\n" +
" let position_id = u32(" +
position_ids.GetByOffset("position_ids_idx") + ")" +
" + select(0, bsnh[1], position_ids_idx == 0);\n"
" let i = dot(bsnh, uniforms.input_output_stride) + select(0, bsnh[3], " +
interleaved_str +
");\n"
" let j = i + select(half_rotary_emb_dim, 1, " +
interleaved_str +
");\n"
" let re = " +
input.GetByOffset("i") + " * " + cos_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") + "-" +
input.GetByOffset("j") + " * " + sin_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") + ";\n" +
" " + output.SetByOffset("i", "re") + "\n" +
" let im = " + input.GetByOffset("i") + " * " +
sin_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") +
"+ " + input.GetByOffset("j") +
" * " + cos_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") +
";\n " + output.SetByOffset("j", "im") +
"\n"
" } else { \n"
" let k = dot(bsnh, uniforms.input_output_stride) + half_rotary_emb_dim;\n" +
" " + output.SetByOffset("k", input.GetByOffset("k")) +
"\n"
" }");

return Status::OK();
}

RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : WebGpuKernel(info) {
scale_ = info.GetAttrOrDefault<float>("scale", 1.0);
rotary_embedding_dim_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("rotary_embedding_dim", 0));
num_heads_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("num_heads", 0));
interleaved_ = (info.GetAttrOrDefault<int64_t>("interleaved", 0) == 1);
is_packed_batching_ = (info.GetAttrOrDefault<int64_t>("is_packed_batching", 0) == 1);
}

Status RotaryEmbedding::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const {
const Tensor* input = context.Input<Tensor>(0);
const auto input_shape = input->Shape();
const Tensor* position_ids = context.Input<Tensor>(1);
const Tensor* cos_cache = context.Input<Tensor>(2);
const Tensor* sin_cache = context.Input<Tensor>(3);

auto* output = context.Output(0, input_shape);

const auto batch_size = gsl::narrow_cast<uint32_t>(input->Shape()[0]);
const auto batch_stride = gsl::narrow_cast<uint32_t>(input_shape.SizeFromDimension(1));
const auto sequence_length = gsl::narrow_cast<uint32_t>(input_shape[input_shape.NumDimensions() - 2]);
const auto hidden_size = batch_stride / sequence_length;
const auto half_rotary_embedding_dim = gsl::narrow_cast<uint32_t>(cos_cache->Shape()[1]);
const auto head_size = rotary_embedding_dim_ == 0 ? half_rotary_embedding_dim * 2 : hidden_size / num_heads_;

// Rotary embeddings will be calculated in a pair-wise fashion. In accordance, use the shape
// [batch size, sequence length, num of heads, num of pairs to rotate + num of dims to copy]
// to unfold the global index in shader.
const TensorShape global_shape({batch_size,
sequence_length,
hidden_size / head_size,
head_size - half_rotary_embedding_dim});

const auto rank = global_shape.NumDimensions();
std::vector<uint32_t> global_dims(rank);
std::vector<uint32_t> global_strides(rank);
for (size_t j = 0; j < rank; ++j) {
global_dims[j] = gsl::narrow_cast<uint32_t>(global_shape[j]);
global_strides[j] = gsl::narrow_cast<uint32_t>(global_shape.SizeFromDimension(j + 1));
}

const auto output_size = gsl::narrow_cast<const uint32_t>(global_shape.Size());
RotaryEmbeddingProgram program{interleaved_};
const auto input_output_strides =
input_shape.NumDimensions() == 3
? std::vector<uint32_t>({batch_stride, hidden_size, head_size, 1})
: (input_shape.NumDimensions() == 4
? std::vector<uint32_t>({batch_stride, head_size, sequence_length * head_size, 1})
: std::vector<uint32_t>({}));

program
.CacheHint(interleaved_)
.AddInputs({{input, ProgramTensorMetadataDependency::TypeAndRank},
{position_ids, ProgramTensorMetadataDependency::TypeAndRank},
{cos_cache, ProgramTensorMetadataDependency::TypeAndRank},
{sin_cache, ProgramTensorMetadataDependency::TypeAndRank}})
.AddOutput({output, ProgramTensorMetadataDependency::None, {input_shape.Size()}, 1})
.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
.AddUniformVariables({{scale_},
{gsl::make_span(global_dims)},
{gsl::make_span(global_strides)},
{gsl::make_span(input_output_strides)}})
.AddIndices(TensorShape{1, 1});
return context.RunProgram(program);
}

} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
47 changes: 47 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/providers/webgpu/program.h"
#include "core/providers/webgpu/webgpu_kernel.h"

namespace onnxruntime {
namespace contrib {
namespace webgpu {

using namespace onnxruntime::webgpu;
using onnxruntime::webgpu::ComputeContext;

class RotaryEmbeddingProgram final : public Program<RotaryEmbeddingProgram> {
public:
RotaryEmbeddingProgram(bool interleaved) : Program{"RotaryEmbedding"}, interleaved_{interleaved} {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"scale", ProgramUniformVariableDataType::Float32},
{"global_shape", ProgramUniformVariableDataType::Uint32},
{"global_stride", ProgramUniformVariableDataType::Uint32},
{"input_output_stride", ProgramUniformVariableDataType::Uint32});

private:
const bool interleaved_;
};

class RotaryEmbedding final : public WebGpuKernel {
public:
RotaryEmbedding(const OpKernelInfo& info);
Status ComputeInternal(ComputeContext& context) const override;

protected:
float scale_;
int num_heads_;
int rotary_embedding_dim_;
bool interleaved_;
bool is_packed_batching_;
};

} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) {
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MatMulNBits)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MultiHeadAttention)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, QuickGelu)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, RotaryEmbedding)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, RotaryEmbedding)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1,
// SkipLayerNormalization)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1,
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ static void RunTest(
if (tensor_type == TensorType::kFloat && !disable_cpu) {
execution_providers.push_back(DefaultCpuExecutionProvider());
}
execution_providers.push_back(DefaultWebGpuExecutionProvider());
if (execution_providers.size() == 0) {
// Return early if CI pipeline does not support EP (e.g. CUDA EP for CPU CI pipeline)
return;
Expand Down Expand Up @@ -118,7 +119,6 @@ static void RunTest(
} else {
test.SetOutputAbsErr("output", 0.002f);
}

test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}

Expand Down

0 comments on commit dd40b49

Please sign in to comment.