Skip to content

Commit

Permalink
Remove scatters from block-sparse matrix multiplication (#18)
Browse files Browse the repository at this point in the history
* Scatter-free block-sparse matrices

* Nit

* Cleanup

* Refactor block-row implementation to remove manual reduce/add during
matvec multiply

* Cleanup
  • Loading branch information
brentyi authored Oct 12, 2024
1 parent a8a508e commit 416e788
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 119 deletions.
86 changes: 36 additions & 50 deletions src/jaxls/_factor_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
TrustRegionConfig,
)
from ._sparse_matrices import (
BlockSparseMatrix,
MatrixBlock,
BlockRowSparseMatrix,
MatrixBlockRow,
SparseCooCoordinates,
SparseCsrCoordinates,
)
Expand Down Expand Up @@ -83,11 +83,8 @@ def compute_residual_vector(self, vals: VarValues) -> jax.Array:
residual_slices.append(stacked_residual_slice.reshape((-1,)))
return jnp.concatenate(residual_slices, axis=0)

def _compute_jac_values(
self, vals: VarValues
) -> tuple[jax.Array, BlockSparseMatrix]:
jac_vals = []
blocks = dict[tuple[int, int], list[MatrixBlock]]()
def _compute_jac_values(self, vals: VarValues) -> BlockRowSparseMatrix:
block_rows = list[MatrixBlockRow]()
residual_offset = 0

for factor in self.stacked_factors:
Expand Down Expand Up @@ -116,72 +113,61 @@ def compute_jac_with_perturb(factor: _AnalyzedFactor) -> jax.Array:
)
)(jnp.zeros((val_subset._get_tangent_dim(),)))

# Compute Jacobian for each factor.
stacked_jac = jax.vmap(compute_jac_with_perturb)(factor)
(num_factor,) = factor._get_batch_axes()
assert stacked_jac.shape == (
num_factor,
factor.residual_dim,
stacked_jac.shape[-1], # Tangent dimension.
)
jac_vals.append(stacked_jac.flatten())

start_col = 0
# Compute block-row representation for sparse Jacobian.
stacked_jac_start_col = 0
start_cols = list[jax.Array]()
block_widths = list[int]()
for var_type, ids in self.tangent_ordering.ordered_dict_items(
# This ordering shouldn't actually matter!
factor.sorted_ids_from_var_type
):
block_shape = (factor.residual_dim, var_type.tangent_dim)
(num_factor_, num_vars) = ids.shape
assert num_factor == num_factor_
end_col = start_col + num_vars * var_type.tangent_dim

