Skip to content

Commit

Permalink
Add GroupedQueryAttention layer (keras-team#18488)
Browse files Browse the repository at this point in the history
* add: gqa barebone

* update: avoid abbr. in signature

* update: use fully-spelled snake_case

* add: initializers, regularizers, constraint

* update: mask to attention_mask

* add: `dropout`

* add: `softmax` from `ops`

* update: use softmax for masking

Similar to multi-head attention

* add: import in `__init__`

* update: filename

* add: `compute_output_shape`

* add: `query-key-value` & `causal` mask

* update: use `EinsumDense` and `einsum`

* update: remove `Dense` import

* fix: `__init__` in layer for `isort`

* update: docstring

* add: simple test

* update: `support_masking` False in test

* fix: error due to query & key-value seq_len mismatch

* Revert "update: `support_masking` False in test"

This reverts commit b361be5.

* update: `use_bias` = True as default

* add: `support_masking`

* add: more tests

* update: code format with `black`

* remove: high dim attention test

GQA does not support high dim attention yet

* fix: remove undefined arg `num_head`

* add: shape mismatch test

* add: initializer test

* update: `softmax` for mask propagation

* update: code format

* add: mask propagation test

* add: masking test

* add: correctness test

* update: output shape test for mqa, gqa & mha

mqa -> multi query attention
gqa -> grouped query attention
mha -> multi head attention

* fix: code format for `isort`

* add: divisible error check

What happens if `num_query_heads` is not divisible by `num_key_value_heads`?

* add: shape of `attention_scores`

* update: different letters for query and key-value heads

use different letters to denote `num_query_heads` vs `num_key_value_heads`
  • Loading branch information
awsaf49 authored Oct 22, 2023
1 parent 15926e1 commit a35287a
Show file tree
Hide file tree
Showing 3 changed files with 643 additions and 0 deletions.
1 change: 1 addition & 0 deletions keras/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from keras.layers.activations.softmax import Softmax
from keras.layers.attention.additive_attention import AdditiveAttention
from keras.layers.attention.attention import Attention
from keras.layers.attention.grouped_query_attention import GroupedQueryAttention
from keras.layers.attention.multi_head_attention import MultiHeadAttention
from keras.layers.convolutional.conv1d import Conv1D
from keras.layers.convolutional.conv1d_transpose import Conv1DTranspose
Expand Down
Loading

0 comments on commit a35287a

Please sign in to comment.