Skip to content

Commit

Permalink
First pass
Browse files Browse the repository at this point in the history
  • Loading branch information
axinging committed Sep 23, 2024
1 parent 2dc3d50 commit 5078503
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 37 deletions.
96 changes: 60 additions & 36 deletions onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,34 +29,45 @@ const std::string BoolToString(bool b) {

Status RotaryEmbeddingProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& input = shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
const auto& position_ids = shader.AddInput("position_ids", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
const auto& cos_cache = shader.AddInput("cos_cache", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
const auto& sin_cache = shader.AddInput("sin_cache", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
const auto& position_ids = shader.AddInput("position_ids", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseShapeAndStride | ShaderUsage::UseIndicesToOffset);
const auto& cos_cache = shader.AddInput("cos_cache", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseShapeAndStride | ShaderUsage::UseIndicesToOffset);
const auto& sin_cache = shader.AddInput("sin_cache", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseShapeAndStride | ShaderUsage::UseIndicesToOffset);
const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform);
const auto interleaved = false;
const auto interleaved_str = BoolToString(interleaved);
// outputVariable('', positionIds.type.tensor, 2)

shader.AppendImplementation(
" fn broadcastedIndicesToposition_idsOffset(outputIndices: vec2<u32>) -> u32 {\n"
"return (outputIndices[1] % uniforms.position_ids_shape[1]) + uniforms.position_ids_stride * (outputIndices[0] % uniforms.position_ids_shape[0]);\n"
"}\n");
shader.SetMainFunctionBody(
"let half_rotary_emb_dim = uniforms.cos_cache_shape[1];"
"let bsnh = global_idx / uniforms.global_strides % uniforms.global_shape;"
"let size = uniforms.global_shape[0] * uniforms.global_strides[0];",
shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"),
"if (bsnh[3] < half_rotary_emb_dim) {"
" let position_ids_idx = " +
position_ids.BroadcastedIndicesToOffset("bsnh.xy", output) +
"let position_id ="
"u32(" +
"let half_rotary_emb_dim = uniforms.cos_cache_shape[1];\n"
"let bsnh = global_idx / uniforms.global_strides % uniforms.global_shape;\n"
"let size = uniforms.global_shape[0] * uniforms.global_strides[0];\n",
"if (global_idx >= size) { return; } //shader.GuardAgainstOutOfBoundsWorkgroupSizes(uniforms.vec_size),\n"
"if (bsnh[3] < half_rotary_emb_dim) {\n"
" let position_ids_idx = broadcastedIndicesToposition_idsOffset(bsnh.xy);\n"
" let position_id ="
"u32(" +
position_ids.GetByOffset("position_ids_idx") + ")" +
" + select(0, bsnh[1], position_ids_idx == 0);"
"let i = dot(bsnh, uniforms.input_output_strides) + select(0, bsnh[3], interleaved_str);"
"let j = i + select(half_rotary_emb_dim, 1, interleaved_str);"
" + select(0, bsnh[1], position_ids_idx == 0);\n"
"let i = dot(bsnh, uniforms.input_output_strides) + select(0, bsnh[3], " +
interleaved_str +
");\n"
"let j = i + select(half_rotary_emb_dim, 1, " +
interleaved_str +
");\n"
"let re = " +
input.GetByOffset("i") + " * " + cos_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") + "-" +
input.GetByOffset("j") + " * " + sin_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") + ";" + output.SetByOffset("i", "re") + "let im = " + input.GetByOffset("i") + " * " + sin_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") + "+ " + input.GetByOffset("j") + " * " + cos_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") + output.SetByOffset("j", "im") +
"} else { "
"let k = dot(bsnh, uniforms.input_output_strides) + half_rotary_emb_dim;" +
input.GetByOffset("j") + " * " + sin_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") + ";\n" +
output.SetByOffset("i", "re") + ";\n" +
"let im = " + input.GetByOffset("i") + " * " + sin_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") + "+ " + input.GetByOffset("j") + " * " + cos_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") +
";\n" + output.SetByOffset("j", "im") +
"\n"
"} else { \n"
" let k = dot(bsnh, uniforms.input_output_strides) + half_rotary_emb_dim;\n" +
output.SetByOffset("k", input.GetByOffset("k")) +
"\n"
"}");

return Status::OK();
Expand All @@ -83,18 +94,18 @@ const auto GetSizeFromDimensionRange(const TensorShape& dims, uint32_t start, ui
}

InlinedVector<uint32_t> ComputeStrides(InlinedVector<uint32_t>& dims) {
const auto rank = dims.size();
const auto rank = gsl::narrow_cast<const uint32_t>(dims.size());
if (rank == 0) {
// return [];
} else if (rank == 1) {
//return [1];
// return [1];
}
// const strides = new InlinedVector<int64_t, rank>();
InlinedVector<uint32_t> strides;
strides.reserve(rank);
InlinedVector<uint32_t> strides(rank);
// strides.reserve(rank);
strides[rank - 1] = 1;
strides[rank - 2] = dims[rank - 1];
for (auto i = rank - 3; i >= 0; --i) {
for (int32_t i = rank - 3; i >= 0; --i) {
strides[i] = strides[i + 1] * dims[i + 1];
}
return strides;
Expand All @@ -103,14 +114,16 @@ InlinedVector<uint32_t> ComputeStrides(InlinedVector<uint32_t>& dims) {
Status RotaryEmbedding::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const {
const Tensor* input = context.Input<Tensor>(0);
const auto input_shape = input->Shape();
// const Tensor* position_ids = context.Input<Tensor>(1);
const Tensor* position_ids = context.Input<Tensor>(1);
//(void)position_ids;
const Tensor* cos_cache = context.Input<Tensor>(2);
//const Tensor* sin_cache = context.Input<Tensor>(3);
const Tensor* sin_cache = context.Input<Tensor>(3);
//(void)sin_cache;

auto* output = context.Output(0, input_shape);

const auto batchSize = gsl::narrow_cast<uint32_t>(input->Shape()[0]);
const auto batchStride = gsl::narrow_cast<uint32_t>(GetSizeFromDimensionRange(input_shape, 0, 1));
const auto batchStride = 36; // gsl::narrow_cast<uint32_t>(GetSizeFromDimensionRange(input_shape, 0, 1));
const auto sequenceLength = gsl::narrow_cast<uint32_t>(input_shape[input_shape.NumDimensions() - 2]);
const auto hiddenSize = batchStride / sequenceLength;
const auto halfRotaryEmbeddingDim = gsl::narrow_cast<uint32_t>(cos_cache->Shape()[1]); // inputs[2].dims[1];
Expand All @@ -120,9 +133,9 @@ Status RotaryEmbedding::ComputeInternal(onnxruntime::webgpu::ComputeContext& con
// [batch size, sequence length, num of heads, num of pairs to rotate + num of dims to copy]
// to unfold the global index in shader.
auto globalShape = new InlinedVector<uint32_t>({batchSize,
sequenceLength,
hiddenSize / headSize,
headSize - halfRotaryEmbeddingDim});
sequenceLength,
hiddenSize / headSize,
headSize - halfRotaryEmbeddingDim});

InlinedVector<uint32_t> globalStrides = ComputeStrides(*globalShape);
const auto vec_size = gsl::narrow_cast<const uint32_t>(globalShape->size());
Expand All @@ -137,14 +150,25 @@ Status RotaryEmbedding::ComputeInternal(onnxruntime::webgpu::ComputeContext& con
program.AddUniformVariable({scale_});
program.AddUniformVariables({{gsl::make_span(globalShape->data(), 4)}});
program.AddUniformVariables({{gsl::make_span(globalStrides.data(),
globalStrides.size())}});
program.AddInput({input, ProgramTensorMetadataDependency::Type, {vec_size}, 4})
.AddOutput({output, ProgramTensorMetadataDependency::None, {vec_size}, 4})
globalStrides.size())}});
program.AddInputs({{input, ProgramTensorMetadataDependency::TypeAndRank},
{position_ids, ProgramTensorMetadataDependency::TypeAndRank},
{cos_cache, ProgramTensorMetadataDependency::TypeAndRank},
{sin_cache, ProgramTensorMetadataDependency::TypeAndRank}})
.AddOutput({output, ProgramTensorMetadataDependency::None, {input_shape.Size()}, 1})
.SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
.AddUniformVariables({gsl::make_span(input_output_strides.data(), input_output_strides.size())});
/*
program.AddInputs({{input, ProgramTensorMetadataDependency::TypeAndRank, {input_shape.Size()}, 1},
{position_ids, ProgramTensorMetadataDependency::TypeAndRank, {position_ids->Shape().Size()}, 1},
{cos_cache, ProgramTensorMetadataDependency::TypeAndRank, {cos_cache->Shape().Size()}, 1},
{sin_cache, ProgramTensorMetadataDependency::TypeAndRank, {sin_cache->Shape().Size()}, 1}})
.AddOutput({output, ProgramTensorMetadataDependency::None, {input_shape.Size()}, 1})
.SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
.AddUniformVariables({

gsl::make_span(input_output_strides.data(), input_output_strides.size())});

gsl::make_span(input_output_strides.data(), input_output_strides.size())});
*/
program.AddUniformVariable({vec_size});
return context.RunProgram(program);
}

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) {
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MatMulNBits)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MultiHeadAttention)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, QuickGelu)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, RotaryEmbedding)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, RotaryEmbedding)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1,
// SkipLayerNormalization)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1,
Expand Down

0 comments on commit 5078503

Please sign in to comment.