Skip to content

Commit

Permalink
faster sum jvp
Browse files Browse the repository at this point in the history
  • Loading branch information
n-gao committed Aug 12, 2024
1 parent 0fab2f8 commit 67bf7c9
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions folx/jvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
broadcast_dim,
broadcast_except,
broadcast_mask_to_jacobian,
compact_repeated_dims_except,
extend_jacobians,
get_jacobian_for_reduction,
np_concatenate_brdcast,
Expand Down Expand Up @@ -89,11 +90,22 @@ def compute_outdeps(arr: np.ndarray, axis: int):

# segment sum on the jacobian
jac = jnp.transpose(x_jac.data, axes_order).reshape(-1, out_size)
out_jac = jax.vmap(
functools.partial(jax.ops.segment_sum, num_segments=out_dim),
in_axes=(1, 1),
out_axes=1,
)(jac, idx)
jac = jac.reshape(-1, *out_shape)
idx = idx.reshape(-1, *out_shape)
idx, repeated_dims = compact_repeated_dims_except(idx, 0)
vmapped_axes = tuple(i for i in range(idx.ndim) if i not in repeated_dims and i > 0)
new_order = (0, *vmapped_axes, *repeated_dims)
inv_order = np.argsort(new_order)
idx = np.transpose(idx, new_order)[..., *(0,) * len(repeated_dims)]
jac = jnp.transpose(jac, new_order)
jac_in_shape = jac.shape
idx = idx.reshape(idx.shape[0], -1)
jac = jac.reshape(*idx.shape[:2], -1)
seg_sum = functools.partial(jax.ops.segment_sum, num_segments=out_dim)
out_jac = jax.vmap(seg_sum, in_axes=1, out_axes=1)(jac, idx).reshape(
out_dim, *jac_in_shape[1:]
)
out_jac = np.transpose(out_jac, inv_order)
out_jac = out_jac.reshape(out_dim, *out_shape)
return y, FwdJacobian(out_jac, idx_out), y_lapl

Expand Down

0 comments on commit 67bf7c9

Please sign in to comment.