Skip to content

Commit

Permalink
add loop to example
Browse files Browse the repository at this point in the history
  • Loading branch information
n-gao committed Mar 23, 2024
1 parent ea01902 commit 62a5740
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 13 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ repos:

- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.3.2
rev: v0.3.4
hooks:
# Run the linter.
- id: ruff
Expand Down
22 changes: 15 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,27 @@ class MLP(nn.Module):
x = nn.Dense(100)(x)
x = nn.silu(x)
return nn.Dense(1)(x).sum()

mlp = MLP()
x = jnp.ones((20, 100, 4))
params = mlp.init(jax.random.PRNGKey(0), x)
def fwd(x):
return mlp.apply(params, x)

fwd_lapl = jax.jit(jax.vmap(folx.forward_laplacian(fwd, sparsity_threshold=4)))
%time jax.block_until_ready(fwd_lapl(x)) # Wall time: 5.05 s
%timeit jax.block_until_ready(fwd_lapl(x)) # 2.59 ms ± 15.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
fwd_lapl = jax.jit(jax.vmap(folx.forward_laplacian(fwd, sparsity_threshold=0)))
%time jax.block_until_ready(fwd_lapl(x)) # Wall time: 2.66 s
%timeit jax.block_until_ready(fwd_lapl(x)) # 48.7 ms ± 42.1 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
# Traditional loop implementation
lapl = jax.jit(jax.vmap(folx.LoopLaplacianOperator()(fwd)))
%time jax.block_until_ready(lapl(x)) # Wall time: 1.42 s
%timeit jax.block_until_ready(lapl(x)) # 224 ms ± 54 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

# Forward laplacian without sparsity
lapl = jax.jit(jax.vmap(folx.ForwardLaplacianOperator(0)(fwd)))
%time jax.block_until_ready(lapl(x)) # Wall time: 2.66 s
%timeit jax.block_until_ready(lapl(x)) # 48.7 ms ± 42.1 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

# Forward laplacian with sparsity
lapl = jax.jit(jax.vmap(folx.ForwardLaplacianOperator(4)(fwd)))
%time jax.block_until_ready(lapl(x)) # Wall time: 5.05 s
%timeit jax.block_until_ready(lapl(x)) # 2.59 ms ± 15.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
```
For electronic wave function like FermiNet or PsiFormer, `sparsity_threshold=6` is a recommended value. But, tuning this hyperparameter may accelerate computations.

Expand Down
13 changes: 8 additions & 5 deletions folx/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,11 @@


class Laplacian(Protocol):
def __call__(self, x: Array) -> tuple[Array, Array]:
...
def __call__(self, x: Array) -> tuple[Array, Array]: ...


class LaplacianOperator(Protocol):
def __call__(self, f: Callable[[Array], Array]) -> Laplacian:
...
def __call__(self, f: Callable[[Array], Array]) -> Laplacian: ...


@dataclass(frozen=True)
Expand All @@ -46,10 +44,15 @@ class LoopLaplacianOperator(LaplacianOperator):
def __call__(f):
@jax.jit
def laplacian(x: jax.Array):
x_shape = x.shape
x = x.reshape(-1)
n = x.shape[0]
eye = jnp.eye(n)
grad_f = jax.grad(f)

def f_(x):
return f(x.reshape(x_shape))

grad_f = jax.grad(f_)
jacobian, dgrad_f = jax.linearize(grad_f, x)

_, laplacian = jax.lax.scan(
Expand Down

0 comments on commit 62a5740

Please sign in to comment.