Skip to content

Commit

Permalink
fix summation index
Browse files Browse the repository at this point in the history
  • Loading branch information
n-gao committed Aug 10, 2024
1 parent 02b5af0 commit 0fab2f8
Showing 1 changed file with 27 additions and 23 deletions.
50 changes: 27 additions & 23 deletions folx/jvp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import functools
import logging
from typing import TypeVar
from typing import Sequence, TypeVar

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -34,7 +34,7 @@
R = TypeVar('R', bound=PyTree[Array])


def sparse_to_dense_sum_jvp(
def sparse_sum_jvp(
laplace_args: FwdLaplArgs,
axes: Axes,
kwargs,
Expand All @@ -56,14 +56,14 @@ def sparse_to_dense_sum_jvp(
# for the sparse jacobian, we will use a segment sum
out_shape = y.shape
out_size = np.prod(out_shape, dtype=int)
jac_axes = tuple(i + (i >= JAC_DIM) for i in axes) + (JAC_DIM,)
non_reduced_axes = tuple(i for i in range(x_jac.ndim) if i not in jac_axes)
reduced_dims = (JAC_DIM,) + tuple(i + (i >= JAC_DIM) for i in axes)
non_reduced_axes = tuple(i for i in range(x_jac.ndim) if i not in reduced_dims)
assert x_jac.x0_idx is not None
axes_order = jac_axes + non_reduced_axes
axes_order = reduced_dims + non_reduced_axes

def compute_outdeps(arr: np.ndarray, axis: int):
A_sorted = np.sort(arr, axis=axis)
max_out = (np.diff(A_sorted, axis=axis) > 0).sum().max() + 1
max_out = (np.diff(A_sorted, axis=axis) > 0).sum(axis).max() + 1
# move axis to back so we can use vectorize
A_sorted = np.moveaxis(A_sorted, axis, -1)
with jax.ensure_compile_time_eval():
Expand Down Expand Up @@ -141,7 +141,7 @@ def sparse_jvp(

# Summation
if FunctionFlags.SUMMATION in flags:
return sparse_to_dense_sum_jvp(laplace_args, axes, kwargs, sparsity_threshold)
return sparse_sum_jvp(laplace_args, axes, kwargs, sparsity_threshold)

grad_tan, out_mask = get_jacobian_for_reduction(laplace_args.jacobian, axes)
if out_mask.shape[JAC_DIM] > sparsity_threshold:
Expand Down Expand Up @@ -508,15 +508,17 @@ def parallel_jvp(args: FwdLaplArgs, kwargs):
def one_by_one_jvp(args: FwdLaplArgs, kwargs) -> tuple[Array, FwdJacobian, Array]:
y, grad, lapl = None, None, None
for i, x in enumerate(args.arrays):
static_args = list(args.x)

def merged_fwd(arg: Array):
return fwd(
*merge(
tuple(static_args[:i] + [arg] + static_args[i + 1 :]),
extra_args,
)
)
static_args = tuple(args.x)
n_static = len(static_args) - 1
new_extra = extra_args + static_args[:i] + static_args[i + 1 :]

def new_merge(args: Sequence[Array], extra: Sequence[Array]):
assert len(args) == 1, 'Only one argument is expected.'
extra, static = extra[:-n_static], extra[-n_static:]
return merge(tuple(static[:i] + (args[0],) + static[i:]), extra)

def merged_fwd(*args: Array):
return fwd(*new_merge(args, new_extra))

merged_fwd.__name__ = fwd.__name__

Expand All @@ -530,8 +532,8 @@ def _jvp(args: FwdLaplArgs, kwargs):
return sparse_jvp(
merged_fwd,
args,
extra_args,
merge,
new_extra,
new_merge,
axes=in_axes,
kwargs=kwargs,
sparsity_threshold=sparsity_threshold,
Expand All @@ -548,11 +550,13 @@ def _jvp(args: FwdLaplArgs, kwargs):
return y, grad, lapl # type: ignore

def jvp(args: FwdLaplArgs, kwargs) -> tuple[Array, FwdJacobian, Array]:
if (
(not args.any_jacobian_weak)
or (FunctionFlags.INDEXING in flags)
or (in_axes == ())
or (len(args) == 1)
# If everything is dense, we do it in parallel. Otherwise, we call the simpler code
# if only a single argument has a jacobian or it is an elementwise/indexing operation.
if (not args.any_jacobian_weak) or (
args.all_jacobian_weak
and (
(FunctionFlags.INDEXING in flags) or (in_axes == ()) or (len(args) == 1)
)
):
return parallel_jvp(args, kwargs)
else:
Expand Down

0 comments on commit 0fab2f8

Please sign in to comment.