block_vals = jnp.moveaxis(
stacked_jac[:, :, start_col:end_col].reshape(
(
num_factor_,
factor.residual_dim,
num_vars,
var_type.tangent_dim,

# Get one block for each variable.
for var_idx in range(ids.shape[-1]):
start_cols.append(
jnp.searchsorted(
self.sorted_ids_from_var_type[var_type], ids[..., var_idx]
)
),
2,
1,
).reshape(
(num_factor_ * num_vars, factor.residual_dim, var_type.tangent_dim)
)
blocks.setdefault(block_shape, []).append(
MatrixBlock(
start_row=residual_offset
+ jnp.repeat(
jnp.arange(num_factor_) * factor.residual_dim, num_vars
),
start_col=(
jnp.searchsorted(
self.sorted_ids_from_var_type[var_type], ids.flatten()
)
* var_type.tangent_dim
+ self.tangent_start_from_var_type[var_type]
),
values=block_vals,
* var_type.tangent_dim
+ self.tangent_start_from_var_type[var_type]
)
)
start_col = end_col
block_widths.append(var_type.tangent_dim)
assert start_cols[-1].shape == (num_factor_,)

assert stacked_jac.shape[-1] == start_col
stacked_jac_start_col = (
stacked_jac_start_col + num_vars * var_type.tangent_dim
)
assert stacked_jac.shape[-1] == stacked_jac_start_col

block_rows.append(
MatrixBlockRow(
start_row=jnp.arange(num_factor) * factor.residual_dim
+ residual_offset,
start_cols=tuple(start_cols),
block_widths=tuple(block_widths),
blocks_concat=stacked_jac,
)
)

residual_offset += factor.residual_dim * num_factor
assert residual_offset == self.residual_dim

bsparse_jacobian = BlockSparseMatrix(
blocks={
shape: jax.tree.map(lambda *x: jnp.concatenate(x, axis=0), *blocklist)
for shape, blocklist in blocks.items()
},
bsparse_jacobian = BlockRowSparseMatrix(
block_rows=tuple(block_rows),
shape=(self.residual_dim, self.tangent_dim),
)
jac_vals = jnp.concatenate(jac_vals, axis=0)
assert jac_vals.shape == (self.jac_coords_coo.rows.shape[0],)
return jac_vals, bsparse_jacobian
return bsparse_jacobian

@staticmethod
def make(
Expand Down
57 changes: 36 additions & 21 deletions src/jaxls/_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ def _solve_on_host(
class ConjugateGradientLinearSolver:
"""Iterative solver for sparse linear systems. Can run on CPU or GPU."""

tolerance: float = 1e-5
inexact_step_eta: float | None = None
tolerance: float = 1e-7
inexact_step_eta: float | None = 1e-2
"""Forcing sequence parameter for inexact Newton steps. CG tolerance is set to
`eta / iteration #`.
Expand All @@ -88,16 +88,13 @@ class ConjugateGradientLinearSolver:

def _solve(
self,
A_coo: jax.experimental.sparse.BCOO,
ATA_multiply: Callable[[jax.Array], jax.Array],
ATA_diagonals: jax.Array,
ATb: jax.Array,
iterations: int | jax.Array,
) -> jax.Array:
assert len(ATb.shape) == 1, "ATb should be 1D!"

# Get diagonals of ATA for preconditioning.
ATA_diagonals = jnp.zeros_like(ATb).at[A_coo.indices[:, 1]].add(A_coo.data**2)

# Solve with conjugate gradient.
initial_x = jnp.zeros(ATb.shape)
solution_values, _ = jax.scipy.sparse.linalg.cg(
Expand Down Expand Up @@ -171,34 +168,48 @@ def solve(self, graph: FactorGraph, initial_vals: VarValues) -> VarValues:
def step(
self, graph: FactorGraph, state: NonlinearSolverState
) -> NonlinearSolverState:
jac_values, A_blocksparse = graph._compute_jac_values(state.vals)
A_coo = SparseCooMatrix(jac_values, graph.jac_coords_coo).as_jax_bcoo()
A_multiply = A_blocksparse.multiply
AT_multiply = A_blocksparse.transpose().multiply
# Get nonzero values of Jacobian.
A_blocksparse = graph._compute_jac_values(state.vals)

# Get flattened version for COO/CSR matrices.
jac_values = jnp.concatenate(
[
block_row.blocks_concat.flatten()
for block_row in A_blocksparse.block_rows
],
axis=0,
)

# Equivalently:
# AT_multiply = lambda vec: jax.linear_transpose(
# A_blocksparse.multiply, jnp.zeros((A_blocksparse.shape[1],))
# )(vec)[0]
# linear_transpose() will return a tuple, with one element per primal.
A_multiply = A_blocksparse.multiply
AT_multiply_ = jax.linear_transpose(
A_multiply, jnp.zeros((A_blocksparse.shape[1],))
)
AT_multiply = lambda vec: AT_multiply_(vec)[0]

# Compute right-hand side of normal equation.
ATb = -AT_multiply(state.residual_vector)

if isinstance(self.linear_solver, ConjugateGradientLinearSolver):
tangent = self.linear_solver._solve(
A_coo,
# Get diagonals of ATA for preconditioning.
ATA_diagonals = (
jnp.zeros_like(ATb).at[graph.jac_coords_coo.cols].add(jac_values**2)
)
local_delta = self.linear_solver._solve(
# We could also use (lambd * ATA_diagonals * vec) for
# scale-invariant damping. But this is hard to match with CHOLMOD.
lambda vec: AT_multiply(A_multiply(vec)) + state.lambd * vec,
ATA_diagonals,
ATb,
iterations=state.iterations,
)
elif isinstance(self.linear_solver, CholmodLinearSolver):
A_csr = SparseCsrMatrix(jac_values, graph.jac_coords_csr)
tangent = self.linear_solver._solve(A_csr, ATb, lambd=state.lambd)
local_delta = self.linear_solver._solve(A_csr, ATb, lambd=state.lambd)
else:
assert False

vals = state.vals._retract(tangent, graph.tangent_ordering)
vals = state.vals._retract(local_delta, graph.tangent_ordering)
if self.verbose:
jax_log(
" step #{i}: cost={cost:.4f} lambd={lambd:.4f}",
Expand Down Expand Up @@ -237,7 +248,11 @@ def step(
# For Levenberg-Marquardt, we need to evaluate the step quality.
else:
step_quality = (proposed_cost - state.cost) / (
jnp.sum((A_coo @ tangent + state.residual_vector) ** 2) - state.cost
jnp.sum(
(A_blocksparse.multiply(local_delta) + state.residual_vector)
** 2
)
- state.cost
)
accept_flag = step_quality >= self.trust_region.step_quality_min

Expand Down Expand Up @@ -268,7 +283,7 @@ def step(
state_next.done = self.termination._check_convergence(
state,
cost_updated=state_next.cost,
tangent=tangent,
tangent=local_delta,
tangent_ordering=graph.tangent_ordering,
ATb=ATb,
)
Expand Down Expand Up @@ -296,7 +311,7 @@ class TerminationConfig:
max_iterations: int = 100
cost_tolerance: float = 1e-6
"""We terminate if `|cost change| / cost < cost_tolerance`."""
gradient_tolerance: float = 1e-8
gradient_tolerance: float = 1e-7
"""We terminate if `norm_inf(x - rplus(x, linear delta)) < gradient_tolerance`."""
gradient_tolerance_start_step: int = 10
"""When to start checking the gradient tolerance condition. Helps solve precision
Expand Down
98 changes: 50 additions & 48 deletions src/jaxls/_sparse_matrices.py
Original file line number Diff line number Diff line change
@@ -1,70 +1,72 @@
from __future__ import annotations

from typing import Hashable

import jax
import jax.experimental.sparse
import jax_dataclasses as jdc
from jax import numpy as jnp


@jdc.pytree_dataclass
class MatrixBlock:
class MatrixBlockRow:
start_row: jax.Array
start_col: jax.Array
values: jax.Array
"""Row indices of the start of each block. Shape should be `(num_blocks,)`."""
start_cols: tuple[jax.Array, ...]
"""Column indices of the start of each block."""
block_widths: jdc.Static[tuple[int, ...]]
"""Width of each block in the block-row."""
blocks_concat: jax.Array
"""Blocks of matrix, concatenated along the column axis. Shape in tuple should be `(num_blocks, rows, cols)`."""

def treedef(self) -> Hashable:
return tuple(block.shape for block in self.blocks_concat)


@jdc.pytree_dataclass
class BlockSparseMatrix:
blocks: dict[tuple[int, int], MatrixBlock]
"""Map from block shape to block (values, start row, start col)."""
class BlockRowSparseMatrix:
block_rows: tuple[MatrixBlockRow, ...]
"""Batched block-rows, ordered. Each element in the tuple has a leading
axis, which represents consecutive block-rows."""
shape: jdc.Static[tuple[int, int]]
"""Shape of matrix."""

def transpose(self) -> BlockSparseMatrix:
new_blocks = {}
for block_shape, block in self.blocks.items():
new_block = MatrixBlock(
start_row=block.start_col,
start_col=block.start_row,
values=jnp.swapaxes(block.values, -1, -2),
)
new_blocks[block_shape[::-1]] = new_block
return BlockSparseMatrix(new_blocks, (self.shape[1], self.shape[0]))

def multiply(self, target: jax.Array) -> jax.Array:
result = jnp.zeros(self.shape[0])
for block_shape, block in self.blocks.items():
start_row, start_col = block.start_row, block.start_col
assert len(start_row.shape) == 1
assert len(start_col.shape) == 1
values = block.values
assert values.shape == (len(start_row), *block_shape)

def multiply_one_block(col, vals) -> jax.Array:
target_slice = jax.lax.dynamic_slice_in_dim(
target, col, block_shape[1], axis=0
)
return jnp.einsum("ij,j->i", vals, target_slice)

update_indices = start_row[:, None] + jnp.arange(block_shape[0])[None, :]
result = result.at[update_indices].add(
jax.vmap(multiply_one_block)(start_col, values)
"""Sparse-dense multiplication."""
assert target.ndim == 1

out_slices = []
for block_row in self.block_rows:
# Do matrix multiplies for all blocks in block-row.
(n_block, block_rows, block_nz_cols) = block_row.blocks_concat.shape
del block_rows

# Get slices corresponding to nonzero terms in block-row.
assert len(block_row.start_cols) == len(block_row.block_widths)
target_slice_parts = list[jax.Array]()
for start_cols, width in zip(block_row.start_cols, block_row.block_widths):
assert start_cols.shape == (n_block,)
assert isinstance(width, int)
slice_part = jax.vmap(
lambda start_col: jax.lax.dynamic_slice_in_dim(
target, start_index=start_col, slice_size=width, axis=0
)
)(start_cols)
assert slice_part.shape == (n_block, width)
target_slice_parts.append(slice_part)

# Concatenate slices to form target slice.
target_slice = jnp.concatenate(target_slice_parts, axis=1)
assert target_slice.shape == (n_block, block_nz_cols)

# Multiply block-rows with target slice.
out_slices.append(
jnp.einsum(
"bij,bj->bi", block_row.blocks_concat, target_slice
).flatten()
)
return result

def todense(self) -> jax.Array:
result = jnp.zeros(self.shape)
for block_shape, block in self.blocks.items():
start_row, start_col = block.start_row, block.start_col
assert len(start_row.shape) == 1
assert len(start_col.shape) == 1
values = block.values
assert values.shape == (len(start_row), *block_shape)

row_indices = start_row[:, None] + jnp.arange(block_shape[0])[None, :]
col_indices = start_col[:, None] + jnp.arange(block_shape[1])[None, :]
result = result.at[row_indices, col_indices].set(values)

result = jnp.concatenate(out_slices, axis=0)
return result


Expand Down

0 comments on commit 416e788

Please sign in to comment.