Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Multi-Head Attention support for Vitis #1163

Open
wants to merge 65 commits into
base: main
Choose a base branch
from

Conversation

rianbrooksflynn
Copy link

Description

This PR adds support for Multi-Head Attention using either Keras or PyTorch with the Vitis backend in io_parallel mode.

Tests have been added for both Keras and Pytorch parsing.

Credit is due to @Ethan0Jiang and @LostEcho365 (Zhixing Jiang and Dennis Yin) for their original implementation and Keras parsing support; my contributions were implementing PyTorch support and adding unit tests. (Here's a link to their pre-print.) The original code authors have given permission for their code to be merged into hls4ml.

There are some important notes for PyTorch (TODO: add documentation to this effect):

  • Need to set batch_first=True when instantiating nn.MultiheadAttention so that the inputs match up ((batch_size, seq_len, embed_dim) instead of (seq_len, batch_size, embed_dim)).
  • Need to set channels_last_conversion='off' when calling config_from_pytorch_model() since batch-first PyTorch and Keras use the same input shape.
  • Keras lets you call MultiHeadAttention using just two inputs (or even just one input for self-attention), but PyTorch insists that you give it all three of query, key, and value; hls4ml currently only supports the case where key and value are the same; thus, you must give PyTorch the same data for the second input and the third input.

Type of change

  • New feature (non-breaking change which adds functionality)
  • A new research paper code implementation

Tests

Two unit tests added: test/pytest/test_multiheadattention.py and test/pytest/test_multiheadattention_pytorch.py

Checklist

  • I have read the guidelines for contributing.
  • I have commented my code, particularly in hard-to-understand areas.
  • I have made corresponding changes to the documentation.
  • My changes generate no new warnings.
  • I have installed and run pre-commit on the files I edited or added.
  • I have added tests that prove my fix is effective or that my feature works.

@Ethan0Jiang
Copy link

Thank you so much for merging it to the main!

@rianbrooksflynn
Copy link
Author

pre-commit.ci autofix

@JanFSchulte JanFSchulte added the please test Trigger testing by creating local PR branch label Jan 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
please test Trigger testing by creating local PR branch
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants