From b8c32f083047ec45abef3f4810def18ed59808b7 Mon Sep 17 00:00:00 2001 From: chengjunlu Date: Fri, 17 May 2024 12:31:22 +0800 Subject: [PATCH] Use threads_per_warp=16 for 06-fused-attention.py (#1146) To enable the DPAS on 06 tutorial. --- python/tutorials/06-fused-attention.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index 1e39b6a9c3..7eac24b1d7 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -468,6 +468,7 @@ def forward(ctx, q, k, v, causal, sm_scale): N_CTX=q.shape[2], # HEAD_DIM=HEAD_DIM_K, # STAGE=stage, # + threads_per_warp=16, # **extra_kern_args) ctx.save_for_backward(q, k, v, o, M) @@ -514,7 +515,8 @@ def backward(ctx, do): BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, # HEAD_DIM=ctx.HEAD_DIM, # num_warps=NUM_WARPS, # - num_stages=NUM_STAGES # + num_stages=NUM_STAGES, # + threads_per_warp=16 # ) return dq, dk, dv, None, None