Skip to content

Commit

Permalink
[BUG] fix crash on flashinfer backend with cudagraph disabled, when a…
Browse files Browse the repository at this point in the history
…ttention group_size not in [1,2,4,8] (vllm-project#7509)
  • Loading branch information
learninmou authored Aug 21, 2024
1 parent c75363f commit 53328d7
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
7 changes: 5 additions & 2 deletions tests/kernels/test_flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
import torch

NUM_HEADS = [(16, 16), (32, 8), (64, 8)]
NUM_HEADS = [(16, 16), (32, 8), (64, 8), (6, 1)]
HEAD_SIZES = [128, 256]
BLOCK_SIZES = [16, 32]
DTYPES = [torch.float16, torch.bfloat16]
Expand Down Expand Up @@ -123,7 +123,10 @@ def test_flashinfer_decode_with_paged_kv(kv_lens: List[int],

workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
wrapper = flashinfer.\
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD")
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD",
use_tensor_cores=(
(num_query_heads//num_kv_heads) not in (1, 2, 4, 8))
)
wrapper.begin_forward(kv_indptr,
kv_indices,
kv_last_page_lens,
Expand Down
6 changes: 4 additions & 2 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ def _get_decode_wrapper(self):
self.runner.parallel_config))
num_kv_heads = self.runner.model_config.get_num_kv_heads(
self.runner.parallel_config)
use_tensor_cores = num_qo_heads // num_kv_heads >= 4
use_tensor_cores = (num_qo_heads // num_kv_heads) not in \
(1, 2, 4, 8)
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self._get_workspace_buffer(),
"NHD",
Expand Down Expand Up @@ -171,7 +172,8 @@ def graph_capture_get_metadata_for_batch(self, batch_size: int):
self.runner.parallel_config))
num_kv_heads = self.runner.model_config.get_num_kv_heads(
self.runner.parallel_config)
use_tensor_cores = num_qo_heads // num_kv_heads >= 4
use_tensor_cores = (num_qo_heads // num_kv_heads) not in \
(1, 2, 4, 8)
self._graph_decode_wrapper = \
CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
self._graph_decode_workspace_buffer, _indptr_buffer,
Expand Down

0 comments on commit 53328d7

Please sign in to comment.