diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl index ad4ff245a1..cd385718ce 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl @@ -14,6 +14,8 @@ #define TILE_SIZE ${TILE_SIZE} +#define STRIDE_EQ_DILATION ${STRIDE_EQ_DILATION} + #define BATCH_SIZE_X ${BATCH_SIZE_X} #define BATCH_SIZE_Y ${BATCH_SIZE_Y} @@ -40,6 +42,8 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; * Computes a depthwise convolution. Each shader invocation calculates the * output at a single output location. */ + +#if STRIDE_EQ_DILATION void main() { // x and y are divided by batch size to determine 3d position // since work size is calculated by x * ((y + B_Y - 1) / B_Y) * z @@ -121,3 +125,42 @@ void main() { } } } + +#else +void main() { + const uint div_by_x = gl_GlobalInvocationID.x / out_limits.x; + const ivec3 pos = ivec3( + gl_GlobalInvocationID.x % out_limits.x, + div_by_x % out_limits.y, + div_by_x / out_limits.y); + + if (any(greaterThanEqual(pos, out_limits))) { + return; + } + + // Compute the index of the top-left element of the overlay region. Negative + // indices indicate that the top-left element is in a region added by padding. + const ivec2 ipos = pos.xy * stride - padding; + + // Compute the start and end of the input indices to load. Padding is assumed + // to be constant 0 padding, so any reads from the padding region is skipped. + const ivec2 start = ipos; + const ivec2 end = ipos + overlay_region.xy; + + VEC4_T sum = texelFetch(t_bias, ivec2(pos.z, 0), 0); + int kx = 0; + for (int y = start.y, i = 0; i < TILE_SIZE; y += dilation.y, i++) { + for (int x = start.x, j = 0; j < TILE_SIZE; x += dilation.x, j++) { + // The weight kernel was rearranged such that every NxN filter is + // flattened to fit in one row. Each filter was then stacked on top of + // each other vertically. + const vec4 in_texel = texelFetch(t_in, ivec3(x, y, pos.z), 0); + sum = fma(in_texel, texelFetch(t_kernel, ivec2(kx, pos.z), 0), sum); + kx++; + } + } + + imageStore(t_out, pos, op(sum, out_min, out_max)); +} + +#endif diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.yaml index 9cf6c22c6c..d3672f5ec2 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.yaml @@ -12,6 +12,7 @@ conv2d_dw_output_tile: TILE_SIZE: 3 BATCH_SIZE_X: 4 BATCH_SIZE_Y: 2 + STRIDE_EQ_DILATION: 0 generate_variant_forall: DTYPE: - VALUE: half @@ -25,3 +26,15 @@ conv2d_dw_output_tile: - NAME: conv2d_dw_output_tile_5x5_clamp OPERATOR: clamp(X, A, B) TILE_SIZE: 5 + - NAME: conv2d_dw_sed_output_tile_3x3 + STRIDE_EQ_DILATION: 1 + - NAME: conv2d_dw_sed_output_tile_3x3_clamp + OPERATOR: clamp(X, A, B) + STRIDE_EQ_DILATION: 1 + - NAME: conv2d_dw_sed_output_tile_5x5 + TILE_SIZE: 5 + STRIDE_EQ_DILATION: 1 + - NAME: conv2d_dw_sed_output_tile_5x5_clamp + OPERATOR: clamp(X, A, B) + TILE_SIZE: 5 + STRIDE_EQ_DILATION: 1 diff --git a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp index 64c145fb7e..a7c11cc853 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp @@ -126,13 +126,17 @@ vkapi::ShaderInfo get_conv2d_shader( const bool prepack_weights, const Conv2dMethod method, const ValueRef weight, - const bool clamp_out = false) { + const bool clamp_out = false, + const bool stride_equals_dilation = false) { std::string kernel_name; kernel_name.reserve(kShaderNameReserve); switch (method) { case Conv2dMethod::Depthwise: kernel_name = "conv2d_dw"; if (!prepack_weights) { + if (stride_equals_dilation) { + kernel_name += "_sed"; + } const auto& weight_sizes = graph.get_tref(weight)->sizes; if (weight_sizes.at(2) == 3 && weight_sizes.at(3) == 3) { kernel_name += "_output_tile_3x3"; @@ -286,22 +290,37 @@ Conv2dMethod get_conv2d_method( return Conv2dMethod::SlidingWindow; } +utils::uvec2 get_conv2d_dw_dispatch_divisor( + const std::vector& weight_sizes) { + if (weight_sizes.at(2) == 3 && weight_sizes.at(3) == 3) { + return {4u, 2u}; + } + if (weight_sizes.at(2) == 5 && weight_sizes.at(3) == 5) { + return {4u, 2u}; + } + return {4u, 2u}; +} + utils::uvec3 create_conv2d_global_wg_size( ComputeGraph& graph, const Conv2dMethod method, - const ValueRef out) { + const ValueRef out, + const ValueRef weight_data, + const bool stride_equals_dilation) { if (method == Conv2dMethod::Pointwise) { const utils::uvec3 image_extents = graph.logical_limits_of(out); return { utils::div_up(image_extents[0u], 2u), utils::div_up(image_extents[1u], 2u), image_extents[2u]}; - } else if (method == Conv2dMethod::Depthwise) { - const utils::uvec3 image_extents = graph.logical_limits_of(out); + } else if (method == Conv2dMethod::Depthwise && stride_equals_dilation) { + const utils::uvec3 image_extents = graph.create_global_wg_size(out); + const utils::uvec2 div = + get_conv2d_dw_dispatch_divisor(graph.get_tref(weight_data)->sizes); return { - utils::div_up(image_extents[0u], 4u), - utils::div_up(image_extents[1u], 2u), - image_extents[2u]}; + utils::div_up(image_extents[0], div[0]), + utils::div_up(image_extents[1], div[1]), + image_extents[2]}; } else { return graph.create_global_wg_size(out); } @@ -364,6 +383,10 @@ void add_conv2d_node( Conv2dParams extra_params = create_conv2d_params(graph, weight_data, kernel_params, transposed_val); + const bool stride_equals_dilation = + (kernel_params.stride[0] == kernel_params.dilation[0] && + kernel_params.stride[1] == kernel_params.dilation[1]); + OutputParams out_params = {out_min_val, out_max_val}; check_conv2d_params(kernel_params, transposed_val); @@ -374,9 +397,11 @@ void add_conv2d_node( /*prepack_weights = */ false, method, weight_data, - clamp_out); + clamp_out, + stride_equals_dilation); - utils::uvec3 wg_size = create_conv2d_global_wg_size(graph, method, out); + utils::uvec3 wg_size = create_conv2d_global_wg_size( + graph, method, out, weight_data, stride_equals_dilation); if (method == Conv2dMethod::Pointwise || method == Conv2dMethod::Depthwise) { wg_size = {wg_size[0] * wg_size[1] * wg_size[2], 1, 1}; diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 85732d7701..d32fa71573 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -348,6 +348,39 @@ def get_conv_inputs(): [0, 0], 1, ), + ( + (1, 4, 234, 234), + (4, 1, 3, 3), + (4,), + [2, 1], + [1, 1], + [1, 1], + False, + [0, 0], + 4, + ), + ( + (1, 4, 234, 234), + (4, 1, 3, 3), + (4,), + [1, 2], + [1, 1], + [1, 1], + False, + [0, 0], + 4, + ), + ( + (1, 4, 234, 234), + (4, 1, 3, 3), + (4,), + [2, 2], + [1, 1], + [1, 1], + False, + [0, 0], + 4, + ), ] ) return test_suite