Skip to content

Commit

Permalink
[ET-VK] Fixing conv2d dw incorrect output when stride != dilation iss…
Browse files Browse the repository at this point in the history
…ue. (#7628)

Pull Request resolved: #7595

This diff moves current implementation of conv2d dw as a special case when stride equals dilation in the Vulkan backend of Executorch, since that's the only time this kind of caching is possible.

If stride does not equal dilation the old implementation is used.

Additional test cases are added to ensure computation is correct when stride != dilation.
ghstack-source-id: 261183385
@exported-using-ghexport

Differential Revision: [D67908916](https://our.internmc.facebook.com/intern/diff/D67908916/)

Co-authored-by: Vivek Trivedi <5340687+trivedivivek@users.noreply.github.com>
  • Loading branch information
pytorchbot and trivedivivek authored Jan 13, 2025
1 parent 8cd3afd commit d229513
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 9 deletions.
43 changes: 43 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

#define TILE_SIZE ${TILE_SIZE}

#define STRIDE_EQ_DILATION ${STRIDE_EQ_DILATION}

#define BATCH_SIZE_X ${BATCH_SIZE_X}

#define BATCH_SIZE_Y ${BATCH_SIZE_Y}
Expand All @@ -40,6 +42,8 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
* Computes a depthwise convolution. Each shader invocation calculates the
* output at a single output location.
*/

#if STRIDE_EQ_DILATION
void main() {
// x and y are divided by batch size to determine 3d position
// since work size is calculated by x * ((y + B_Y - 1) / B_Y) * z
Expand Down Expand Up @@ -121,3 +125,42 @@ void main() {
}
}
}

#else
void main() {
const uint div_by_x = gl_GlobalInvocationID.x / out_limits.x;
const ivec3 pos = ivec3(
gl_GlobalInvocationID.x % out_limits.x,
div_by_x % out_limits.y,
div_by_x / out_limits.y);

if (any(greaterThanEqual(pos, out_limits))) {
return;
}

// Compute the index of the top-left element of the overlay region. Negative
// indices indicate that the top-left element is in a region added by padding.
const ivec2 ipos = pos.xy * stride - padding;

// Compute the start and end of the input indices to load. Padding is assumed
// to be constant 0 padding, so any reads from the padding region is skipped.
const ivec2 start = ipos;
const ivec2 end = ipos + overlay_region.xy;

VEC4_T sum = texelFetch(t_bias, ivec2(pos.z, 0), 0);
int kx = 0;
for (int y = start.y, i = 0; i < TILE_SIZE; y += dilation.y, i++) {
for (int x = start.x, j = 0; j < TILE_SIZE; x += dilation.x, j++) {
// The weight kernel was rearranged such that every NxN filter is
// flattened to fit in one row. Each filter was then stacked on top of
// each other vertically.
const vec4 in_texel = texelFetch(t_in, ivec3(x, y, pos.z), 0);
sum = fma(in_texel, texelFetch(t_kernel, ivec2(kx, pos.z), 0), sum);
kx++;
}
}

imageStore(t_out, pos, op(sum, out_min, out_max));
}

#endif
13 changes: 13 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ conv2d_dw_output_tile:
TILE_SIZE: 3
BATCH_SIZE_X: 4
BATCH_SIZE_Y: 2
STRIDE_EQ_DILATION: 0
generate_variant_forall:
DTYPE:
- VALUE: half
Expand All @@ -25,3 +26,15 @@ conv2d_dw_output_tile:
- NAME: conv2d_dw_output_tile_5x5_clamp
OPERATOR: clamp(X, A, B)
TILE_SIZE: 5
- NAME: conv2d_dw_sed_output_tile_3x3
STRIDE_EQ_DILATION: 1
- NAME: conv2d_dw_sed_output_tile_3x3_clamp
OPERATOR: clamp(X, A, B)
STRIDE_EQ_DILATION: 1
- NAME: conv2d_dw_sed_output_tile_5x5
TILE_SIZE: 5
STRIDE_EQ_DILATION: 1
- NAME: conv2d_dw_sed_output_tile_5x5_clamp
OPERATOR: clamp(X, A, B)
TILE_SIZE: 5
STRIDE_EQ_DILATION: 1
43 changes: 34 additions & 9 deletions backends/vulkan/runtime/graph/ops/impl/Convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,17 @@ vkapi::ShaderInfo get_conv2d_shader(
const bool prepack_weights,
const Conv2dMethod method,
const ValueRef weight,
const bool clamp_out = false) {
const bool clamp_out = false,
const bool stride_equals_dilation = false) {
std::string kernel_name;
kernel_name.reserve(kShaderNameReserve);
switch (method) {
case Conv2dMethod::Depthwise:
kernel_name = "conv2d_dw";
if (!prepack_weights) {
if (stride_equals_dilation) {
kernel_name += "_sed";
}
const auto& weight_sizes = graph.get_tref(weight)->sizes;
if (weight_sizes.at(2) == 3 && weight_sizes.at(3) == 3) {
kernel_name += "_output_tile_3x3";
Expand Down Expand Up @@ -286,22 +290,37 @@ Conv2dMethod get_conv2d_method(
return Conv2dMethod::SlidingWindow;
}

utils::uvec2 get_conv2d_dw_dispatch_divisor(
const std::vector<int64_t>& weight_sizes) {
if (weight_sizes.at(2) == 3 && weight_sizes.at(3) == 3) {
return {4u, 2u};
}
if (weight_sizes.at(2) == 5 && weight_sizes.at(3) == 5) {
return {4u, 2u};
}
return {4u, 2u};
}

utils::uvec3 create_conv2d_global_wg_size(
ComputeGraph& graph,
const Conv2dMethod method,
const ValueRef out) {
const ValueRef out,
const ValueRef weight_data,
const bool stride_equals_dilation) {
if (method == Conv2dMethod::Pointwise) {
const utils::uvec3 image_extents = graph.logical_limits_of(out);
return {
utils::div_up(image_extents[0u], 2u),
utils::div_up(image_extents[1u], 2u),
image_extents[2u]};
} else if (method == Conv2dMethod::Depthwise) {
const utils::uvec3 image_extents = graph.logical_limits_of(out);
} else if (method == Conv2dMethod::Depthwise && stride_equals_dilation) {
const utils::uvec3 image_extents = graph.create_global_wg_size(out);
const utils::uvec2 div =
get_conv2d_dw_dispatch_divisor(graph.get_tref(weight_data)->sizes);
return {
utils::div_up(image_extents[0u], 4u),
utils::div_up(image_extents[1u], 2u),
image_extents[2u]};
utils::div_up(image_extents[0], div[0]),
utils::div_up(image_extents[1], div[1]),
image_extents[2]};
} else {
return graph.create_global_wg_size(out);
}
Expand Down Expand Up @@ -364,6 +383,10 @@ void add_conv2d_node(
Conv2dParams extra_params =
create_conv2d_params(graph, weight_data, kernel_params, transposed_val);

const bool stride_equals_dilation =
(kernel_params.stride[0] == kernel_params.dilation[0] &&
kernel_params.stride[1] == kernel_params.dilation[1]);

OutputParams out_params = {out_min_val, out_max_val};

check_conv2d_params(kernel_params, transposed_val);
Expand All @@ -374,9 +397,11 @@ void add_conv2d_node(
/*prepack_weights = */ false,
method,
weight_data,
clamp_out);
clamp_out,
stride_equals_dilation);

utils::uvec3 wg_size = create_conv2d_global_wg_size(graph, method, out);
utils::uvec3 wg_size = create_conv2d_global_wg_size(
graph, method, out, weight_data, stride_equals_dilation);

if (method == Conv2dMethod::Pointwise || method == Conv2dMethod::Depthwise) {
wg_size = {wg_size[0] * wg_size[1] * wg_size[2], 1, 1};
Expand Down
33 changes: 33 additions & 0 deletions backends/vulkan/test/op_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,39 @@ def get_conv_inputs():
[0, 0],
1,
),
(
(1, 4, 234, 234),
(4, 1, 3, 3),
(4,),
[2, 1],
[1, 1],
[1, 1],
False,
[0, 0],
4,
),
(
(1, 4, 234, 234),
(4, 1, 3, 3),
(4,),
[1, 2],
[1, 1],
[1, 1],
False,
[0, 0],
4,
),
(
(1, 4, 234, 234),
(4, 1, 3, 3),
(4,),
[2, 2],
[1, 1],
[1, 1],
False,
[0, 0],
4,
),
]
)
return test_suite
Expand Down

0 comments on commit d229513

Please sign in to comment.