Skip to content

Commit

Permalink
limit sm
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Jul 22, 2024
1 parent 6f399a6 commit 1f6738e
Showing 1 changed file with 20 additions and 19 deletions.
39 changes: 20 additions & 19 deletions onnxruntime/test/python/transformers/test_mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,15 +150,16 @@ def get_provider_support_info(provider: str, use_kv_cache: bool):
return device, dtype, formats


def has_cuda_support():
def get_compute_capability():
if torch.cuda.is_available() and "CUDAExecutionProvider" in onnxruntime.get_available_providers():
major, _ = torch.cuda.get_device_capability()
return major >= 6
return False
major, minor = torch.cuda.get_device_capability()
sm = major * 10 + minor
return sm
return 0


def no_kv_cache_test_cases(provider: str, comprehensive: bool):
if provider == "CUDAExecutionProvider" and not has_cuda_support():
if provider == "CUDAExecutionProvider" and get_compute_capability() < 60:
return
yield

Expand Down Expand Up @@ -221,7 +222,7 @@ def no_kv_cache_test_cases(provider: str, comprehensive: bool):


def kv_cache_test_cases(provider: str, comprehensive: bool):
if provider == "CUDAExecutionProvider" and not has_cuda_support():
if provider == "CUDAExecutionProvider" and get_compute_capability() < 60:
return
yield

Expand Down Expand Up @@ -292,7 +293,7 @@ def mha_test_cases(provider: str, comprehensive: bool):


def no_kv_cache_multi_thread_test_cases(provider: str, comprehensive: bool):
if provider == "CUDAExecutionProvider" and not has_cuda_support():
if provider == "CUDAExecutionProvider" and get_compute_capability() < 60:
return
yield

Check warning

Code scanning / CodeQL

Unreachable code Warning test

This statement is unreachable.

Expand Down Expand Up @@ -331,7 +332,7 @@ def no_kv_cache_multi_thread_test_cases(provider: str, comprehensive: bool):


def kv_cache_multi_thread_test_cases(provider: str, comprehensive: bool):
if provider == "CUDAExecutionProvider" and not has_cuda_support():
if provider == "CUDAExecutionProvider" and get_compute_capability() < 60:
return
yield

Check warning

Code scanning / CodeQL

Unreachable code Warning test

This statement is unreachable.

Expand Down Expand Up @@ -473,14 +474,14 @@ def parity_check_mha_multi_threading(
config = test_inputs[0]["config"]
# For now, MHA CUDA kernel does not support causal so skip such test cases.
if config.causal and config.provider == "CUDAExecutionProvider":
return
return None
# Some kernel does not support certain input format.
if sdpa_kernel not in [
SdpaKernel.DEFAULT,
SdpaKernel.FLASH_ATTENTION,
SdpaKernel.EFFICIENT_ATTENTION,
] and config.input_format in [InputFormats.Q_KV_BSNH_BSN2H]:
return
return None
if verbose:
print(f"create a shared session with {vars(config)}")
onnx_model_str = create_multi_head_attention_onnx_model(config, use_symbolic_shape=True)
Expand Down Expand Up @@ -582,6 +583,7 @@ def check_parity_with_config(i: int):
except AssertionError as e:
print(f"Failed with {vars(config)}: {e}")
return e

if verbose:
print(f"Passed: {vars(config)}")
return None
Expand Down Expand Up @@ -630,19 +632,18 @@ def run_mha_cuda_multi_threading(self, spda_kernel):
def test_mha_cuda_multi_threading(self):
self.run_mha_cuda_multi_threading(SdpaKernel.DEFAULT)

def test_mha_cuda_multi_threading_flash(self):
self.run_mha_cuda_multi_threading(SdpaKernel.FLASH_ATTENTION)

def test_mha_cuda_multi_threading_efficient(self):
self.run_mha_cuda_multi_threading(SdpaKernel.EFFICIENT_ATTENTION)

def test_mha_cuda_multi_threading_trt(self):
self.run_mha_cuda_multi_threading(
SdpaKernel.TRT_FUSED_ATTENTION
| SdpaKernel.TRT_FLASH_ATTENTION
| SdpaKernel.TRT_CROSS_ATTENTION
| SdpaKernel.TRT_CAUSAL_ATTENTION
)
sm = get_compute_capability()
if sm in [75, 80, 86, 89]:
self.run_mha_cuda_multi_threading(
SdpaKernel.TRT_FUSED_ATTENTION
| SdpaKernel.TRT_FLASH_ATTENTION
| SdpaKernel.TRT_CROSS_ATTENTION
| SdpaKernel.TRT_CAUSAL_ATTENTION
)


if __name__ == "__main__":
Expand Down

0 comments on commit 1f6738e

Please sign in to comment.