Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add xformers support #20

Closed
xuzhao9 opened this issue Oct 28, 2024 · 12 comments
Closed

Add xformers support #20

xuzhao9 opened this issue Oct 28, 2024 · 12 comments

Comments

@xuzhao9
Copy link
Contributor

xuzhao9 commented Oct 28, 2024

Add xformers built on source code, similar to fbgemm: https://github.com/facebookresearch/xformers

Make sure fa3 is available.

@antferdom
Copy link

Simple op availability assertion:

op = xformers.ops.fmha.flash3.FwOp
if op.is_available():
    print(f"xformers_ops_fmha_flash3 supported: {HAS_FLASH}")

References

memory_efficient_attention.fa3F@0.0.0:             available
memory_efficient_attention.fa3B@0.0.0:             available

@xuzhao9
Copy link
Contributor Author

xuzhao9 commented Oct 29, 2024

#23 should fix this

@antferdom
Copy link

antferdom commented Oct 29, 2024

Looks good to me, but xformers build from sources with FA3 support might trigger recompilation in the existing environment and overlap with previous Flash Attention v3 installation.

Me and a colleague @ohwi, found a point of conflict between xformers FA3 Torch custom op wrapper logic and flashattn_hopper_cuda, which led to CUDA errors:

TypeError: fwd(): incompatible function arguments. The following argument types are supported:                                                                                                     1. (arg0: torch.Tensor, arg1: torch.Tensor, arg2: torch.Tensor, arg3: Optional[torch.Tensor], arg4: float, arg5: Optional[torch.Tensor], arg6: Optional[torch.Tensor], arg7: Optional[torch.Tensor], arg8: bool, arg9: int, arg10: int) -> list[torch.Tensor]

Our understanding of the conflict:

  • The current version of fwd function in flashattn_hopper_cuda requires non-optional arguments window_size_left and window_size_right, but xformer registered custom `mha_fwd does not include this update.

And there is a code block in xformers that import flashattn_hopper_cuda as a fallback. This makes only one of xformers or flash-attn available.
See: https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_api.cpp#L463-L475
and
https://github.com/facebookresearch/xformers/blob/68b7fd14df5eb1d2558c52842b4206a14d2d20e9/xformers/ops/fmha/flash3.py#L48-L82

Therefore, although xformers prints FLASH3 as available operator, we need to further assert its execution. I made it work with
flashattn-hopper==3.0.0b1
torch==2.4.1+cu124
xformers==0.0.29

This consideration might be worth creating a proper issue in xformers repo, what do you think @xuzhao9?

@xuzhao9
Copy link
Contributor Author

xuzhao9 commented Oct 29, 2024

Yes I think it is a valid issue to post to the xformers repo @antferdom

@xuzhao9
Copy link
Contributor Author

xuzhao9 commented Nov 20, 2024

I took a further look at xformers code, since we are compiling it from source, I think it should use the xformers._C_flashattention3 plugin at https://github.com/facebookresearch/xformers/blob/6e10bd21ac6fc878657b24684723ccd05e41d385/setup.py#L321C19-L321C46 and therefore not fallback to the pre-installed flashattn_hopper_cuda package. We should treat them separately.

@xuzhao9
Copy link
Contributor Author

xuzhao9 commented Nov 21, 2024

Actually the xformers FA3 build is broken, submitted: facebookresearch/xformers#1157

After this patch, we can install xformer+FA3 and FA3 as separate kernel impls for flash_attention and they won't conflict.

@ohwi
Copy link

ohwi commented Nov 21, 2024

Hi. I'm Hwigeon, working with @antferdom.

If my memory serves me correctly (apologies, it's been a month since I last tried this), I moved to a newer version of FA with fallback because I also couldn't compile FA with xFormers, which led to the problem mentioned above.

@xuzhao9
Copy link
Contributor Author

xuzhao9 commented Nov 21, 2024

@ohwi Can you please try with facebookresearch/xformers#1157 to see if FA works with xformers?

@ohwi
Copy link

ohwi commented Nov 21, 2024

Sure, I will try with your solution in a day

@antferdom
Copy link

Actually the xformers FA3 build is broken, submitted: facebookresearch/xformers#1157
After this patch, we can install xformer+FA3 and FA3 as separate kernel impls for flash_attention and they won't conflict.

Are we going to use libraries=["cuda"] patch workaround until upstream FA3 fix and upgrade CUTLASS? Should we then forward this issue to FlashAttention repo? @lw explained it with great detail showcasing the underlying problematic beyond what we thought. Will try with @ohwi

@xuzhao9
Copy link
Contributor Author

xuzhao9 commented Nov 21, 2024

@antferdom Yes, we are going to patch xformers until it updates FA3 and CUTLASS, see #61

Linking to CUDA driver is not a problem for us because we have fine-grained control over the CI infra. At compile time, we will first install the NVIDIA driver package, build xformers/FA3 with libcuda link, then purge the driver files. The driver files will be mapped on the H100 CI runners at test/benchmark time.

Compile time: https://github.com/pytorch-labs/tritonbench/blob/main/docker/tritonbench-nightly.dockerfile#L47
Run time: https://github.com/pytorch-labs/tritonbench/blob/main/docker/infra/values.yaml#L227

@antferdom
Copy link

@xuzhao9 alright, thanks for the clarification. It does makes sense for us as well, since we have complete control over the machines. I have use Tritonbench Dockerfile as main reference for us to have almost identical environment as yours, and can validate that trying your patch, xformers works:

$ python -c "import xformers._C_flashattention3"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants