Skip to content

Commit

Permalink
remove a redundant mask; more tolerance for v_vjp in test
Browse files Browse the repository at this point in the history
  • Loading branch information
ae-foster committed Oct 30, 2024
1 parent c6fda6a commit 13163de
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 9 deletions.
6 changes: 3 additions & 3 deletions folx/experimental/pallas/attention/mhsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ def mhsa_kernel(
q_mask = pl.load(mask_ref, (q_slice,))
square_mask = q_mask[:, None] * kv_mask[None, :]
# Forward pass
q = jnp.where(q_mask[:, None], q_ref[:, :], 0.0)
k = jnp.where(kv_mask[:, None], k_ref[:, :], 0.0)
v = jnp.where(kv_mask[:, None], v_ref[:, :], 0.0)
q = q_ref[:, :]
k = k_ref[:, :]
v = v_ref[:, :]
s = jnp.where(square_mask, pl.dot(q, k, trans_b=True), -big_number(q.dtype))
p = jax.nn.softmax(s, axis=1)
o = pl.dot(p, v)
Expand Down
9 changes: 3 additions & 6 deletions test/experimental/pallas/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,23 +119,20 @@ def test_vjp(rng, batch_dim, sequence_dim, num_heads, head_dim, max_sequence, wi
jax_o, jax_mhsa_vjp_fn = jax.vjp(jax_fn, q, k, v)
jax_q_vjp, jax_k_vjp, jax_v_vjp = jax_mhsa_vjp_fn(o_vjp)

print("ours", mask_array(k_vjp, mask))
print("jax", mask_array(jax_k_vjp, mask))
print("ref", mask_array(ref_k_vjp, mask))
assert jnp.allclose(mask_array(o, mask), mask_array(ref_o, mask), atol=1e-6)
assert jnp.allclose(mask_array(q_vjp, mask), mask_array(ref_q_vjp, mask), atol=1e-6)
assert jnp.allclose(mask_array(k_vjp, mask), mask_array(ref_k_vjp, mask), atol=1e-6)
assert jnp.allclose(mask_array(v_vjp, mask), mask_array(ref_v_vjp, mask))
assert jnp.allclose(mask_array(v_vjp, mask), mask_array(ref_v_vjp, mask), atol=1e-6)

assert jnp.allclose(mask_array(o, mask), mask_array(jax_o, mask), atol=1e-6)
assert jnp.allclose(mask_array(q_vjp, mask), mask_array(jax_q_vjp, mask), atol=1e-6)
assert jnp.allclose(mask_array(k_vjp, mask), mask_array(jax_k_vjp, mask), atol=1e-6)
assert jnp.allclose(mask_array(v_vjp, mask), mask_array(jax_v_vjp, mask))
assert jnp.allclose(mask_array(v_vjp, mask), mask_array(jax_v_vjp, mask), atol=1e-6)

assert jnp.allclose(mask_array(ref_o, mask), mask_array(jax_o, mask), atol=1e-6)
assert jnp.allclose(mask_array(ref_q_vjp, mask), mask_array(jax_q_vjp, mask), atol=1e-6)
assert jnp.allclose(mask_array(ref_k_vjp, mask), mask_array(jax_k_vjp, mask), atol=1e-6)
assert jnp.allclose(mask_array(ref_v_vjp, mask), mask_array(jax_v_vjp, mask))
assert jnp.allclose(mask_array(ref_v_vjp, mask), mask_array(jax_v_vjp, mask), atol=1e-6)


@pytest.mark.parametrize(
Expand Down

0 comments on commit 13163de

Please sign in to comment.