Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FA3] Link to cuda library to fix the FA3 extension build #1157

Closed
wants to merge 1 commit into from

Conversation

xuzhao9
Copy link

@xuzhao9 xuzhao9 commented Nov 21, 2024

What does this PR do?

Fix the FA3 extension build by adding the cuda library.

The original flash-attn repo mentioned that -lcuda is required to build and install the FA3 library: https://github.com/Dao-AILab/flash-attention/blob/main/hopper/setup.py#L222C13-L223C31

We need to add the library to xformers too to build the correct so file.

Test Plan:

Before this PR:

$ pip install -e .
# install xformers from source...
$ python -c "import xformers._C_flashattention3"
ImportError: /data/users/xzhao9/tritonbench/submodules/xformers/xformers/_C_flashattention3.so: undefined symbol: cuTensorMapEncodeTiled

After this PR:

$ pip install -e .
# install xformers from source...
$ python -c "import xformers._C_flashattention3"
# success!

Before submitting

  • Did you have fun?
    • Make sure you had fun coding 🙃
  • Did you read the contributor guideline?
  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
    • N/A
  • Did you make sure to update the docs?
    • N/A
  • Did you write any new necessary tests?
    • N/A
  • Did you update the changelog? (if needed)
    • N/A

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 21, 2024
@xuzhao9
Copy link
Author

xuzhao9 commented Nov 21, 2024

The CI workflow failed because the CI runner does not come with libcuda.so installed. To install it, we need to install the NVIDIA driver package: https://github.com/pytorch-labs/tritonbench/blob/main/docker/tritonbench-nightly.dockerfile#L47

Copy link
Contributor

@lw lw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we want this change. Libraries are typically expected to link to the CUDA runtime (libcudart.so), but not link directly to the CUDA driver (libcuda.so). The reason is that the CUDA runtime can be installed in userspace (e.g., conda) or even bundled with the library (either statically or dynamically), whereas the CUDA driver needs to be installed system-wide by an administrator. Linking with the CUDA driver would thus prevent loading xFormers in mis-configured systems, or in systems without GPUs. On the other hand, the CUDA runtime loads the CUDA driver lazily/dynamically, hence only needing it if/when CUDA features are requested.

This is the approach followed by PyTorch and, indeed, they had to go our of their way to avoid linking with the CUDA driver, such as in https://github.com/pytorch/pytorch/blob/ecf3bae40a6f2f0f3b237bde1fc4b2492765ab13/aten/src/ATen/native/cuda/RowwiseScaledMM.cu#L16-L58.

I believe the problem in FlashAttention3 comes from them using CUTLASS, which indeed made an "eager" call to the CUDA driver. This problem has been reported and has been fixed in v3.5.1. Ideally, thus, we should update CUTLASS within FlashAttention3 and then update the submodule within xFormers.

@xuzhao9
Copy link
Author

xuzhao9 commented Nov 21, 2024

Thanks for explaining the details! I will close this PR and wait for upstream FA3 to fix and upgrade CUTLASS.

@xuzhao9 xuzhao9 closed this Nov 21, 2024
@xuzhao9 xuzhao9 deleted the xz9/fix-fa3-build branch November 21, 2024 12:50
facebook-github-bot pushed a commit to pytorch-labs/tritonbench that referenced this pull request Nov 21, 2024
Summary:
This is to patch xformers as its FA3 extension build will fail due to lack of linking to libcuda.so: facebookresearch/xformers#1157

Fixes #20

Pull Request resolved: #61

Reviewed By: FindHao

Differential Revision: D66273474

Pulled By: xuzhao9

fbshipit-source-id: 81898ccd005750937ac3cfd639c2303975ef1abe
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants