Skip to content

Commit

Permalink
Longer docstring for mha function
Browse files Browse the repository at this point in the history
  • Loading branch information
ae-foster committed Oct 28, 2024
1 parent d6f7fc4 commit 8041ab3
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion folx/experimental/pallas/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,19 @@ def mha(
num_warps: int = 2,
num_stages: int = 2,
) -> jax.Array:
r"""Pallas implementation of masked multi-head attention."""
r"""Pallas implementation of masked multi-head attention.
Note: the dimensions of the tensor inputs to this function must have dimensions that are
powers of 2, and any dimension that will participate in a matrix multiplication must
have dimension at least 16.
By default, when using pallas, we will run the operation in parallel over the batch and head
dimensions of the inputs. This is implemented by creating a pallas grid of the relevant size
and distributing the necessary submatrices to each streaming multiprocessor (SM).
At some point, the sequence length may become too large to run the entire computation for
one head on a single SM. In this case, by changing `q_block_len`, we distribute different
blocks of queries to different SMs.
"""
del input_mask # Only used in the forward Laplacian
batch_len, seq_len, num_heads, head_len = q.shape
q_block_len, kv_block_len = compute_q_and_kv_block_len(seq_len, q_block_len)
Expand Down

0 comments on commit 8041ab3

Please sign in to comment.