Skip to content

Commit

Permalink
Block-Jacobi preconditioning, Eisenstat-Walker for inexact steps (#19)
Browse files Browse the repository at this point in the history
* Block-Jacobi preconditioning

* Nits

* Add missing jdc.Static[]

* Fix

* Implement Eisenstat-Walker
  • Loading branch information
brentyi authored Oct 12, 2024
1 parent 416e788 commit 8576a1e
Show file tree
Hide file tree
Showing 3 changed files with 224 additions and 36 deletions.
17 changes: 10 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

[![pyright](https://github.com/brentyi/jaxls/actions/workflows/pyright.yml/badge.svg)](https://github.com/brentyi/jaxls/actions/workflows/pyright.yml)

_status: working! see limitations [here](#limitations)_
_status: working! see limitations [here](#limitations)_

**`jaxls`** is a library for nonlinear least squares in JAX.

Expand All @@ -11,18 +11,20 @@ problems. We accelerate optimization by analyzing the structure of graphs:
repeated factor and variable types are vectorized, and the sparsity of adjacency
in the graph is translated into sparse matrix operations.

Features:
Currently supported:

- Automatic sparse Jacobians.
- Optimization on manifolds; SO(2), SO(3), SE(2), and SE(3) implementations
included.
- Nonlinear solvers: Levenberg-Marquardt and Gauss-Newton.
- Linear solvers: both direct (sparse Cholesky via CHOLMOD, on CPU) and
iterative (Jacobi-preconditioned Conjugate Gradient).
- Direct linear solves via sparse Cholesky / CHOLMOD, on CPU.
- Iterative linear solves via Conjugate Gradient.
- Preconditioning: block and point Jacobi.
- Inexact Newton via Eisenstat-Walker.

Use cases are primarily in least squares problems that are inherently (1) sparse
and (2) inefficient to solve with gradient-based methods. In robotics, these are
ubiquitous across classical approaches to perception, planning, and control.
Use cases are primarily in least squares problems that are inherently (1)
sparse and (2) inefficient to solve with gradient-based methods. These are
common in robotics.

For the first iteration of this library, written for
[IROS 2021](https://github.com/brentyi/dfgo), see
Expand Down Expand Up @@ -122,6 +124,7 @@ print("Pose 1", solution[pose_vars[1]])
### Limitations

There are many practical features that we don't currently support:

- GPU accelerated Cholesky factorization. (for CHOLMOD we wrap [scikit-sparse](https://scikit-sparse.readthedocs.io/en/latest/), which runs on CPU only)
- Covariance estimation / marginalization.
- Incremental solves.
Expand Down
126 changes: 126 additions & 0 deletions src/jaxls/_preconditioning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Callable

import jax
from jax import numpy as jnp

if TYPE_CHECKING:
from ._factor_graph import FactorGraph
from ._sparse_matrices import BlockRowSparseMatrix


def make_point_jacobi_precoditioner(
A_blocksparse: BlockRowSparseMatrix,
) -> Callable[[jax.Array], jax.Array]:
"""Returns a point Jacobi (diagonal) preconditioner."""
ATA_diagonals = jnp.zeros(A_blocksparse.shape[1])

for block_row in A_blocksparse.block_rows:
(n_blocks, rows, cols_concat) = block_row.blocks_concat.shape
del rows
del cols_concat
assert block_row.blocks_concat.ndim == 3 # (N_block, rows, cols)
assert block_row.start_cols[0].shape == (n_blocks,)
block_l2_cols = jnp.sum(block_row.blocks_concat**2, axis=1).flatten()
indices = jnp.concatenate(
[
(start_col[:, None] + jnp.arange(width)[None, :])
for start_col, width in zip(
block_row.start_cols, block_row.block_widths
)
],
axis=1,
).flatten()
ATA_diagonals = ATA_diagonals.at[indices].add(block_l2_cols)

return lambda vec: vec / ATA_diagonals


def make_block_jacobi_precoditioner(
graph: FactorGraph, A_blocksparse: BlockRowSparseMatrix
) -> Callable[[jax.Array], jax.Array]:
"""Returns a block Jacobi preconditioner."""

# This list will store block diagonal gram matrices corresponding to each
# variable.
gram_diagonal_blocks = list[jax.Array]()
for var_type, ids in graph.tangent_ordering.ordered_dict_items(
graph.sorted_ids_from_var_type
):
(num_vars,) = ids.shape
gram_diagonal_blocks.append(
jnp.zeros((num_vars, var_type.tangent_dim, var_type.tangent_dim))
+ jnp.eye(var_type.tangent_dim) * 1e-6
)

assert len(graph.stacked_factors) == len(A_blocksparse.block_rows)
for factor, block_row in zip(graph.stacked_factors, A_blocksparse.block_rows):
assert block_row.blocks_concat.ndim == 3 # (N_block, rows, cols)

# Current index we're looking at in the blocks_concat array.
start_concat_col = 0

for var_type, ids in graph.tangent_ordering.ordered_dict_items(
factor.sorted_ids_from_var_type
):
(num_factors, num_vars) = ids.shape
var_type_idx = graph.tangent_ordering.order_from_type[var_type]

# Extract the blocks corresponding to the current variable type.
end_concat_col = start_concat_col + num_vars * var_type.tangent_dim
A_blocks = block_row.blocks_concat[
:, :, start_concat_col:end_concat_col
].reshape(
(
num_factors,
factor.residual_dim,
num_vars,
var_type.tangent_dim,
)
)

# f: factor, r: residual, v: variable, t/a: tangent
gram_blocks = jnp.einsum("frvt,frva->fvta", A_blocks, A_blocks)
assert gram_blocks.shape == (
num_factors,
num_vars,
factor.residual_dim,
factor.residual_dim,
)

start_concat_col = end_concat_col
del end_concat_col

gram_diagonal_blocks[var_type_idx] = (
gram_diagonal_blocks[var_type_idx]
.at[jnp.searchsorted(graph.sorted_ids_from_var_type[var_type], ids)]
.add(gram_blocks)
)

inv_block_diagonals = [
jnp.linalg.inv(batched_block) for batched_block in gram_diagonal_blocks
]

def preconditioner(vec: jax.Array) -> jax.Array:
"""Compute block Jacobi preconditioning."""
precond_parts = []
offset = 0
for inv_batched_block in inv_block_diagonals:
num_blocks, block_dim, block_dim_ = inv_batched_block.shape
assert block_dim == block_dim_
precond_parts.append(
jnp.einsum(
"bij,bj->bi",
inv_batched_block,
vec[offset : offset + num_blocks * block_dim].reshape(
(num_blocks, block_dim)
),
).flatten()
)
offset += num_blocks * block_dim
out = jnp.concatenate(precond_parts, axis=0)
assert out.shape == vec.shape
return out

return preconditioner
117 changes: 88 additions & 29 deletions src/jaxls/_solvers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import functools
from typing import TYPE_CHECKING, Callable, Hashable, cast
from typing import TYPE_CHECKING, Callable, Hashable, Literal, assert_never, cast

import jax
import jax.experimental.sparse
Expand All @@ -12,7 +12,12 @@
import sksparse.cholmod
from jax import numpy as jnp

from ._sparse_matrices import SparseCooMatrix, SparseCsrMatrix
from jaxls._preconditioning import (
make_block_jacobi_precoditioner,
make_point_jacobi_precoditioner,
)

from ._sparse_matrices import BlockRowSparseMatrix, SparseCsrMatrix
from ._variables import VarTypeOrdering, VarValues
from .utils import jax_log

Expand Down Expand Up @@ -75,26 +80,73 @@ def _solve_on_host(


@jdc.pytree_dataclass
class ConjugateGradientLinearSolver:
"""Iterative solver for sparse linear systems. Can run on CPU or GPU."""
class ConjugateGradientState:
"""State used for Eisenstat-Walker criterion in ConjugateGradientLinearSolver."""

ATb_norm_prev: float | jax.Array
"""Previous norm of ATb."""
eta: float | jax.Array
"""Current tolerance."""

tolerance: float = 1e-7
inexact_step_eta: float | None = 1e-2
"""Forcing sequence parameter for inexact Newton steps. CG tolerance is set to
`eta / iteration #`.

For reference, see AN INEXACT LEVENBERG-MARQUARDT METHOD FOR LARGE SPARSE NONLINEAR
LEAST SQUARES, Wright & Holt 1983."""
@jdc.pytree_dataclass
class ConjugateGradientLinearSolver:
"""Iterative solver for sparse linear systems. Can run on CPU or GPU.
For inexact steps, we use the Eisenstat-Walker criterion. For reference,
see "Choosing the Forcing Terms in an Inexact Newton Method", Eisenstat &
Walker, 1996."
"""

tolerance_min: float = 1e-7
tolerance_max: float = 1e-2

eisenstat_walker_gamma: float = 0.9
"""Eisenstat-Walker criterion gamma term. Controls how quickly the tolerance
decreases. Typical values range from 0.5 to 0.9. Higher values lead to more
aggressive tolerance reduction."""
eisenstat_walker_alpha: float = 2.0
""" Eisenstat-Walker criterion alpha term. Determines rate at which the
tolerance changes based on residual reduction. Typical values are 1.5 or
2.0. Higher values make the tolerance more sensitive to residual changes."""

preconditioner: jdc.Static[Literal["block-jacobi", "point-jacobi"] | None] = (
"block-jacobi"
)
"""Preconditioner to use for linear solves."""

def _solve(
self,
graph: FactorGraph,
A_blocksparse: BlockRowSparseMatrix,
ATA_multiply: Callable[[jax.Array], jax.Array],
ATA_diagonals: jax.Array,
ATb: jax.Array,
iterations: int | jax.Array,
) -> jax.Array:
prev_linear_state: ConjugateGradientState,
) -> tuple[jax.Array, ConjugateGradientState]:
assert len(ATb.shape) == 1, "ATb should be 1D!"

# Preconditioning setup.
if self.preconditioner == "block-jacobi":
preconditioner = make_block_jacobi_precoditioner(graph, A_blocksparse)
elif self.preconditioner == "point-jacobi":
preconditioner = make_point_jacobi_precoditioner(A_blocksparse)
elif self.preconditioner is None:
preconditioner = lambda x: x
else:
assert_never(self.preconditioner)

# Calculate tolerance using Eisenstat-Walker criterion.
ATb_norm = jnp.linalg.norm(ATb)
current_eta = jnp.minimum(
self.eisenstat_walker_gamma
* (ATb_norm / (prev_linear_state.ATb_norm_prev + 1e-7))
** self.eisenstat_walker_alpha,
self.tolerance_max,
)
current_eta = jnp.maximum(
self.tolerance_min, jnp.minimum(current_eta, prev_linear_state.eta)
)

# Solve with conjugate gradient.
initial_x = jnp.zeros(ATb.shape)
solution_values, _ = jax.scipy.sparse.linalg.cg(
Expand All @@ -103,15 +155,12 @@ def _solve(
x0=initial_x,
# https://en.wikipedia.org/wiki/Conjugate_gradient_method#Convergence_properties
maxiter=len(initial_x),
tol=cast(
float,
jnp.maximum(self.tolerance, self.inexact_step_eta / (iterations + 1)),
)
if self.inexact_step_eta is not None
else self.tolerance,
M=lambda x: x / ATA_diagonals, # Jacobi preconditioner.
tol=cast(float, current_eta),
M=preconditioner,
)
return solution_values, ConjugateGradientState(
ATb_norm_prev=ATb_norm, eta=current_eta
)
return solution_values


# Nonlinear solvers.
Expand All @@ -126,6 +175,8 @@ class NonlinearSolverState:
done: bool | jax.Array
lambd: float | jax.Array

linear_state: ConjugateGradientState | None


@jdc.pytree_dataclass
class NonlinearSolver:
Expand All @@ -149,6 +200,11 @@ def solve(self, graph: FactorGraph, initial_vals: VarValues) -> VarValues:
lambd=self.trust_region.lambda_initial
if self.trust_region is not None
else 0.0,
linear_state=None
if isinstance(self.linear_solver, CholmodLinearSolver)
else ConjugateGradientState(
ATb_norm_prev=0.0, eta=self.linear_solver.tolerance_max
),
)

# Optimization.
Expand Down Expand Up @@ -190,18 +246,17 @@ def step(
# Compute right-hand side of normal equation.
ATb = -AT_multiply(state.residual_vector)

linear_state = None
if isinstance(self.linear_solver, ConjugateGradientLinearSolver):
# Get diagonals of ATA for preconditioning.
ATA_diagonals = (
jnp.zeros_like(ATb).at[graph.jac_coords_coo.cols].add(jac_values**2)
)
local_delta = self.linear_solver._solve(
assert isinstance(state.linear_state, ConjugateGradientState)
local_delta, linear_state = self.linear_solver._solve(
graph,
A_blocksparse,
# We could also use (lambd * ATA_diagonals * vec) for
# scale-invariant damping. But this is hard to match with CHOLMOD.
lambda vec: AT_multiply(A_multiply(vec)) + state.lambd * vec,
ATA_diagonals,
ATb,
iterations=state.iterations,
ATb=ATb,
prev_linear_state=state.linear_state,
)
elif isinstance(self.linear_solver, CholmodLinearSolver):
A_csr = SparseCsrMatrix(jac_values, graph.jac_coords_csr)
Expand Down Expand Up @@ -239,6 +294,10 @@ def step(
proposed_residual_vector = graph.compute_residual_vector(vals)
proposed_cost = jnp.sum(proposed_residual_vector**2)

# Update ATb_norm for Eisenstat-Walker criterion.
if linear_state is not None:
state_next.linear_state = linear_state

# Always accept Gauss-Newton steps.
if self.trust_region is None:
state_next.vals = vals
Expand Down

0 comments on commit 8576a1e

Please sign in to comment.