Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUDA] FusedMHARunnerFP16v2 thread-safe #21420

Merged
merged 5 commits into from
Jul 22, 2024
Merged

Conversation

tianleiwu
Copy link
Contributor

@tianleiwu tianleiwu commented Jul 19, 2024

Description

  • Rewrite FusedMHARunnerFP16v2 to make it thread-safe.
  • Add multi-threading tests

Previously, the kernel parameters params is stored as a member of mha runner, which means that different threads might change the params at the same time and impacts the other threads.

For example, if batch_size and seq_len was changed by another thread to larger values in setup(...), buffer overrun might happen in run(...) because a kernel could read/write memory out of range of allocated buffers.

In new implementation, I change the api and remove mutable member variables to make it thread safe. Below is summary of change:

Before:

class FusedMHARunnerFP16v2::mhaImpl {
   void setup(int seq_len, int batch_size) {
      // change scalar params
   }

   void run(input, output) {
      // change params for input and output pointers
      // launch kernel using params
   }

   Fused_multihead_attention_params_v2 params; // mutable, not thread-safe
}

After:

class FusedMHARunnerFP16v2::FmhaImpl {
   void setup(int seq_len, int batch_size, Fused_multihead_attention_params_v2& params) {
      // change params
   }

   void run(params, input, output) {
      // change params with input and output pointers
      // launch kernel using params
   }
}

Motivation and Context

#18854
#21413

@tianleiwu tianleiwu marked this pull request as draft July 19, 2024 16:34
Comment on lines +464 to +471
def parity_check_mha_multi_threading(
test_inputs: List[Dict],
rtol: float = 1e-3,
atol: float = 1e-3,
sdpa_kernel: int = SdpaKernel.DEFAULT,
max_threads: int = 5,
verbose: bool = False,
):

Check notice

Code scanning / CodeQL

Explicit returns mixed with implicit (fall through) returns Note test

Mixing implicit and explicit returns may indicate an error as implicit returns always return None.
onnxruntime/test/python/transformers/test_mha.py Dismissed Show dismissed Hide dismissed
onnxruntime/test/python/transformers/test_mha.py Dismissed Show dismissed Hide dismissed
@tianleiwu tianleiwu marked this pull request as ready for review July 22, 2024 07:06
@tianleiwu tianleiwu merged commit a6c5e2c into main Jul 22, 2024
95 of 98 checks passed
@tianleiwu tianleiwu deleted the tlwu/thread_safe_attention branch July 22, 2024 17:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants