Skip to content

Commit

Permalink
Use threads_per_warp=16 for 06-fused-attention.py (#1146)
Browse files Browse the repository at this point in the history
To enable the DPAS on 06 tutorial.
  • Loading branch information
chengjunlu authored May 17, 2024
1 parent 5e2256f commit b8c32f0
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion python/tutorials/06-fused-attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,7 @@ def forward(ctx, q, k, v, causal, sm_scale):
N_CTX=q.shape[2], #
HEAD_DIM=HEAD_DIM_K, #
STAGE=stage, #
threads_per_warp=16, #
**extra_kern_args)

ctx.save_for_backward(q, k, v, o, M)
Expand Down Expand Up @@ -514,7 +515,8 @@ def backward(ctx, do):
BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, #
HEAD_DIM=ctx.HEAD_DIM, #
num_warps=NUM_WARPS, #
num_stages=NUM_STAGES #
num_stages=NUM_STAGES, #
threads_per_warp=16 #
)

return dq, dk, dv, None, None
Expand Down

0 comments on commit b8c32f0

Please sign in to comment.