-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add xformers support #20
Comments
Simple op availability assertion: op = xformers.ops.fmha.flash3.FwOp
if op.is_available():
print(f"xformers_ops_fmha_flash3 supported: {HAS_FLASH}") References
memory_efficient_attention.fa3F@0.0.0: available
memory_efficient_attention.fa3B@0.0.0: available |
#23 should fix this |
Looks good to me, but xformers build from sources with FA3 support might trigger recompilation in the existing environment and overlap with previous Flash Attention v3 installation. Me and a colleague @ohwi, found a point of conflict between xformers FA3 Torch custom op wrapper logic and TypeError: fwd(): incompatible function arguments. The following argument types are supported: 1. (arg0: torch.Tensor, arg1: torch.Tensor, arg2: torch.Tensor, arg3: Optional[torch.Tensor], arg4: float, arg5: Optional[torch.Tensor], arg6: Optional[torch.Tensor], arg7: Optional[torch.Tensor], arg8: bool, arg9: int, arg10: int) -> list[torch.Tensor] Our understanding of the conflict:
And there is a code block in Therefore, although xformers prints FLASH3 as available operator, we need to further assert its execution. I made it work with This consideration might be worth creating a proper issue in xformers repo, what do you think @xuzhao9? |
Yes I think it is a valid issue to post to the xformers repo @antferdom |
I took a further look at xformers code, since we are compiling it from source, I think it should use the |
Actually the xformers FA3 build is broken, submitted: facebookresearch/xformers#1157 After this patch, we can install xformer+FA3 and FA3 as separate kernel impls for flash_attention and they won't conflict. |
Hi. I'm Hwigeon, working with @antferdom. If my memory serves me correctly (apologies, it's been a month since I last tried this), I moved to a newer version of FA with fallback because I also couldn't compile FA with xFormers, which led to the problem mentioned above. |
@ohwi Can you please try with facebookresearch/xformers#1157 to see if FA works with xformers? |
Sure, I will try with your solution in a day |
Are we going to use |
@antferdom Yes, we are going to patch xformers until it updates FA3 and CUTLASS, see #61 Linking to CUDA driver is not a problem for us because we have fine-grained control over the CI infra. At compile time, we will first install the NVIDIA driver package, build xformers/FA3 with libcuda link, then purge the driver files. The driver files will be mapped on the H100 CI runners at test/benchmark time. Compile time: https://github.com/pytorch-labs/tritonbench/blob/main/docker/tritonbench-nightly.dockerfile#L47 |
@xuzhao9 alright, thanks for the clarification. It does makes sense for us as well, since we have complete control over the machines. I have use Tritonbench Dockerfile as main reference for us to have almost identical environment as yours, and can validate that trying your patch, xformers works: $ python -c "import xformers._C_flashattention3" |
Add xformers built on source code, similar to fbgemm: https://github.com/facebookresearch/xformers
Make sure fa3 is available.
The text was updated successfully, but these errors were encountered: