Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add
GroupedQueryAttention
layer (keras-team#18488)
* add: gqa barebone * update: avoid abbr. in signature * update: use fully-spelled snake_case * add: initializers, regularizers, constraint * update: mask to attention_mask * add: `dropout` * add: `softmax` from `ops` * update: use softmax for masking Similar to multi-head attention * add: import in `__init__` * update: filename * add: `compute_output_shape` * add: `query-key-value` & `causal` mask * update: use `EinsumDense` and `einsum` * update: remove `Dense` import * fix: `__init__` in layer for `isort` * update: docstring * add: simple test * update: `support_masking` False in test * fix: error due to query & key-value seq_len mismatch * Revert "update: `support_masking` False in test" This reverts commit b361be5. * update: `use_bias` = True as default * add: `support_masking` * add: more tests * update: code format with `black` * remove: high dim attention test GQA does not support high dim attention yet * fix: remove undefined arg `num_head` * add: shape mismatch test * add: initializer test * update: `softmax` for mask propagation * update: code format * add: mask propagation test * add: masking test * add: correctness test * update: output shape test for mqa, gqa & mha mqa -> multi query attention gqa -> grouped query attention mha -> multi head attention * fix: code format for `isort` * add: divisible error check What happens if `num_query_heads` is not divisible by `num_key_value_heads`? * add: shape of `attention_scores` * update: different letters for query and key-value heads use different letters to denote `num_query_heads` vs `num_key_value_heads`
- Loading branch information