From 50785030ec596f27c98197b0b8b5f0f1f700d1fd Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Mon, 23 Sep 2024 15:26:03 +0800 Subject: [PATCH] First pass --- .../webgpu/bert/rotary_embedding.cc | 96 ++++++++++++------- .../webgpu/webgpu_contrib_kernels.cc | 2 +- 2 files changed, 61 insertions(+), 37 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc index f0b6b4dc37355..2a4aba1dba3e0 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc @@ -29,34 +29,45 @@ const std::string BoolToString(bool b) { Status RotaryEmbeddingProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& input = shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); - const auto& position_ids = shader.AddInput("position_ids", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); - const auto& cos_cache = shader.AddInput("cos_cache", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); - const auto& sin_cache = shader.AddInput("sin_cache", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + const auto& position_ids = shader.AddInput("position_ids", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseShapeAndStride | ShaderUsage::UseIndicesToOffset); + const auto& cos_cache = shader.AddInput("cos_cache", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseShapeAndStride | ShaderUsage::UseIndicesToOffset); + const auto& sin_cache = shader.AddInput("sin_cache", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseShapeAndStride | ShaderUsage::UseIndicesToOffset); const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform); const auto interleaved = false; const auto interleaved_str = BoolToString(interleaved); // outputVariable('', positionIds.type.tensor, 2) - + shader.AppendImplementation( + " fn broadcastedIndicesToposition_idsOffset(outputIndices: vec2) -> u32 {\n" + "return (outputIndices[1] % uniforms.position_ids_shape[1]) + uniforms.position_ids_stride * (outputIndices[0] % uniforms.position_ids_shape[0]);\n" + "}\n"); shader.SetMainFunctionBody( - "let half_rotary_emb_dim = uniforms.cos_cache_shape[1];" - "let bsnh = global_idx / uniforms.global_strides % uniforms.global_shape;" - "let size = uniforms.global_shape[0] * uniforms.global_strides[0];", - shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"), - "if (bsnh[3] < half_rotary_emb_dim) {" - " let position_ids_idx = " + - position_ids.BroadcastedIndicesToOffset("bsnh.xy", output) + - "let position_id =" - "u32(" + + "let half_rotary_emb_dim = uniforms.cos_cache_shape[1];\n" + "let bsnh = global_idx / uniforms.global_strides % uniforms.global_shape;\n" + "let size = uniforms.global_shape[0] * uniforms.global_strides[0];\n", + "if (global_idx >= size) { return; } //shader.GuardAgainstOutOfBoundsWorkgroupSizes(uniforms.vec_size),\n" + "if (bsnh[3] < half_rotary_emb_dim) {\n" + " let position_ids_idx = broadcastedIndicesToposition_idsOffset(bsnh.xy);\n" + " let position_id =" + "u32(" + position_ids.GetByOffset("position_ids_idx") + ")" + - " + select(0, bsnh[1], position_ids_idx == 0);" - "let i = dot(bsnh, uniforms.input_output_strides) + select(0, bsnh[3], interleaved_str);" - "let j = i + select(half_rotary_emb_dim, 1, interleaved_str);" + " + select(0, bsnh[1], position_ids_idx == 0);\n" + "let i = dot(bsnh, uniforms.input_output_strides) + select(0, bsnh[3], " + + interleaved_str + + ");\n" + "let j = i + select(half_rotary_emb_dim, 1, " + + interleaved_str + + ");\n" "let re = " + input.GetByOffset("i") + " * " + cos_cache.GetByIndices("vec2(position_id, bsnh[3])") + "-" + - input.GetByOffset("j") + " * " + sin_cache.GetByIndices("vec2(position_id, bsnh[3])") + ";" + output.SetByOffset("i", "re") + "let im = " + input.GetByOffset("i") + " * " + sin_cache.GetByIndices("vec2(position_id, bsnh[3])") + "+ " + input.GetByOffset("j") + " * " + cos_cache.GetByIndices("vec2(position_id, bsnh[3])") + output.SetByOffset("j", "im") + - "} else { " - "let k = dot(bsnh, uniforms.input_output_strides) + half_rotary_emb_dim;" + + input.GetByOffset("j") + " * " + sin_cache.GetByIndices("vec2(position_id, bsnh[3])") + ";\n" + + output.SetByOffset("i", "re") + ";\n" + + "let im = " + input.GetByOffset("i") + " * " + sin_cache.GetByIndices("vec2(position_id, bsnh[3])") + "+ " + input.GetByOffset("j") + " * " + cos_cache.GetByIndices("vec2(position_id, bsnh[3])") + + ";\n" + output.SetByOffset("j", "im") + + "\n" + "} else { \n" + " let k = dot(bsnh, uniforms.input_output_strides) + half_rotary_emb_dim;\n" + output.SetByOffset("k", input.GetByOffset("k")) + + "\n" "}"); return Status::OK(); @@ -83,18 +94,18 @@ const auto GetSizeFromDimensionRange(const TensorShape& dims, uint32_t start, ui } InlinedVector ComputeStrides(InlinedVector& dims) { - const auto rank = dims.size(); + const auto rank = gsl::narrow_cast(dims.size()); if (rank == 0) { // return []; } else if (rank == 1) { - //return [1]; + // return [1]; } // const strides = new InlinedVector(); - InlinedVector strides; - strides.reserve(rank); + InlinedVector strides(rank); + // strides.reserve(rank); strides[rank - 1] = 1; strides[rank - 2] = dims[rank - 1]; - for (auto i = rank - 3; i >= 0; --i) { + for (int32_t i = rank - 3; i >= 0; --i) { strides[i] = strides[i + 1] * dims[i + 1]; } return strides; @@ -103,14 +114,16 @@ InlinedVector ComputeStrides(InlinedVector& dims) { Status RotaryEmbedding::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { const Tensor* input = context.Input(0); const auto input_shape = input->Shape(); - // const Tensor* position_ids = context.Input(1); + const Tensor* position_ids = context.Input(1); + //(void)position_ids; const Tensor* cos_cache = context.Input(2); - //const Tensor* sin_cache = context.Input(3); + const Tensor* sin_cache = context.Input(3); + //(void)sin_cache; auto* output = context.Output(0, input_shape); const auto batchSize = gsl::narrow_cast(input->Shape()[0]); - const auto batchStride = gsl::narrow_cast(GetSizeFromDimensionRange(input_shape, 0, 1)); + const auto batchStride = 36; // gsl::narrow_cast(GetSizeFromDimensionRange(input_shape, 0, 1)); const auto sequenceLength = gsl::narrow_cast(input_shape[input_shape.NumDimensions() - 2]); const auto hiddenSize = batchStride / sequenceLength; const auto halfRotaryEmbeddingDim = gsl::narrow_cast(cos_cache->Shape()[1]); // inputs[2].dims[1]; @@ -120,9 +133,9 @@ Status RotaryEmbedding::ComputeInternal(onnxruntime::webgpu::ComputeContext& con // [batch size, sequence length, num of heads, num of pairs to rotate + num of dims to copy] // to unfold the global index in shader. auto globalShape = new InlinedVector({batchSize, - sequenceLength, - hiddenSize / headSize, - headSize - halfRotaryEmbeddingDim}); + sequenceLength, + hiddenSize / headSize, + headSize - halfRotaryEmbeddingDim}); InlinedVector globalStrides = ComputeStrides(*globalShape); const auto vec_size = gsl::narrow_cast(globalShape->size()); @@ -137,14 +150,25 @@ Status RotaryEmbedding::ComputeInternal(onnxruntime::webgpu::ComputeContext& con program.AddUniformVariable({scale_}); program.AddUniformVariables({{gsl::make_span(globalShape->data(), 4)}}); program.AddUniformVariables({{gsl::make_span(globalStrides.data(), - globalStrides.size())}}); - program.AddInput({input, ProgramTensorMetadataDependency::Type, {vec_size}, 4}) - .AddOutput({output, ProgramTensorMetadataDependency::None, {vec_size}, 4}) + globalStrides.size())}}); + program.AddInputs({{input, ProgramTensorMetadataDependency::TypeAndRank}, + {position_ids, ProgramTensorMetadataDependency::TypeAndRank}, + {cos_cache, ProgramTensorMetadataDependency::TypeAndRank}, + {sin_cache, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddOutput({output, ProgramTensorMetadataDependency::None, {input_shape.Size()}, 1}) + .SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({gsl::make_span(input_output_strides.data(), input_output_strides.size())}); + /* + program.AddInputs({{input, ProgramTensorMetadataDependency::TypeAndRank, {input_shape.Size()}, 1}, + {position_ids, ProgramTensorMetadataDependency::TypeAndRank, {position_ids->Shape().Size()}, 1}, + {cos_cache, ProgramTensorMetadataDependency::TypeAndRank, {cos_cache->Shape().Size()}, 1}, + {sin_cache, ProgramTensorMetadataDependency::TypeAndRank, {sin_cache->Shape().Size()}, 1}}) + .AddOutput({output, ProgramTensorMetadataDependency::None, {input_shape.Size()}, 1}) .SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) .AddUniformVariables({ - - gsl::make_span(input_output_strides.data(), input_output_strides.size())}); - + gsl::make_span(input_output_strides.data(), input_output_strides.size())}); + */ + program.AddUniformVariable({vec_size}); return context.RunProgram(program); } diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc index def104b6cb108..01c8a28d45069 100644 --- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc @@ -47,7 +47,7 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) { // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo