From 2da80ba81c31579ac80aff3ec91fb11260e74b48 Mon Sep 17 00:00:00 2001 From: trivedivivek <5340687+trivedivivek@users.noreply.github.com> Date: Fri, 3 Jan 2025 16:45:55 -0600 Subject: [PATCH] [ET-VK] [ET-VK] Reduced int precision for all int storage in conv pw op to improve performance. Differential Revision: D67674212 Pull Request resolved: https://github.com/pytorch/executorch/pull/7447 --- .../graph/ops/glsl/conv2d_dw_output_tile.glsl | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl index 21760eca0e..57ae98eb85 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl @@ -32,12 +32,14 @@ ${layout_declare_ubo(8, "float", "out_min", "float", "out_max")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require + /* * Computes a depthwise convolution. Each shader invocation calculates the * output at a single output location. */ void main() { - const ivec3 pos = ivec3(gl_GlobalInvocationID); + const u16vec3 pos = u16vec3(gl_GlobalInvocationID); if (any(greaterThanEqual(pos, out_limits))) { return; @@ -45,22 +47,22 @@ void main() { // Compute the index of the top-left element of the overlay region. Negative // indices indicate that the top-left element is in a region added by padding. - const ivec2 ipos = pos.xy * stride - padding; + const u16vec2 ipos = pos.xy * u16vec2(stride) - u16vec2(padding); // Compute the start and end of the input indices to load. Padding is assumed // to be constant 0 padding, so any reads from the padding region is skipped. - const ivec2 start = ipos; - const ivec2 end = ipos + overlay_region.xy; + const u16vec2 start = ipos; + const u16vec2 end = ipos + u16vec2(overlay_region.xy); - VEC4_T sum = texelFetch(t_bias, ivec2(pos.z, 0), 0); - int kx = 0; - for (int y = start.y, i = 0; i < TILE_SIZE; y += dilation.y, i++) { - for (int x = start.x, j = 0; j < TILE_SIZE; x += dilation.x, j++) { + VEC4_T sum = texelFetch(t_bias, u16vec2(pos.z, 0), 0); + uint16_t kx = uint16_t(0); + for (uint16_t y = start.y, i = uint16_t(0); i < uint16_t(TILE_SIZE); y += uint16_t(dilation.y), i++) { + for (uint16_t x = start.x, j = uint16_t(0); j < uint16_t(TILE_SIZE); x += uint16_t(dilation.x), j++) { // The weight kernel was rearranged such that every NxN filter is // flattened to fit in one row. Each filter was then stacked on top of // each other vertically. - const vec4 in_texel = texelFetch(t_in, ivec3(x, y, pos.z), 0); - sum = fma(in_texel, texelFetch(t_kernel, ivec2(kx, pos.z), 0), sum); + const vec4 in_texel = texelFetch(t_in, u16vec3(x, y, pos.z), 0); + sum = fma(in_texel, texelFetch(t_kernel, u16vec2(kx, pos.z), 0), sum); kx++; } }