From 49bc1af9df9a939a0511b722255e43af79383d1f Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Thu, 19 Sep 2024 16:47:47 -0700 Subject: [PATCH] fix Where to use indices helper. --- .../core/providers/webgpu/tensor/where.cc | 136 ++++++++---------- 1 file changed, 62 insertions(+), 74 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/tensor/where.cc b/onnxruntime/core/providers/webgpu/tensor/where.cc index 99e7398e8a13..3dc543950f5c 100644 --- a/onnxruntime/core/providers/webgpu/tensor/where.cc +++ b/onnxruntime/core/providers/webgpu/tensor/where.cc @@ -5,28 +5,11 @@ #include "core/providers/webgpu/tensor/where.h" #include "core/providers/cpu/tensor/utils.h" #include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" namespace onnxruntime { namespace webgpu { -ONNX_OPERATOR_VERSIONED_KERNEL_EX( - Where, - kOnnxDomain, - 9, 15, - kCudaExecutionProvider, - (*KernelDefBuilder::Create()) - .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()), - Where); - -ONNX_OPERATOR_KERNEL_EX( - Where, - kOnnxDomain, - 16, - kCudaExecutionProvider, - (*KernelDefBuilder::Create()) - .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()), - Where); - // Compute where operator output shape based upon three way broad-casting. Status ComputeOutputShape(const TensorShape& cond_shape, const TensorShape& x_shape, const TensorShape& y_shape, TensorShape& output_shape) { @@ -72,45 +55,14 @@ Status ComputeOutputShape(const TensorShape& cond_shape, } Status WhereProgram::GenerateShaderCode(ShaderHelper& shader) const { - const auto a_name{"a_data"}; - const auto b_name{"b_data"}; - const auto c_name{"c_data"}; - const auto output_name{"output_data"}; - const auto& c_input = shader.AddInput(c_name, - ShaderUsage::UseUniform); - const auto& a_input = shader.AddInput(a_name, - ShaderUsage::UseUniform); - const auto& b_input = shader.AddInput(b_name, - ShaderUsage::UseUniform); - const auto& output = shader.AddOutput(output_name, - ShaderUsage::UseUniform); + const auto& c_input = shader.AddInput("c_data", ShaderUsage::UseUniform); + const auto& a_input = shader.AddInput("a_data", ShaderUsage::UseUniform); + const auto& b_input = shader.AddInput("b_data", ShaderUsage::UseUniform); + const auto& output = shader.AddOutput("output_data", ShaderUsage::UseUniform); + auto expression = [](const std::string& a, const std::string& b, const std::string& c) -> const auto { return "select(" + b + ", " + a + ", " + c + ")"; }; - auto single_assignment = - [expression, &output, &a_input, &b_input, &c_input]( - const std::string& rest_str, const std::string& x, const std::string& type_cast = "") - -> const auto { - const std::string a_expression = "a_data[index_a" + x + "][component_a" + x + "]"; - const std::string b_expression = "b_data[index_b" + x + "][component_b" + x + "]"; - const std::string c_expression = "bool(c_data[index_c" + x + "] & (0xffu << (component_c" + x + " * 8)))"; - - std::ostringstream ss; - ss.imbue(std::locale::classic()); - ss << "let output_indices" + x + " = " << output.OffsetToIndices("global_idx * 4u + " + x + "u") << ";\n"; - ss << "let offset_a" + x + " = " + a_input.BroadcastedIndicesToOffset("output_indices" + x, output) + ";\n"; - ss << "let offset_b" + x + " = " + b_input.BroadcastedIndicesToOffset("output_indices" + x, output) + ";\n"; - ss << "let offset_c" + x + " = " + c_input.BroadcastedIndicesToOffset("output_indices" + x, output) + ";\n"; - ss << "let index_a" + x + " = offset_a" + x + " / 4u;\n"; - ss << "let index_b" + x + " = offset_b" + x + " / 4u;\n"; - ss << "let index_c" + x + " = offset_c" + x + " / 4u;\n"; - ss << "let component_a" + x + " = offset_a" + x + " % 4u;\n"; - ss << "let component_b" + x + " = offset_b" + x + " % 4u;\n"; - ss << "let component_c" + x + " = offset_c" + x + " % 4u;\n"; - ss << rest_str + "[" + x + "] = " + type_cast + "(" + expression(a_expression, b_expression, c_expression) + ");\n"; - return ss.str(); - }; - std::string assignment; if (!is_broadcast_) { assignment = output.SetByOffset( @@ -118,6 +70,35 @@ Status WhereProgram::GenerateShaderCode(ShaderHelper& shader) const { expression(a_input.GetByOffset("global_idx"), b_input.GetByOffset("global_idx"), c_input.GetByOffset("global_idx"))); } else { + const auto& c_indices = shader.AddIndices("c_indices"); + const auto& a_indices = shader.AddIndices("a_indices"); + const auto& b_indices = shader.AddIndices("b_indices"); + const auto& output_indices = shader.AddIndices("output_indices"); + + auto single_assignment = + [&expression, &output_indices, &a_indices, &b_indices, &c_indices]( + const std::string& rest_str, const std::string& x, const std::string& type_cast = "") + -> const auto { + const std::string a_expression = "a_data[index_a" + x + "][component_a" + x + "]"; + const std::string b_expression = "b_data[index_b" + x + "][component_b" + x + "]"; + const std::string c_expression = "bool(c_data[index_c" + x + "] & (0xffu << (component_c" + x + " * 8)))"; + + std::ostringstream ss; + ss.imbue(std::locale::classic()); + ss << "let output_indices" + x + " = " << output_indices.OffsetToIndices("global_idx * 4u + " + x + "u") << ";\n"; + ss << "let offset_a" + x + " = " + a_indices.BroadcastedIndicesToOffset("output_indices" + x, output_indices) + ";\n"; + ss << "let offset_b" + x + " = " + b_indices.BroadcastedIndicesToOffset("output_indices" + x, output_indices) + ";\n"; + ss << "let offset_c" + x + " = " + c_indices.BroadcastedIndicesToOffset("output_indices" + x, output_indices) + ";\n"; + ss << "let index_a" + x + " = offset_a" + x + " / 4u;\n"; + ss << "let index_b" + x + " = offset_b" + x + " / 4u;\n"; + ss << "let index_c" + x + " = offset_c" + x + " / 4u;\n"; + ss << "let component_a" + x + " = offset_a" + x + " % 4u;\n"; + ss << "let component_b" + x + " = offset_b" + x + " % 4u;\n"; + ss << "let component_c" + x + " = offset_c" + x + " % 4u;\n"; + ss << rest_str + "[" + x + "] = " + type_cast + "(" + expression(a_expression, b_expression, c_expression) + ");\n"; + return ss.str(); + }; + if (Outputs()[0].tensor->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_BOOL) { assignment = "var data = vec4(0); \n" + @@ -157,34 +138,41 @@ Status Where::ComputeInternal(ComputeContext& context) const { WhereProgram program{is_broadcast}; program .CacheHint(is_broadcast) - .AddInputs({{cond_tensor, ProgramTensorMetadataDependency::Rank, component}, - {x_tensor, ProgramTensorMetadataDependency::Rank, component}, - {y_tensor, ProgramTensorMetadataDependency::Rank, component}}) - .AddOutputs({{output_tensor, - ProgramTensorMetadataDependency::Rank | - ProgramTensorMetadataDependency::Type, - component}}) .SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddInputs({{cond_tensor, ProgramTensorMetadataDependency::None, {(cond_shape.Size() + 3) / 4}, 4}, + {x_tensor, ProgramTensorMetadataDependency::None, {(x_shape.Size() + 3) / 4}, 4}, + {y_tensor, ProgramTensorMetadataDependency::None, {(y_shape.Size() + 3) / 4}, 4}}) + .AddOutput({output_tensor, ProgramTensorMetadataDependency::Type, {vec_size}, 4}) .AddUniformVariables({ {static_cast(vec_size)}, }); + if (is_broadcast) { + program + .AddIndices(cond_shape) + .AddIndices(x_shape) + .AddIndices(y_shape) + .AddIndices(output_tensor->Shape()); + } return context.RunProgram(program); } -#define WEBGPU_TRANSPOSE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS, TYPE) \ - ONNX_OPERATOR_KERNEL_EX( \ - OP_TYPE, kOnnxDomain, VERSION, kWebGpuExecutionProvider, \ - KernelDefBuilder().TypeConstraint("T", TYPE), \ - KERNEL_CLASS); - -#define WEBGPU_TRANSPOSE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, TYPE) \ - ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ - OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kWebGpuExecutionProvider, \ - KernelDefBuilder().TypeConstraint("T", TYPE), \ - KERNEL_CLASS); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Where, + kOnnxDomain, + 9, 15, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedFloatTypes()), + Where); -WEBGPU_TRANSPOSE_VERSIONED_KERNEL(Where, 9, 15, Where, WebGpuSupportedFloatTypes()) -WEBGPU_TRANSPOSE_KERNEL(Where, 16, Where, WebGpuSupportedFloatTypes()) +ONNX_OPERATOR_KERNEL_EX( + Where, + kOnnxDomain, + 16, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedFloatTypes()), + Where); } // namespace webgpu } // namespace onnxruntime