Skip to content

Commit

Permalink
add example on sparsity to the readme
Browse files Browse the repository at this point in the history
  • Loading branch information
n-gao authored Mar 23, 2024
1 parent 0def91e commit ea01902
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pip install folx
```

## Example
### Dense example
For simple usage, one can decorate any function with `forward_laplacian`.
```python
import numpy as np
Expand All @@ -32,6 +33,36 @@ result.x # f(x) 3
result.jacobian.dense_array # J_f(x) [0, 2, 4]
result.laplacian # tr(H_f(x)) 6
```
### Sparsity example
A big feature of `folx` is to automatically work with sparse jacobians to accelerate computations. Note that the results are still **exact**. To enable this feature simply supply a maximum sparsity threshold. Compile times may increase significantly as tracing the sparsity patterns of the jacobians is a lengthy process. Here is an example with an MLP operating indepdently on individual node features.
```python
import folx
import jax
import jax.numpy as jnp
import flax.linen as nn

class MLP(nn.Module):
@nn.compact
def __call__(self, x):
for _ in range(10):
x = nn.Dense(100)(x)
x = nn.silu(x)
return nn.Dense(1)(x).sum()

mlp = MLP()
x = jnp.ones((20, 100, 4))
params = mlp.init(jax.random.PRNGKey(0), x)
def fwd(x):
return mlp.apply(params, x)

fwd_lapl = jax.jit(jax.vmap(folx.forward_laplacian(fwd, sparsity_threshold=4)))
%time jax.block_until_ready(fwd_lapl(x)) # Wall time: 5.05 s
%timeit jax.block_until_ready(fwd_lapl(x)) # 2.59 ms ± 15.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
fwd_lapl = jax.jit(jax.vmap(folx.forward_laplacian(fwd, sparsity_threshold=0)))
%time jax.block_until_ready(fwd_lapl(x)) # Wall time: 2.66 s
%timeit jax.block_until_ready(fwd_lapl(x)) # 48.7 ms ± 42.1 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
```
For electronic wave function like FermiNet or PsiFormer, `sparsity_threshold=6` is a recommended value. But, tuning this hyperparameter may accelerate computations.

## Introduction
To avoid custom wrappers for all of JAX's commands, the forward laplacian is implemented as custom interpreter for Jaxpr.
Expand Down

0 comments on commit ea01902

Please sign in to comment.