diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.cc b/onnxruntime/core/providers/webgpu/tensor/transpose.cc index b620e83843b2f..855826b9298cb 100644 --- a/onnxruntime/core/providers/webgpu/tensor/transpose.cc +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.cc @@ -47,11 +47,11 @@ ONNX_OPERATOR_KERNEL_EX( .TypeConstraint("T", WebGpuSupportedNumberTypes()), Transpose); -const std::string AppendPermFunction(gsl::span perm) { +const std::string AppendPermFunction(gsl::span perm) { std::ostringstream ss; ss.imbue(std::locale::classic()); - ss << "fn perm(i: y_indices_t)->x_indices_t {\n" - " var a: x_indices_t;\n"; + ss << "fn perm(i: output_indices_t)->a_indices_t {\n" + " var a: a_indices_t;\n"; for (auto i = 0; i < perm.size(); ++i) { ss << " a[" << perm[i] << "] = i[" << i << "];\n"; } @@ -60,21 +60,62 @@ const std::string AppendPermFunction(gsl::span perm) { return ss.str(); } +auto SqueezeShape(const gsl::span& shape, const gsl::span& adjusted_perm, InlinedVector& new_shape, InlinedVector& new_perm) { + for (auto i = 0; i < shape.size(); ++i) { + if (shape[i] != 1) { + new_shape.push_back(shape[i]); + } + if (shape[adjusted_perm[i]] != 1) { + new_perm.push_back(adjusted_perm[i]); + } + } +}; + Status TransposeProgram::GenerateShaderCode(ShaderHelper& shader) const { - const auto& input = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); - const auto& output = shader.AddOutput("y", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); - shader.AppendImplementation(AppendPermFunction(this->perm_)); - shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size"), - " let indices = ", output.OffsetToIndices("global_idx"), - ";\n" - " let x_indices = perm(indices); \n" - " ", - output.SetByOffset("global_idx", input.GetByIndices("x_indices"))); + const auto& input = shader.AddInput("a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + + if (use_shared_) { + const auto tile_size = std::to_string(tile_size_); + shader.AppendImplementation("var tile : array, " + tile_size + ">;\n"); + shader.SetMainFunctionBody( + " let stride = (uniforms.output_shape[1] - 1) / " + tile_size + + " + 1;\n" + " let workgroup_id_x = workgroup_idx % stride;\n" + " let workgroup_id_y = workgroup_idx / stride;\n" + " let input_col = workgroup_id_y * " + + tile_size + + "u + local_id.x;\n" + " let input_row = workgroup_id_x * " + + tile_size + + "u + local_id.y;\n" + " if (input_row < uniforms.a_shape[0] && input_col < uniforms.a_shape[1]) {\n" + " tile[local_id.y][local_id.x] = " + + input.GetByIndices("a_indices_t(input_row, input_col)") + + ";\n" + " }\n" + " workgroupBarrier();\n" + " let output_col = workgroup_id_x * " + + tile_size + + "u + local_id.x;\n" + " let output_row = workgroup_id_y * " + + tile_size + + "u + local_id.y;\n" + " if (output_row < uniforms.output_shape[0] && output_col < uniforms.output_shape[1]) {\n" + + output.SetByIndices("output_indices_t(output_row, output_col)", "tile[local_id.x][local_id.y]") + "\n } \n"); + } else { + shader.AppendImplementation(AppendPermFunction(this->perm_)); + shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size"), + " let indices = ", output.OffsetToIndices("global_idx"), + ";\n" + " let x_indices = perm(indices); \n" + " ", + output.SetByOffset("global_idx", input.GetByIndices("x_indices"))); + } return Status::OK(); } Status Transpose::ComputeInternal(ComputeContext& context) const { - // TODO: there is an optimized version of transpose to port. const auto* input_tensor = context.Input(0); const TensorShape& input_shape = input_tensor->Shape(); int32_t rank = gsl::narrow_cast(input_shape.NumDimensions()); @@ -86,16 +127,41 @@ Status Transpose::ComputeInternal(ComputeContext& context) const { TensorShape output_shape(output_dims); auto* output_tensor = context.Output(0, output_shape); + InlinedVector new_shape{}; + InlinedVector new_perm{}; + SqueezeShape(input_shape.GetDims(), *p_perm, new_shape, new_perm); + const auto channels_last = new_perm == InlinedVector({2, 3, 1}); + const auto channels_first = new_perm == InlinedVector({3, 1, 2}); + const auto use_shared = (new_shape.size() == 2 && new_perm[0] > new_perm[1]) || channels_last || channels_first; + auto new_input_shape = use_shared ? new_shape : input_shape; + auto new_output_shape = output_dims; + if (use_shared) { + new_input_shape = channels_last + ? InlinedVector({new_shape[0], new_shape[1] * new_shape[2]}) + : channels_first + ? InlinedVector({new_shape[0] * new_shape[1], new_shape[2]}) + : new_shape; + new_output_shape = InlinedVector({new_input_shape[1], new_input_shape[0]}); + } + uint32_t output_size = gsl::narrow_cast(input_tensor->Shape().Size()); - TransposeProgram program{*p_perm}; + const auto tile_size = 16; + TransposeProgram program{*p_perm, use_shared, tile_size}; + if (use_shared) { + program.SetWorkgroupSize(tile_size, tile_size, 1); + } + program .CacheHint(absl::StrJoin(*p_perm, "-")) - .AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) - .AddOutputs({output_tensor}) - .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank, new_input_shape, 1}}) + .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank, new_output_shape, 1}}) + .SetDispatchGroupSize(static_cast((new_output_shape[1] + tile_size - 1) / tile_size), static_cast(((new_output_shape[0] + tile_size - 1) / tile_size))) .AddUniformVariables({ {static_cast(output_size)}, }); + + use_shared ? program.SetDispatchGroupSize(static_cast((new_output_shape[1] + tile_size - 1) / tile_size), static_cast(((new_output_shape[0] + tile_size - 1) / tile_size))) + : program.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); return context.RunProgram(program); } diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.h b/onnxruntime/core/providers/webgpu/tensor/transpose.h index 3ca5674d5dfab..254f935438d96 100644 --- a/onnxruntime/core/providers/webgpu/tensor/transpose.h +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.h @@ -13,8 +13,8 @@ namespace webgpu { class TransposeProgram final : public Program { public: - TransposeProgram(const gsl::span& permutations) - : Program{"Transpose"}, perm_(permutations.begin(), permutations.end()) { + TransposeProgram(const gsl::span& permutations, bool use_shared, const int tile_size) + : Program{"Transpose"}, perm_(permutations.begin(), permutations.end()), use_shared_(use_shared), tile_size_(tile_size) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -22,7 +22,9 @@ class TransposeProgram final : public Program { WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32}); private: - InlinedVector perm_; + InlinedVector perm_; + const bool use_shared_; + const int tile_size_; }; class Transpose final : public WebGpuKernel, public TransposeBase {