Skip to content

Commit

Permalink
[ET-VK] Modify conv 2d pw op shader and dispatch settings to linearly…
Browse files Browse the repository at this point in the history
… dispatch work accounting for linearity texture to improve performance. (#7501)

Pull Request resolved: #7452

This diff modifies the convolution 2D pointwise op shader and dispatch settings to linearly dispatch work accounting for linearity texture to improve performance.
ghstack-source-id: 260166247
@exported-using-ghexport

Differential Revision: [D67683411](https://our.internmc.facebook.com/intern/diff/D67683411/)

Co-authored-by: Vivek Trivedi <5340687+trivedivivek@users.noreply.github.com>
  • Loading branch information
pytorchbot and trivedivivek authored Jan 4, 2025
1 parent f139e39 commit 6c11356
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
7 changes: 6 additions & 1 deletion backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,12 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
* size is only 1x1, making it easier to re-use loaded texels from t_kernel.
*/
void main() {
const u16vec3 gpos = u16vec3(gl_GlobalInvocationID);
const uint16_t out_limits_y_scaled = uint16_t((out_limits.y + TILE_SIZE - 1) / TILE_SIZE);

const u16vec3 gpos = u16vec3(
gl_GlobalInvocationID.x / (out_limits_y_scaled * out_limits.z),
(gl_GlobalInvocationID.x / out_limits.z) % out_limits_y_scaled,
gl_GlobalInvocationID.x % out_limits.z);

// Output position for TILE_SIZE = 2
// +--------+--------+
Expand Down
4 changes: 4 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/Convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,10 @@ void add_conv2d_node(

utils::uvec3 wg_size = create_conv2d_global_wg_size(graph, method, out);

if (method == Conv2dMethod::Pointwise) {
wg_size = {wg_size[0] * wg_size[1] * wg_size[2], 1, 1};
}

graph.execute_nodes().emplace_back(new DispatchNode(
graph,
shader,
Expand Down

0 comments on commit 6c11356

Please sign in to comment.