Skip to content

Commit

Permalink
fix Where to use indices helper.
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire committed Sep 19, 2024
1 parent 5edd041 commit 49bc1af
Showing 1 changed file with 62 additions and 74 deletions.
136 changes: 62 additions & 74 deletions onnxruntime/core/providers/webgpu/tensor/where.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,11 @@
#include "core/providers/webgpu/tensor/where.h"
#include "core/providers/cpu/tensor/utils.h"
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_supported_types.h"

namespace onnxruntime {
namespace webgpu {

ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Where,
kOnnxDomain,
9, 15,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
Where);

ONNX_OPERATOR_KERNEL_EX(
Where,
kOnnxDomain,
16,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
Where);

// Compute where operator output shape based upon three way broad-casting.
Status ComputeOutputShape(const TensorShape& cond_shape,
const TensorShape& x_shape, const TensorShape& y_shape, TensorShape& output_shape) {
Expand Down Expand Up @@ -72,52 +55,50 @@ Status ComputeOutputShape(const TensorShape& cond_shape,
}

Status WhereProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto a_name{"a_data"};
const auto b_name{"b_data"};
const auto c_name{"c_data"};
const auto output_name{"output_data"};
const auto& c_input = shader.AddInput(c_name,
ShaderUsage::UseUniform);
const auto& a_input = shader.AddInput(a_name,
ShaderUsage::UseUniform);
const auto& b_input = shader.AddInput(b_name,
ShaderUsage::UseUniform);
const auto& output = shader.AddOutput(output_name,
ShaderUsage::UseUniform);
const auto& c_input = shader.AddInput("c_data", ShaderUsage::UseUniform);
const auto& a_input = shader.AddInput("a_data", ShaderUsage::UseUniform);
const auto& b_input = shader.AddInput("b_data", ShaderUsage::UseUniform);
const auto& output = shader.AddOutput("output_data", ShaderUsage::UseUniform);

auto expression = [](const std::string& a, const std::string& b, const std::string& c) -> const auto {
return "select(" + b + ", " + a + ", " + c + ")";
};
auto single_assignment =
[expression, &output, &a_input, &b_input, &c_input](
const std::string& rest_str, const std::string& x, const std::string& type_cast = "")
-> const auto {
const std::string a_expression = "a_data[index_a" + x + "][component_a" + x + "]";
const std::string b_expression = "b_data[index_b" + x + "][component_b" + x + "]";
const std::string c_expression = "bool(c_data[index_c" + x + "] & (0xffu << (component_c" + x + " * 8)))";

std::ostringstream ss;
ss.imbue(std::locale::classic());
ss << "let output_indices" + x + " = " << output.OffsetToIndices("global_idx * 4u + " + x + "u") << ";\n";
ss << "let offset_a" + x + " = " + a_input.BroadcastedIndicesToOffset("output_indices" + x, output) + ";\n";
ss << "let offset_b" + x + " = " + b_input.BroadcastedIndicesToOffset("output_indices" + x, output) + ";\n";
ss << "let offset_c" + x + " = " + c_input.BroadcastedIndicesToOffset("output_indices" + x, output) + ";\n";
ss << "let index_a" + x + " = offset_a" + x + " / 4u;\n";
ss << "let index_b" + x + " = offset_b" + x + " / 4u;\n";
ss << "let index_c" + x + " = offset_c" + x + " / 4u;\n";
ss << "let component_a" + x + " = offset_a" + x + " % 4u;\n";
ss << "let component_b" + x + " = offset_b" + x + " % 4u;\n";
ss << "let component_c" + x + " = offset_c" + x + " % 4u;\n";
ss << rest_str + "[" + x + "] = " + type_cast + "(" + expression(a_expression, b_expression, c_expression) + ");\n";
return ss.str();
};

std::string assignment;
if (!is_broadcast_) {
assignment = output.SetByOffset(
"global_idx",
expression(a_input.GetByOffset("global_idx"), b_input.GetByOffset("global_idx"), c_input.GetByOffset("global_idx")));

} else {
const auto& c_indices = shader.AddIndices("c_indices");
const auto& a_indices = shader.AddIndices("a_indices");
const auto& b_indices = shader.AddIndices("b_indices");
const auto& output_indices = shader.AddIndices("output_indices");

auto single_assignment =
[&expression, &output_indices, &a_indices, &b_indices, &c_indices](
const std::string& rest_str, const std::string& x, const std::string& type_cast = "")
-> const auto {
const std::string a_expression = "a_data[index_a" + x + "][component_a" + x + "]";
const std::string b_expression = "b_data[index_b" + x + "][component_b" + x + "]";
const std::string c_expression = "bool(c_data[index_c" + x + "] & (0xffu << (component_c" + x + " * 8)))";

std::ostringstream ss;
ss.imbue(std::locale::classic());
ss << "let output_indices" + x + " = " << output_indices.OffsetToIndices("global_idx * 4u + " + x + "u") << ";\n";
ss << "let offset_a" + x + " = " + a_indices.BroadcastedIndicesToOffset("output_indices" + x, output_indices) + ";\n";
ss << "let offset_b" + x + " = " + b_indices.BroadcastedIndicesToOffset("output_indices" + x, output_indices) + ";\n";
ss << "let offset_c" + x + " = " + c_indices.BroadcastedIndicesToOffset("output_indices" + x, output_indices) + ";\n";
ss << "let index_a" + x + " = offset_a" + x + " / 4u;\n";
ss << "let index_b" + x + " = offset_b" + x + " / 4u;\n";
ss << "let index_c" + x + " = offset_c" + x + " / 4u;\n";
ss << "let component_a" + x + " = offset_a" + x + " % 4u;\n";
ss << "let component_b" + x + " = offset_b" + x + " % 4u;\n";
ss << "let component_c" + x + " = offset_c" + x + " % 4u;\n";
ss << rest_str + "[" + x + "] = " + type_cast + "(" + expression(a_expression, b_expression, c_expression) + ");\n";
return ss.str();
};

if (Outputs()[0].tensor->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_BOOL) {
assignment =
"var data = vec4<u32>(0); \n" +
Expand Down Expand Up @@ -157,34 +138,41 @@ Status Where::ComputeInternal(ComputeContext& context) const {
WhereProgram program{is_broadcast};
program
.CacheHint(is_broadcast)
.AddInputs({{cond_tensor, ProgramTensorMetadataDependency::Rank, component},
{x_tensor, ProgramTensorMetadataDependency::Rank, component},
{y_tensor, ProgramTensorMetadataDependency::Rank, component}})
.AddOutputs({{output_tensor,
ProgramTensorMetadataDependency::Rank |
ProgramTensorMetadataDependency::Type,
component}})
.SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
.AddInputs({{cond_tensor, ProgramTensorMetadataDependency::None, {(cond_shape.Size() + 3) / 4}, 4},
{x_tensor, ProgramTensorMetadataDependency::None, {(x_shape.Size() + 3) / 4}, 4},
{y_tensor, ProgramTensorMetadataDependency::None, {(y_shape.Size() + 3) / 4}, 4}})
.AddOutput({output_tensor, ProgramTensorMetadataDependency::Type, {vec_size}, 4})
.AddUniformVariables({
{static_cast<uint32_t>(vec_size)},
});
if (is_broadcast) {
program
.AddIndices(cond_shape)
.AddIndices(x_shape)
.AddIndices(y_shape)
.AddIndices(output_tensor->Shape());
}
return context.RunProgram(program);
}

#define WEBGPU_TRANSPOSE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS, TYPE) \
ONNX_OPERATOR_KERNEL_EX( \
OP_TYPE, kOnnxDomain, VERSION, kWebGpuExecutionProvider, \
KernelDefBuilder().TypeConstraint("T", TYPE), \
KERNEL_CLASS);

#define WEBGPU_TRANSPOSE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, TYPE) \
ONNX_OPERATOR_VERSIONED_KERNEL_EX( \
OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kWebGpuExecutionProvider, \
KernelDefBuilder().TypeConstraint("T", TYPE), \
KERNEL_CLASS);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Where,
kOnnxDomain,
9, 15,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", WebGpuSupportedFloatTypes()),
Where);

WEBGPU_TRANSPOSE_VERSIONED_KERNEL(Where, 9, 15, Where, WebGpuSupportedFloatTypes())
WEBGPU_TRANSPOSE_KERNEL(Where, 16, Where, WebGpuSupportedFloatTypes())
ONNX_OPERATOR_KERNEL_EX(
Where,
kOnnxDomain,
16,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", WebGpuSupportedFloatTypes()),
Where);

} // namespace webgpu
} // namespace onnxruntime

0 comments on commit 49bc1af

Please sign in to comment.