Skip to content

Commit

Permalink
use jax.ensure_compile_time_eval() instead of tracer overwrite
Browse files Browse the repository at this point in the history
  • Loading branch information
n-gao committed Mar 15, 2024
1 parent bd75ec5 commit 0def91e
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 47 deletions.
28 changes: 19 additions & 9 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.1.14
hooks:
# Run the linter.
- id: ruff
args: [ --fix ]
# Run the formatter.
- id: ruff-format
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: check-case-conflict
- id: check-toml
- id: check-xml
- id: check-yaml
- id: check-added-large-files
- id: trailing-whitespace

- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.3.2
hooks:
# Run the linter.
- id: ruff
args: [ --fix ]
# Run the formatter.
- id: ruff-format
30 changes: 10 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ This submodule implements the forward laplacian from https://arxiv.org/abs/2307.

## Install

Either clone repo and install locally via
Either clone repo and install locally via
```bash
poetry install
```
Expand Down Expand Up @@ -34,7 +34,7 @@ result.laplacian # tr(H_f(x)) 6
```

## Introduction
To avoid custom wrappers for all of JAX's commands, the forward laplacian is implemented as custom interpreter for Jaxpr.
To avoid custom wrappers for all of JAX's commands, the forward laplacian is implemented as custom interpreter for Jaxpr.
This means if you have a function
```python
class Fn(Protocol):
Expand All @@ -47,7 +47,7 @@ class LaplacianFn(Protocol):
def __call__(self, *args: PyTree[Array]) -> PyTree[FwdLaplArray]:
...
```
where `FwdLaplArray` is a triplet of
where `FwdLaplArray` is a triplet of
```python
FwdLaplArray.x # jax.Array f(x) f(x).shape
FwdLaplArray.jacobian # FwdJacobian J_f(x)
Expand Down Expand Up @@ -77,7 +77,7 @@ But instead of using the [standard evaluation pipeline](https://github.com/googl

### Package structure
The general structure of the package is
* `interpreter.py` contains the evaluation of jaxpr and exported function decorator.
* `interpreter.py` contains the evaluation of jaxpr and exported function decorator.
* `wrapper.py` contains subfunction decorator that maps a function that takes `jax.Array`s to a function that accepts `FwdLaplArray`s instead.
* `wrapped_functions.py` contains a registry of predefined functions as well as utility functions to add new functions to the registry.
* `jvp.py` contains logic for jacobian vector products.
Expand All @@ -86,13 +86,13 @@ The general structure of the package is
* `api.py` contains general interfaces shared in the package.
* `operators.py` contains a forward laplacian operator as well as alternatives.
* `utils.py` contains several small utility functions.
* `tree_utils.py` contains several utility functions for PyTrees.
* `tree_utils.py` contains several utility functions for PyTrees.
* `vmap.py` contains a batched vmap implementation to reduce memory usage by going through a batch sequentially in chunks.


### Function Annotations
There is a default interpreter that will simply apply the rules outlined above but if additional information about a function is available, e.g., that it applies elementwise like `jnp.tanh`, we can do better.
These additional annotations are available in `wrapped_functions.py`'s `_LAPLACE_FN_REGISTRY`.
These additional annotations are available in `wrapped_functions.py`'s `_LAPLACE_FN_REGISTRY`.
Specifically, to augment a function `fn` to accept `FwdLaplArray` instead of regular `jax.Array`, we wrap it with `wrap_forward_laplacian` from `fwd_laplacian.py`:
```python
wrap_forward_laplacian(jnp.tanh, in_axes=())
Expand Down Expand Up @@ -142,11 +142,12 @@ g(jnp.ones(())).laplacian # 10
Sparsity is detected at compile time, this has the advantage of avoiding expensive index computations at runtime and enables efficient reductions. However, it completely prohibits dynamic indexing, i.e., if indices are data-dependent we will simply default to full jacobians.

As we know a lot about the sparsity structure apriori, e.g., that we are only sparse in one dimension, we use a custom sparsity operations that are more efficient than relying on JAX's default `BCOO` (further, at the time of writing, the support for `jax.experimental.sparse` is quite bad).
So, the sparsity data format is implemented in `FwdJacobian` in `api.py`. Instead of storing a dense array `(m, n)` for a function `f:R^n -> R^m`, we store only the non-zero data in a `(m,d)` array where `d<n` is the maximum number of non-zero inputs any output depends on.
So, the sparsity data format is implemented in `FwdJacobian` in `api.py`. Instead of storing a dense array `(m, n)` for a function `f:R^n -> R^m`, we store only the non-zero data in a `(m,d)` array where `d<n` is the maximum number of non-zero inputs any output depends on.
To be able to recreate the larger `(m,n)` array from the `(m,d)` array, we additional keep track of the indices in the last dimension in a mask `(m,d)` dimensional array of integers `0<mask_ij<n`.

Masks are treated as compile time static and will be traced automatically. If the tracing is not possible, e.g., due to data dependent indexing, we will fall back to a dense implementation. These propagation rules are implemented in `jvp.py`.


### Memory efficiency
The forward laplacian uses more GPU memory due to the full materialization of the Jacobian matrix.
To compensate for this, it is recommended to loop over the batch size (while other implementations typically loop over the Hessian).
Expand All @@ -160,17 +161,6 @@ def f(x):
batched_f = batched_vmap(f, max_batch_size=64)
```

##### Omnistaging
If arrays do not depend on the initial input, they are typically still traced to better optimize the final program. This is called [omnistaging](https://github.com/google/jax/pull/3370). While this generally is beneficial, it does not allow us to perform indexing as tracer hide the actual data.
So, if we use sparsity we want to compute all arrays that do not explicitly depend on the input such that we could use them for index operations.
While this is not documented, it can be accomplished by overwriting the global trace via:
```python
from jax import core

with core.new_main(core.EvalTrace, dynamic=True):
...
```

## Citation
If you find work helpful, please consider citing it as
```
Expand Down Expand Up @@ -208,8 +198,8 @@ contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additio

## Trademarks

This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft
trademarks or logos is subject to and must follow
This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft
trademarks or logos is subject to and must follow
[Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
Any use of third-party trademarks or logos are subject to those third-party's policies.
15 changes: 5 additions & 10 deletions folx/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import jax.numpy as jnp
import numpy as np
import numpy.typing as npt
from jax import core
from jaxtyping import Array, PyTree

T = TypeVar('T', bound=PyTree[Array])
Expand Down Expand Up @@ -131,7 +130,7 @@ def get_indices(mask, out_mask):
return indices

if isinstance(outputs, np.ndarray):
with core.new_main(core.EvalTrace, dynamic=True):
with jax.ensure_compile_time_eval():
result = np.asarray(get_indices(flat_mask, flat_outputs), dtype=int).T
else:
result = get_indices(flat_mask, flat_outputs).T
Expand Down Expand Up @@ -390,8 +389,7 @@ def __len__(self) -> int:


class MergeFn(Protocol):
def __call__(self, args: Arrays, extra: ExtraArgs) -> Arrays:
...
def __call__(self, args: Arrays, extra: ExtraArgs) -> Arrays: ...


class ForwardLaplacianFns(NamedTuple):
Expand All @@ -403,8 +401,7 @@ class ForwardLaplacianFns(NamedTuple):


class JvpFn(Protocol):
def __call__(self, primals: Arrays, tangents: Arrays) -> tuple[Array, Array]:
...
def __call__(self, primals: Arrays, tangents: Arrays) -> tuple[Array, Array]: ...


class CustomTraceJacHessianJac(Protocol):
Expand All @@ -414,8 +411,7 @@ def __call__(
extra_args: ExtraArgs,
merge: MergeFn,
materialize_idx: Array,
) -> PyTree[Array]:
...
) -> PyTree[Array]: ...


class ForwardLaplacian(Protocol):
Expand All @@ -424,8 +420,7 @@ def __call__(
args: tuple[ArrayOrFwdLaplArray],
kwargs: dict[str, Any],
sparsity_threshold: int,
) -> PyTree[ArrayOrFwdLaplArray]:
...
) -> PyTree[ArrayOrFwdLaplArray]: ...


class FunctionFlags(IntFlag):
Expand Down
3 changes: 1 addition & 2 deletions folx/hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import jax.tree_util as jtu
import jaxlib.xla_extension
import numpy as np
from jax import core

from .api import (
JAC_DIM,
Expand Down Expand Up @@ -206,7 +205,7 @@ def find_materialization_idx(
return None
# TODO: Rewrite this!! This is quity messy and inefficient.
# it assumes that we're only interested in the last dimension.
with core.new_main(core.EvalTrace, dynamic=True):
with jax.ensure_compile_time_eval():
vmap_seq, (inp,) = vmap_sequences_and_squeeze(
([j.mask for j in lapl_args.jacobian],),
(
Expand Down
2 changes: 1 addition & 1 deletion folx/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def eval_laplacian(eqn: core.JaxprEqn, invals):
# https://github.com/google/jax/pull/3370
if all(not isinstance(x, core.Tracer) for x in invals) and enable_sparsity:
try:
with core.new_main(core.EvalTrace, dynamic=True):
with jax.ensure_compile_time_eval():
outvals = eqn.primitive.bind(*subfuns, *invals, **bind_params)
except Exception as e:
with LoggingPrefix(f'({summarize(eqn.source_info)})'):
Expand Down
5 changes: 2 additions & 3 deletions folx/jvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import TypeVar

import jax
import jax.core as core
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
Expand Down Expand Up @@ -192,7 +191,7 @@ def sparse_index_jvp(
# An index operation is expected to be static. If it is not, we will default to
# materializing everything.
# https://github.com/google/jax/pull/3370
with core.new_main(core.EvalTrace, dynamic=True):
with jax.ensure_compile_time_eval():
extra_filled = jtu.tree_map(
lambda x: jnp.full(x.shape, -1, dtype=jnp.int32), extra_args
)
Expand Down Expand Up @@ -287,7 +286,7 @@ def sparse_scatter_jvp(

updates: FwdLaplArray = updates
n = updates.jacobian.max_n + 1
with core.new_main(core.EvalTrace, dynamic=True):
with jax.ensure_compile_time_eval():
one_hot_mask = jax.nn.one_hot(
updates.jacobian.x0_idx, n, axis=-1, dtype=jnp.int32
).sum(0)
Expand Down
3 changes: 1 addition & 2 deletions folx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
from jax import core

from .api import (
JAC_DIM,
Expand Down Expand Up @@ -563,7 +562,7 @@ def broadcast(m: np.ndarray, j: Array):
def brdcast(x):
return jnp.broadcast_to(x, target_shape)

with core.new_main(core.EvalTrace, dynamic=True):
with jax.ensure_compile_time_eval():
return np.asarray(brdcast(m), dtype=m.dtype)

return jtu.tree_map(broadcast, mask, jacobian)
Expand Down

0 comments on commit 0def91e

Please sign in to comment.