**Mici** is a Python package providing implementations of *Markov chain Monte
@@ -34,6 +28,10 @@ Key features include
extend the package,
* a pure Python code base with minimal dependencies,
allowing easy integration within other code,
+ * built-in support for several automatic differentiation frameworks, including
+ [JAX](https://jax.readthedocs.io/en/latest/) and
+ [Autograd](https://github.com/HIPS/autograd), or the option to supply your own
+ derivative functions,
* implementations of MCMC methods for sampling from distributions on embedded
manifolds implicitly-defined by a constraint equation and distributions on
Riemannian manifolds with a user-specified metric,
@@ -45,7 +43,7 @@ Key features include
## Installation
-To install and use Mici the minimal requirements are a Python 3.9+ environment
+To install and use Mici the minimal requirements are a Python 3.10+ environment
with [NumPy](http://www.numpy.org/) and [SciPy](https://www.scipy.org)
installed. The latest Mici release on PyPI (and its dependencies) can be
installed in the current Python environment by running
@@ -63,6 +61,14 @@ pip install git+https://github.com/matt-graham/mici
If available in the installed Python environment the following additional
packages provide extra functionality and features
+ * [ArviZ](https://python.arviz.org/en/latest/index.html): if ArviZ is
+ available the traces (dictionary) output of a sampling run can be directly
+ converted to an `arviz.InferenceData` container object using
+ `arviz.convert_to_inference_data` or implicitly converted by passing the
+ traces dictionary as the `data` argument
+ [to ArviZ API functions](https://python.arviz.org/en/latest/api/index.html),
+ allowing straightforward use of the ArviZ's extensive visualisation and
+ diagnostic functions.
* [Autograd](https://github.com/HIPS/autograd): if available Autograd will
be used to automatically compute the required derivatives of the model
functions (providing they are specified using functions from the
@@ -74,15 +80,22 @@ packages provide extra functionality and features
serialisation (via [dill](https://github.com/uqfoundation/dill)) of a much
wider range of types, including of Autograd generated functions. Both
Autograd and multiprocess can be installed alongside Mici by running `pip
- install mici[autodiff]`.
- * [ArviZ](https://python.arviz.org/en/latest/index.html): if ArviZ is
- available the traces (dictionary) output of a sampling run can be directly
- converted to an `arviz.InferenceData` container object using
- `arviz.convert_to_inference_data` or implicitly converted by passing the
- traces dictionary as the `data` argument
- [to ArviZ API functions](https://python.arviz.org/en/latest/api/index.html),
- allowing straightforward use of the ArviZ's extensive visualisation and
- diagnostic functions.
+ install mici[autograd]`.
+ * [JAX](https://jax.readthedocs.io/en/latest/): if available JAX will be used to
+ automatically compute the required derivatives of the model functions (providing
+ they are specified using functions from the [`jax`
+ interface](https://jax.readthedocs.io/en/latest/jax.html)). To sample chains
+ parallel using JAX functions you also need to install
+ [multiprocess](https://github.com/uqfoundation/multiprocess), though note due to
+ JAX's use of multithreading which [is incompatible with forking child
+ processes](https://docs.python.org/3/library/os.html#os.fork), this can result in
+ deadlock. Both JAX and multiprocess can be installed alongside Mici by running `pip
+ install mici[jax]`.
+ * [SymNum](https://github.com/matt-graham/symnum): if available SymNum will be used to
+ automatically compute the required derivatives of the model functions (providing
+ they are specified using functions from the [`symnum.numpy`
+ interface](https://matt-graham.github.io/symnum/symnum.numpy.html)). Symnum can be
+ installed alongside Mici by running `pip install mici[symnum]`.
## Why Mici?
@@ -122,7 +135,7 @@ chains in Python can dominate the computational cost, making sampling much
slower than packages which outsource the sampling loop to a efficient compiled
implementation.
- ## Overview of package
+## Overview of package
API documentation for the package is available
[here](https://matt-graham.github.io/mici/). The three main user-facing
@@ -257,22 +270,21 @@ The manifold MCMC methods implemented in Mici have been used in several research
-A simple complete example of using the package to compute approximate samples
-from a distribution on a two-dimensional torus embedded in a three-dimensional
-space is given below. The computed samples are visualized in the animation
-above. Here we use `autograd` to automatically construct functions to calculate
-the required derivatives (gradient of negative log density of target
-distribution and Jacobian of constraint function), sample four chains in
-parallel using `multiprocess`, use `arviz` to calculate diagnostics and use
-`matplotlib` to plot the samples.
-
-> ⚠️ **If you do not have [`multiprocess`](https://github.com/uqfoundation/multiprocess) installed the example code below will hang or raise an error when sampling the chains as the inbuilt `multiprocessing` module does not support pickling Autograd functions.**
+A simple complete example of using the package to compute approximate samples from a
+distribution on a two-dimensional torus embedded in a three-dimensional space is given
+below. The computed samples are visualized in the animation above. Here we use
+[SymNum](https://github.com/matt-graham/symnum) to automatically construct functions to
+calculate the required derivatives (gradient of negative log density of target
+distribution and Jacobian of constraint function), sample four chains in parallel using
+`multiprocessing`, use [ArviZ](https://python.arviz.org/en/stable/) to calculate
+diagnostics and use [Matplotlib](https://matplotlib.org/) to plot the samples.
```Python
-from mici import systems, integrators, samplers
-import autograd.numpy as np
+import mici
+import numpy as np
+import symnum
+import symnum.numpy as snp
import matplotlib.pyplot as plt
-from mpl_toolkits.mplot3d import Axes3D
import matplotlib.animation as animation
import arviz
@@ -281,44 +293,62 @@ R = 1.0 # toroidal radius ∈ (0, ∞)
r = 0.5 # poloidal radius ∈ (0, R)
α = 0.9 # density fluctuation amplitude ∈ [0, 1)
+# State dimension
+dim_q = 3
+
+
# Define constraint function such that the set {q : constr(q) == 0} is a torus
+@symnum.numpify(dim_q)
def constr(q):
- x, y, z = q.T
- return np.stack([((x**2 + y**2)**0.5 - R)**2 + z**2 - r**2], -1)
+ x, y, z = q
+ return snp.array([((x**2 + y**2) ** 0.5 - R) ** 2 + z**2 - r**2])
+
# Define negative log density for the target distribution on torus
# (with respect to 2D 'area' measure for torus)
+@symnum.numpify(dim_q)
def neg_log_dens(q):
- x, y, z = q.T
- θ = np.arctan2(y, x)
- ϕ = np.arctan2(z, x / np.cos(θ) - R)
- return np.log1p(r * np.cos(ϕ) / R) - np.log1p(np.sin(4*θ) * np.cos(ϕ) * α)
+ x, y, z = q
+ θ = snp.arctan2(y, x)
+ ϕ = snp.arctan2(z, x / snp.cos(θ) - R)
+ return snp.log1p(r * snp.cos(ϕ) / R) - snp.log1p(snp.sin(4 * θ) * snp.cos(ϕ) * α)
+
# Specify constrained Hamiltonian system with default identity metric
-system = systems.DenseConstrainedEuclideanMetricSystem(neg_log_dens, constr)
+system = mici.systems.DenseConstrainedEuclideanMetricSystem(
+ neg_log_dens,
+ constr,
+ backend="symnum",
+)
# System is constrained therefore use constrained leapfrog integrator
-integrator = integrators.ConstrainedLeapfrogIntegrator(system)
+integrator = mici.integrators.ConstrainedLeapfrogIntegrator(system)
# Seed a random number generator
rng = np.random.default_rng(seed=1234)
# Use dynamic integration-time HMC implementation as MCMC sampler
-sampler = samplers.DynamicMultinomialHMC(system, integrator, rng)
+sampler = mici.samplers.DynamicMultinomialHMC(system, integrator, rng)
# Sample initial positions on torus using parameterisation (θ, ϕ) ∈ [0, 2π)²
# x, y, z = (R + r * cos(ϕ)) * cos(θ), (R + r * cos(ϕ)) * sin(θ), r * sin(ϕ)
n_chain = 4
θ_init, ϕ_init = rng.uniform(0, 2 * np.pi, size=(2, n_chain))
-q_init = np.stack([
- (R + r * np.cos(ϕ_init)) * np.cos(θ_init),
- (R + r * np.cos(ϕ_init)) * np.sin(θ_init),
- r * np.sin(ϕ_init)], -1)
+q_init = np.stack(
+ [
+ (R + r * np.cos(ϕ_init)) * np.cos(θ_init),
+ (R + r * np.cos(ϕ_init)) * np.sin(θ_init),
+ r * np.sin(ϕ_init),
+ ],
+ -1,
+)
+
# Define function to extract variables to trace during sampling
def trace_func(state):
x, y, z = state.pos
- return {'x': x, 'y': y, 'z': z}
+ return {"x": x, "y": y, "z": z}
+
# Sample 4 chains in parallel with 500 adaptive warm up iterations in which the
# integrator step size is tuned, followed by 2000 non-adaptive iterations
@@ -327,7 +357,7 @@ final_states, traces, stats = sampler.sample_chains(
n_main_iter=2000,
init_states=q_init,
n_process=4,
- trace_funcs=[trace_func]
+ trace_funcs=[trace_func],
)
# Print average accept probability and number of integrator steps per chain
@@ -340,19 +370,24 @@ for c in range(n_chain):
print(arviz.summary(traces))
# Visualize concatentated chain samples as animated 3D scatter plot
-fig = plt.figure(figsize=(4, 4))
-ax = Axes3D(fig, [0., 0., 1., 1.], proj_type='ortho')
-points_3d, = ax.plot(*(np.concatenate(traces[k]) for k in 'xyz'), '.', ms=0.5)
-ax.axis('off')
+fig, ax = plt.subplots(
+ figsize=(4, 4),
+ subplot_kw={"projection": "3d", "proj_type": "ortho"},
+)
+(points_3d,) = ax.plot(*(np.concatenate(traces[k]) for k in "xyz"), ".", ms=0.5)
+ax.axis("off")
for set_lim in [ax.set_xlim, ax.set_ylim, ax.set_zlim]:
set_lim((-1, 1))
+
def update(i):
angle = 45 * (np.sin(2 * np.pi * i / 60) + 1)
ax.view_init(elev=angle, azim=angle)
return (points_3d,)
-anim = animation.FuncAnimation(fig, update, frames=60, interval=100, blit=True)
+
+anim = animation.FuncAnimation(fig, update, frames=60, interval=100)
+plt.show()
```
## References
diff --git a/pyproject.toml b/pyproject.toml
index f50e056..16a6225 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -17,9 +17,9 @@ classifiers = [
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3 :: Only",
- "Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
+ "Programming Language :: Python :: 3.12",
"Typing :: Typed",
]
dependencies = [
@@ -38,7 +38,7 @@ keywords = [
]
name = "mici"
readme = "README.md"
-requires-python = ">=3.9"
+requires-python = ">=3.10"
license.file = "LICENCE"
urls.homepage = "https://github.com/matt-graham/mici"
urls.documentation = "https://matt-graham.github.io/mici"
@@ -54,9 +54,16 @@ dev = [
"tox>=4",
"twine",
]
-autodiff = [
+autograd = [
"autograd>=1.3",
- "multiprocess>=0.7.0"
+ "multiprocess>=0.7.0",
+]
+jax = [
+ "jax>=0.4.1",
+ "multiprocess>=0.7.0",
+]
+symnum = [
+ "symnum>=0.2.1",
]
[tool.coverage]
@@ -132,7 +139,7 @@ select = [
"W",
"YTT",
]
-target-version = "py39"
+target-version = "py310"
isort.known-first-party = [
"mici",
]
@@ -152,21 +159,25 @@ write_to = "src/mici/_version.py"
legacy_tox_ini = """
[gh-actions]
python =
- 3.9: py39
3.10: py310
3.11: py311
+ 3.12: py312
[testenv]
commands =
pytest --cov {posargs}
+ extras =
+ autograd
+ jax
+ symnum
deps =
pytest
pytest-cov
- autograd>=1.3
- multiprocess>=0.7.0
pystan
pymc>=5
arviz
+ set_env =
+ JAX_ENABLE_X64=1
[testenv:docs]
commands =
@@ -178,7 +189,7 @@ legacy_tox_ini = """
[tox]
env_list =
- py39
py310
py311
+ py312
"""
diff --git a/src/mici/autodiff.py b/src/mici/autodiff.py
deleted file mode 100644
index ce69121..0000000
--- a/src/mici/autodiff.py
+++ /dev/null
@@ -1,70 +0,0 @@
-"""Automatic differentation fallback for constructing derivative functions."""
-
-from __future__ import annotations
-
-from typing import TYPE_CHECKING
-
-from mici import autograd_wrapper
-
-if TYPE_CHECKING:
- from typing import Callable, Optional
-
-
-"""List of names of valid differential operators.
-
-Any automatic differentiation framework wrapper module will need to provide all of these
-operators as callables (with a single function as argument) to fully support all of the
-required derivative functions.
-"""
-DIFF_OPS = [
- # vector Jacobian product and value
- "vjp_and_value",
- # gradient and value for scalar valued functions
- "grad_and_value",
- # Hessian matrix, gradient and value for scalar valued functions
- "hessian_grad_and_value",
- # matrix Tressian product, gradient and value for scalar valued functions
- "mtp_hessian_grad_and_value",
- # Jacobian matrix and value for vector valued functions
- "jacobian_and_value",
- # matrix Hessian product, Jacobian matrix and value for vector valued functions
- "mhp_jacobian_and_value",
-]
-
-
-def autodiff_fallback(
- diff_func: Optional[Callable],
- func: Callable,
- diff_op_name: str,
- name: str,
-) -> Callable:
- """Generate derivative function automatically if not provided.
-
- Uses automatic differentiation to generate a function corresponding to a
- differential operator applied to a function if an alternative implementation of the
- derivative function has not been provided.
-
- Args:
- diff_func: Either a callable implementing the required derivative function or
- `None` if none was provided.
- func: Function to differentiate.
- diff_op_name: String specifying name of differential operator from automatic
- differentiation framework wrapper to use to generate required derivative
- function.
- name: Name of derivative function to use in error message.
-
- Returns:
- `diff_func` value if not `None` otherwise generated derivative of `func` by
- applying named differential operator.
- """
- if diff_func is not None:
- return diff_func
- elif diff_op_name not in DIFF_OPS:
- msg = f"Differential operator {diff_op_name} is not defined."
- raise ValueError(msg)
- elif autograd_wrapper.AUTOGRAD_AVAILABLE:
- return getattr(autograd_wrapper, diff_op_name)(func)
- elif not autograd_wrapper.AUTOGRAD_AVAILABLE:
- msg = f"Autograd not available therefore {name} must be provided."
- raise ValueError(msg)
- return None
diff --git a/src/mici/autodiff/__init__.py b/src/mici/autodiff/__init__.py
new file mode 100644
index 0000000..4879534
--- /dev/null
+++ b/src/mici/autodiff/__init__.py
@@ -0,0 +1,180 @@
+"""Automatic differentation support for constructing derivative functions.
+
+Multiple automatic differentiation backends are supported:
+
+* `jax`: High performance array computing framework, with support for running
+ computations on accelerator devices and just-in-time (JIT) compilation. To
+ differentiate a function using the `jax` backend, the function must be defined in
+ terms of JAX primitives, for example using the functions in the `jax.numpy` API. By
+ default the derivative functions produced will be JIT compiled; a `jax_nojit` variant
+ is available if the function (or its derivative) is not compatible with JIT
+ compilation.
+* `autograd`: Autograd can automatically differentiate native Python and NumPy code. To
+ differentiate a function using the `autograd` backend it should be defined in terms
+ of functions from the `autograd.numpy` and `autograd.scipy` APIs. Compared to JAX,
+ the lack of JIT compilation in Autograd and features such as automatic vectorisation
+ make Autograd slower, and so JAX will generally be a better choice.
+* `symnum`: SymNum is a Python package that acts a bridge between NumPy and SymPy,
+ providing a NumPy-like interface that can be used to symbolically define functions
+ which take arrays as arguments and return arrays or scalars as values. To
+ differentiate a function using the `symnum` backend it should be defined in terms of
+ functions from the `symnum.numpy` API, and should have been decorated with
+ ``symnum.numpify`` with specified argument shapes. SymNum is intended for use in
+ generating the derivatives of 'simple' functions which compose a relatively small
+ number of operations and act on small array inputs. By reducing interpreter overheads
+ it can produce code which is cheaper to evaluate than corresponding Autograd or JAX
+ functions (including those using JIT compilation) in such cases, and which can be
+ serialised with the inbuilt Python pickle library allowing use for example in
+ libraries which use multiprocessing to implement parallelisation across multiple
+ processes.
+"""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, NamedTuple
+
+from mici.autodiff import autograd_wrapper, jax_wrapper, symnum_wrapper
+
+if TYPE_CHECKING:
+ from types import ModuleType
+ from typing import Callable, Optional
+
+
+"""Names of valid differential operators.
+
+Any automatic differentiation framework wrapper module will need to provide all of these
+operators as callables (with a single function as argument) to fully support all of the
+required derivative functions.
+"""
+DIFF_OPS = (
+ # vector Jacobian product and value
+ "vjp_and_value",
+ # gradient and value for scalar valued functions
+ "grad_and_value",
+ # Hessian matrix, gradient and value for scalar valued functions
+ "hessian_grad_and_value",
+ # matrix Tressian product, gradient and value for scalar valued functions
+ "mtp_hessian_grad_and_value",
+ # Jacobian matrix and value for vector valued functions
+ "jacobian_and_value",
+ # matrix Hessian product, Jacobian matrix and value for vector valued functions
+ "mhp_jacobian_and_value",
+)
+
+
+class AutodiffBackend(NamedTuple):
+ """Automatic differentiation backend framework.
+
+ Consists of a module defining differential operators, a boolean flag indicating if
+ backend is available in current environment and optionally a function wrapper which
+ applies any post processing required to functions.
+ """
+
+ module: ModuleType
+ available: bool
+ function_wrapper: Optional[Callable] = None
+
+
+"""Available autodifferentiation framework backends."""
+_REGISTERED_BACKENDS = {
+ "jax": AutodiffBackend(
+ jax_wrapper,
+ jax_wrapper.JAX_AVAILABLE,
+ jax_wrapper.jit_and_return_numpy_arrays,
+ ),
+ "jax_nojit": AutodiffBackend(
+ jax_wrapper,
+ jax_wrapper.JAX_AVAILABLE,
+ jax_wrapper.return_numpy_arrays,
+ ),
+ "autograd": AutodiffBackend(autograd_wrapper, autograd_wrapper.AUTOGRAD_AVAILABLE),
+ "symnum": AutodiffBackend(symnum_wrapper, symnum_wrapper.SYMNUM_AVAILABLE),
+}
+
+
+def _get_backend(name: str):
+ # Normalize name string to all lowercase to make invariant to capitalization
+ name = name.lower()
+ if name not in _REGISTERED_BACKENDS:
+ msg = (
+ f"Selected autodiff backend {name} not recognised: "
+ f"available options are {tuple(_REGISTERED_BACKENDS)}."
+ )
+ raise ValueError(msg)
+ return _REGISTERED_BACKENDS[name]
+
+
+def wrap_function(function: Callable, backend: Optional[str]):
+ """Apply function wrapper for automatic differentiation backend to a function.
+
+ Backends may define a function wrapper which applies any post processing required to
+ functions using framework - for example ensuring the function returns NumPy arrays
+ or just-in-time compiling the function.
+
+ Args:
+ function: Function to wrap.
+ backend: Name of automatic differentiation framework backend to use. If `None`
+ function is returned unchanged.
+
+ Returns:
+ Wrapped function.
+ """
+ if backend is None:
+ return function
+ backend = _get_backend(backend)
+ if backend.function_wrapper is not None:
+ return backend.function_wrapper(function)
+ else:
+ return function
+
+
+def autodiff_fallback(
+ diff_func: Optional[Callable],
+ func: Callable,
+ diff_op_name: str,
+ name: str,
+ backend: Optional[str],
+) -> Callable:
+ """Generate derivative function automatically if not provided.
+
+ Uses automatic differentiation to generate a function corresponding to a
+ differential operator applied to a function if an alternative implementation of the
+ derivative function has not been provided.
+
+ Args:
+ diff_func: Either a callable implementing the required derivative function or
+ `None` if none was provided.
+ func: Function to differentiate.
+ diff_op_name: String specifying name of differential operator from automatic
+ differentiation framework wrapper to use to generate required derivative
+ function.
+ name: Name of derivative function to use in error message.
+ backend: Name of automatic differentiation framework backend to use. If `None`
+ `diff_func` must be provided.
+
+ Returns:
+ `diff_func` value if not `None` otherwise generated derivative of `func` by
+ applying named differential operator from automatic differentiation backend.
+ """
+ if diff_func is not None:
+ return diff_func
+ elif diff_func is None and backend is None:
+ msg = (
+ f"Automatic differentiation backend specified as `None` so {name} must"
+ "be provided directly."
+ )
+ raise ValueError(msg)
+ elif diff_op_name not in DIFF_OPS:
+ msg = f"Differential operator {diff_op_name} is not defined."
+ raise ValueError(msg)
+ else:
+ autodiff_backend = _get_backend(backend)
+ if autodiff_backend.available:
+ diff_func = getattr(autodiff_backend.module, diff_op_name)(func)
+ return wrap_function(diff_func, backend)
+ else:
+ msg = (
+ f"{backend} selected as autodiff backend but is not available in "
+ f"current environment therefore {name} must be provided directly."
+ )
+ raise ValueError(msg)
diff --git a/src/mici/autodiff/autograd_wrapper.py b/src/mici/autodiff/autograd_wrapper.py
new file mode 100644
index 0000000..a3dfb3b
--- /dev/null
+++ b/src/mici/autodiff/autograd_wrapper.py
@@ -0,0 +1,146 @@
+"""Autograd differential operators."""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+AUTOGRAD_AVAILABLE = True
+try:
+ import autograd.numpy as np
+ from autograd.builtins import tuple as atuple
+ from autograd.core import make_vjp
+ from autograd.extend import vspace
+except ImportError:
+ AUTOGRAD_AVAILABLE = False
+
+if TYPE_CHECKING:
+
+ from mici.types import (
+ ArrayFunction,
+ GradientFunction,
+ HessianFunction,
+ JacobianFunction,
+ MatrixHessianProductFunction,
+ MatrixTressianProductFunction,
+ ScalarFunction,
+ VectorJacobianProductFunction,
+ )
+
+
+def grad_and_value(func: ScalarFunction) -> GradientFunction:
+ """Makes a function that returns both gradient and value of a function."""
+
+ def grad_and_value_func(x):
+ vjp, val = make_vjp(func, x)
+ if vspace(val).size != 1:
+ msg = "grad_and_value only applies to real scalar-output functions."
+ raise TypeError(msg)
+ return vjp(vspace(val).ones()), val
+
+ return grad_and_value_func
+
+
+def vjp_and_value(func: ScalarFunction) -> VectorJacobianProductFunction:
+ """
+ Makes a function that returns vector-Jacobian-product and value of a function.
+
+ For a vector-valued function `fun` the vector-Jacobian-product (VJP) is here
+ defined as a function of a vector `v` corresponding to
+
+ vjp(v) = v @ j
+
+ where `j` is the Jacobian of `f = fun(x)` wrt `x` i.e. the rank-2
+ tensor of first-order partial derivatives of the vector-valued function,
+ such that
+
+ j[i, k] = ∂f[i] / ∂x[k]
+ """
+
+ def vjp_and_value_func(x):
+ return make_vjp(func, x)
+
+ return vjp_and_value_func
+
+
+def jacobian_and_value(func: ArrayFunction) -> JacobianFunction:
+ """Makes a function that returns both the Jacobian and value of a function."""
+
+ def jacobian_and_value_func(x):
+ vjp, val = make_vjp(func, x)
+ val_vspace = vspace(val)
+ jacobian_shape = val_vspace.shape + vspace(x).shape
+ jacobian_rows = map(vjp, val_vspace.standard_basis())
+ return np.reshape(np.stack(jacobian_rows), jacobian_shape), val
+
+ return jacobian_and_value_func
+
+
+def mhp_jacobian_and_value(func: ArrayFunction) -> MatrixHessianProductFunction:
+ """
+ Makes a function that returns MHP, Jacobian and value of a function.
+
+ For a vector-valued function `fun` the matrix-Hessian-product (MHP) is here
+ defined as a function of a matrix `m` corresponding to
+
+ mhp(m) = sum(m[:, :, None] * h[:, :, :], axis=(0, 1))
+
+ where `h` is the vector-Hessian of `f = fun(x)` wrt `x` i.e. the rank-3
+ tensor of second-order partial derivatives of the vector-valued function,
+ such that
+
+ h[i, j, k] = ∂²f[i] / (∂x[j] ∂x[k])
+ """
+
+ def mhp_jacobian_and_value_func(x):
+ mhp, (jacob, val) = make_vjp(lambda x: atuple(jacobian_and_value(func)(x)), x)
+ return lambda m: mhp((m, vspace(val).zeros())), jacob, val
+
+ return mhp_jacobian_and_value_func
+
+
+def hessian_grad_and_value(func: ArrayFunction) -> HessianFunction:
+ """Makes a function that returns the Hessian, gradient & value of a function."""
+
+ def grad_func(x):
+ vjp, val = make_vjp(func, x)
+ return vjp(vspace(val).ones()), val
+
+ def hessian_grad_and_value_func(x):
+ x_vspace = vspace(x)
+ vjp_grad, (grad, val) = make_vjp(lambda x: atuple(grad_func(x)), x)
+ hessian_shape = x_vspace.shape + x_vspace.shape
+ zeros = vspace(val).zeros()
+ hessian_rows = (vjp_grad((v, zeros)) for v in x_vspace.standard_basis())
+ return np.reshape(np.stack(hessian_rows), hessian_shape), grad, val
+
+ return hessian_grad_and_value_func
+
+
+def mtp_hessian_grad_and_value(func: ArrayFunction) -> MatrixTressianProductFunction:
+ """
+ Makes a function that returns MTP, Jacobian and value of a function.
+
+ For a scalar-valued function `fun` the matrix-Tressian-product (MTP) is
+ here defined as a function of a matrix `m` corresponding to
+
+ mtp(m) = sum(m[:, :] * t[:, :, :], axis=(-1, -2))
+
+ where `t` is the 'Tressian' of `f = fun(x)` wrt `x` i.e. the 3D array of
+ third-order partial derivatives of the scalar-valued function such that
+
+ t[i, j, k] = ∂³f / (∂x[i] ∂x[j] ∂x[k])
+ """
+
+ def mtp_hessian_grad_and_value_func(x):
+ mtp, (hessian, grad, val) = make_vjp(
+ lambda x: atuple(hessian_grad_and_value(func)(x)),
+ x,
+ )
+ return (
+ lambda m: mtp((m, vspace(grad).zeros(), vspace(val).zeros())),
+ hessian,
+ grad,
+ val,
+ )
+
+ return mtp_hessian_grad_and_value_func
diff --git a/src/mici/autodiff/jax_wrapper.py b/src/mici/autodiff/jax_wrapper.py
new file mode 100644
index 0000000..ae0b064
--- /dev/null
+++ b/src/mici/autodiff/jax_wrapper.py
@@ -0,0 +1,199 @@
+"""JAX differential operators and helper functions."""
+
+from __future__ import annotations
+
+from functools import partial
+from typing import TYPE_CHECKING
+
+import numpy as np
+
+JAX_AVAILABLE = True
+try:
+ import jax
+ import jax.numpy as jnp
+except ImportError:
+ JAX_AVAILABLE = False
+
+if TYPE_CHECKING:
+ from mici.types import (
+ ArrayFunction,
+ ArrayLike,
+ GradientFunction,
+ JacobianFunction,
+ MatrixHessianProductFunction,
+ MatrixTressianProduct,
+ ScalarFunction,
+ ScalarLike,
+ VectorJacobianProductFunction,
+ )
+
+
+def jit_and_return_numpy_arrays(function, **jit_kwargs):
+ """Wrap a JIT compiled function returning JAX arrays to instead return NumPy arrays.
+
+ Args:
+ function: Function to wrap. Should return one of: a single JAX array, a callable
+ returning one or more JAX array or a tuple of one or more JAX arrays or
+ functions returning one or more JAX arrays.
+ **jit_kwargs: Any keyword arguments to pass to `jax.jit` operator.
+
+ Returns:
+ Wrapped function. Any values returned by original function which are JAX arrays
+ will instead be NumPy arrays, while any values which are callables returning
+ JAX arrays will instead return NumPy arrays.
+ """
+ jitted_function = jax.jit(function, **jit_kwargs)
+ return return_numpy_arrays(jitted_function)
+
+
+def return_numpy_arrays(function: callable) -> callable:
+ """Wrap a function returning JAX arrays to instead return NumPy arrays.
+
+ Args:
+ function: Function to wrap. Should return one of: a single JAX array, a callable
+ returning one or more JAX array or a tuple of one or more JAX arrays or
+ functions returning one or more JAX arrays.
+
+ Returns:
+ Wrapped function. Any values returned by original function which are JAX arrays
+ will instead be NumPy arrays, while any values which are callables returning
+ JAX arrays will instead return NumPy arrays.
+ """
+
+ def as_numpy_array(value):
+ if callable(value):
+ return return_numpy_arrays(value)
+ elif isinstance(value, jax.Array):
+ return np.asarray(value)
+ else:
+ return value
+
+ def function_returning_numpy_arrays(*args, **kwargs):
+ return_value = function(*args, **kwargs)
+ if isinstance(return_value, tuple):
+ return tuple(as_numpy_array(value) for value in return_value)
+ else:
+ return as_numpy_array(return_value)
+
+ return function_returning_numpy_arrays
+
+
+def grad_and_value(func: ScalarFunction) -> GradientFunction:
+ """Makes a function that returns both the Jacobian and value of a function."""
+
+ def grad_and_value_func(x):
+ value, grad = jax.value_and_grad(func)(x)
+ return grad, value
+
+ return grad_and_value_func
+
+
+def _detuple_vjp(vjp_func):
+ """Transform a VJP of function with one return value so it returns an array."""
+ return jax.tree_util.Partial(
+ lambda *args, **kwargs: vjp_func.func(*args, **kwargs)[0],
+ *vjp_func.args,
+ **vjp_func.keywords,
+ )
+
+
+def vjp_and_value(func: ArrayFunction) -> VectorJacobianProductFunction:
+ """Makes a function that returns vector-Jacobian-product and value of a function.
+
+ For a vector-valued function `fun` the vector-Jacobian-product (VJP) is here
+ defined as a function of a vector `v` corresponding to
+
+ vjp(v) = v @ j
+
+ where `j` is the Jacobian of `f = fun(x)` wrt `x` i.e. the rank-2
+ tensor of first-order partial derivatives of the vector-valued function,
+ such that
+
+ j[i, k] = ∂f[i] / ∂x[k]
+ """
+
+ def vjp_and_value_func(x):
+ value, vjp = jax.vjp(func, x)
+ return _detuple_vjp(vjp), value
+
+ return vjp_and_value_func
+
+
+def jacobian_and_value(func: ArrayFunction) -> JacobianFunction:
+ """Makes a function that returns both the Jacobian and value of a function."""
+
+ def jacobian_and_value_func(x):
+ value, pullback = jax.vjp(func, x)
+ basis = jnp.eye(value.size, dtype=value.dtype)
+ (jac,) = jax.vmap(pullback)(basis)
+ return jac, value
+
+ return jacobian_and_value_func
+
+
+def mhp_jacobian_and_value(func: ArrayFunction) -> MatrixHessianProductFunction:
+ """
+ Makes a function that returns MHP, Jacobian and value of a function.
+
+ For a vector-valued function `fun` the matrix-Hessian-product (MHP) is here
+ defined as a function of a matrix `m` corresponding to
+
+ mhp(m) = sum(m[:, :, None] * h[:, :, :], axis=(0, 1))
+
+ where `h` is the vector-Hessian of `f = fun(x)` wrt `x` i.e. the rank-3
+ tensor of second-order partial derivatives of the vector-valued function,
+ such that
+
+ h[i, j, k] = ∂²f[i] / (∂x[j] ∂x[k])
+ """
+
+ def mhp_jacobian_and_value_func(x):
+ jac, mhp, value = jax.vjp(jacobian_and_value(func), x, has_aux=True)
+ return _detuple_vjp(mhp), jac, value
+
+ return mhp_jacobian_and_value_func
+
+
+def hessian_grad_and_value(
+ func: ScalarFunction,
+) -> tuple[ArrayLike, ArrayLike, ScalarLike]:
+ """Makes a function that returns the Hessian, gradient and value of a function."""
+
+ def hessian_grad_and_value_func(x):
+ basis = jnp.eye(x.size, dtype=x.dtype)
+ grad_and_value_func = grad_and_value(func)
+ pushforward = partial(jax.jvp, grad_and_value_func, (x,), has_aux=True)
+ grad, hessian, value = jax.vmap(pushforward, out_axes=(None, -1, None))(
+ (basis,),
+ )
+ return hessian, grad, value
+
+ return hessian_grad_and_value_func
+
+
+def mtp_hessian_grad_and_value(
+ func: ScalarFunction,
+) -> tuple[MatrixTressianProduct, ArrayLike, ArrayLike, ScalarLike]:
+ """
+ Makes a function that returns MTP, Jacobian and value of a function.
+
+ For a scalar-valued function `fun` the matrix-Tressian-product (MTP) is
+ here defined as a function of a matrix `m` corresponding to
+
+ mtp(m) = sum(m[:, :] * t[:, :, :], axis=(-1, -2))
+
+ where `t` is the 'Tressian' of `f = fun(x)` wrt `x` i.e. the 3D array of
+ third-order partial derivatives of the scalar-valued function such that
+
+ t[i, j, k] = ∂³f / (∂x[i] ∂x[j] ∂x[k])
+ """
+
+ def hessian_and_aux_func(x):
+ hessian, grad, value = hessian_grad_and_value(func)(x)
+ return hessian, (grad, value)
+
+ def mtp_hessian_grad_and_value_func(x):
+ hessian, mtp, (grad, value) = jax.vjp(hessian_and_aux_func, x, has_aux=True)
+ return _detuple_vjp(mtp), hessian, grad, value
+
+ return mtp_hessian_grad_and_value_func
diff --git a/src/mici/autodiff/symnum_wrapper.py b/src/mici/autodiff/symnum_wrapper.py
new file mode 100644
index 0000000..f10d0d6
--- /dev/null
+++ b/src/mici/autodiff/symnum_wrapper.py
@@ -0,0 +1,94 @@
+"""SymNum differential operators and helper functions."""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+SYMNUM_AVAILABLE = True
+try:
+ import symnum
+except ImportError:
+ SYMNUM_AVAILABLE = False
+
+if TYPE_CHECKING:
+ from mici.types import (
+ ArrayFunction,
+ ArrayLike,
+ GradientFunction,
+ JacobianFunction,
+ MatrixHessianProductFunction,
+ MatrixTressianProduct,
+ ScalarFunction,
+ ScalarLike,
+ VectorJacobianProductFunction,
+ )
+
+
+def grad_and_value(func: ScalarFunction) -> GradientFunction:
+ """Makes a function that returns both the Jacobian and value of a function."""
+ return symnum.grad(func, return_aux=True)
+
+
+def vjp_and_value(func: ArrayFunction) -> VectorJacobianProductFunction:
+ """Makes a function that returns vector-Jacobian-product and value of a function.
+
+ For a vector-valued function `fun` the vector-Jacobian-product (VJP) is here
+ defined as a function of a vector `v` corresponding to
+
+ vjp(v) = v @ j
+
+ where `j` is the Jacobian of `f = fun(x)` wrt `x` i.e. the rank-2
+ tensor of first-order partial derivatives of the vector-valued function,
+ such that
+
+ j[i, k] = ∂f[i] / ∂x[k]
+ """
+ return symnum.vector_jacobian_product(func, return_aux=True)
+
+
+def jacobian_and_value(func: ArrayFunction) -> JacobianFunction:
+ """Makes a function that returns both the Jacobian and value of a function."""
+ return symnum.jacobian(func, return_aux=True)
+
+
+def mhp_jacobian_and_value(func: ArrayFunction) -> MatrixHessianProductFunction:
+ """Makes a function that returns MHP, Jacobian and value of a function.
+
+ For a vector-valued function `fun` the matrix-Hessian-product (MHP) is here
+ defined as a function of a matrix `m` corresponding to
+
+ mhp(m) = sum(m[:, :, None] * h[:, :, :], axis=(0, 1))
+
+ where `h` is the vector-Hessian of `f = fun(x)` wrt `x` i.e. the rank-3
+ tensor of second-order partial derivatives of the vector-valued function,
+ such that
+
+ h[i, j, k] = ∂²f[i] / (∂x[j] ∂x[k])
+ """
+ return symnum.matrix_hessian_product(func, return_aux=True)
+
+
+def hessian_grad_and_value(
+ func: ScalarFunction,
+) -> tuple[ArrayLike, ArrayLike, ScalarLike]:
+ """Makes a function that returns the Hessian, gradient and value of a function."""
+ return symnum.hessian(func, return_aux=True)
+
+
+def mtp_hessian_grad_and_value(
+ func: ScalarFunction,
+) -> tuple[MatrixTressianProduct, ArrayLike, ArrayLike, ScalarLike]:
+ """
+ Makes a function that returns MTP, Jacobian and value of a function.
+
+ For a scalar-valued function `fun` the matrix-Tressian-product (MTP) is
+ here defined as a function of a matrix `m` corresponding to
+
+ mtp(m) = sum(m[:, :] * t[:, :, :], axis=(-1, -2))
+
+ where `t` is the 'Tressian' of `f = fun(x)` wrt `x` i.e. the 3D array of
+ third-order partial derivatives of the scalar-valued function such that
+
+ t[i, j, k] = ∂³f / (∂x[i] ∂x[j] ∂x[k])
+ """
+ return symnum.matrix_tressian_product(func, return_aux=True)
diff --git a/src/mici/autograd_wrapper.py b/src/mici/autograd_wrapper.py
deleted file mode 100644
index 540f8b3..0000000
--- a/src/mici/autograd_wrapper.py
+++ /dev/null
@@ -1,153 +0,0 @@
-"""Additional autograd differential operators."""
-
-from __future__ import annotations
-
-from functools import wraps
-from typing import TYPE_CHECKING
-
-AUTOGRAD_AVAILABLE = True
-try:
- import autograd.numpy as np
- from autograd.builtins import tuple as atuple
- from autograd.core import make_vjp
- from autograd.extend import vspace
- from autograd.wrap_util import unary_to_nary
-except ImportError:
- AUTOGRAD_AVAILABLE = False
-
-if TYPE_CHECKING:
- from typing import Callable
-
- from mici.types import (
- ArrayFunction,
- ArrayLike,
- MatrixHessianProduct,
- MatrixTressianProduct,
- ScalarFunction,
- ScalarLike,
- )
-
-
-def _wrapped_unary_to_nary(func: Callable) -> Callable:
- """Use functools.wraps with unary_to_nary decorator."""
- if AUTOGRAD_AVAILABLE:
- return wraps(func)(unary_to_nary(func))
- else:
- return func
-
-
-@_wrapped_unary_to_nary
-def grad_and_value(fun: ScalarFunction, x: ArrayLike) -> tuple[ArrayLike, ScalarLike]:
- """Makes a function that returns both gradient and value of a function."""
- vjp, val = make_vjp(fun, x)
- if vspace(val).size != 1:
- msg = "grad_and_value only applies to real scalar-output functions."
- raise TypeError(msg)
- return vjp(vspace(val).ones()), val
-
-
-@_wrapped_unary_to_nary
-def jacobian_and_value(fun: ArrayFunction, x: ArrayLike) -> tuple[ArrayLike, ArrayLike]:
- """
- Makes a function that returns both the Jacobian and value of a function.
-
- Assumes that the function `fun` broadcasts along the first dimension of the
- input being differentiated with respect to such that a batch of outputs can
- be computed concurrently for a batch of inputs.
- """
- val = fun(x)
- v_vspace = vspace(val)
- x_vspace = vspace(x)
- x_rep = np.tile(x, (v_vspace.size,) + (1,) * x_vspace.ndim)
- vjp_rep, _ = make_vjp(fun, x_rep)
- jacobian_shape = v_vspace.shape + x_vspace.shape
- basis_vectors = np.array(list(v_vspace.standard_basis()))
- jacobian = vjp_rep(basis_vectors)
- return np.reshape(jacobian, jacobian_shape), val
-
-
-@_wrapped_unary_to_nary
-def mhp_jacobian_and_value(
- fun: ArrayFunction,
- x: ArrayLike,
-) -> tuple[MatrixHessianProduct, ArrayLike, ArrayLike]:
- """
- Makes a function that returns MHP, Jacobian and value of a function.
-
- For a vector-valued function `fun` the matrix-Hessian-product (MHP) is here
- defined as a function of a matrix `m` corresponding to
-
- mhp(m) = sum(m[:, :, None] * h[:, :, :], axis=(0, 1))
-
- where `h` is the vector-Hessian of `f = fun(x)` wrt `x` i.e. the rank-3
- tensor of second-order partial derivatives of the vector-valued function,
- such that
-
- h[i, j, k] = ∂²f[i] / (∂x[j] ∂x[k])
-
- Assumes that the function `fun` broadcasts along the first dimension of the
- input being differentiated with respect to such that a batch of outputs can
- be computed concurrently for a batch of inputs.
- """
- mhp, (jacob, val) = make_vjp(lambda x: atuple(jacobian_and_value(fun)(x)), x)
- return lambda m: mhp((m, vspace(val).zeros())), jacob, val
-
-
-@_wrapped_unary_to_nary
-def hessian_grad_and_value(
- fun: ArrayFunction,
- x: ArrayLike,
-) -> tuple[ArrayLike, ArrayLike, ScalarLike]:
- """
- Makes a function that returns the Hessian, gradient & value of a function.
-
- Assumes that the function `fun` broadcasts along the first dimension of the
- input being differentiated with respect to such that a batch of outputs can
- be computed concurrently for a batch of inputs.
- """
-
- def grad_fun(x):
- vjp, val = make_vjp(fun, x)
- return vjp(vspace(val).ones()), val
-
- x_vspace = vspace(x)
- x_rep = np.tile(x, (x_vspace.size,) + (1,) * x_vspace.ndim)
- vjp_grad, (grad, val) = make_vjp(lambda x: atuple(grad_fun(x)), x_rep)
- hessian_shape = x_vspace.shape + x_vspace.shape
- basis_vectors = np.array(list(x_vspace.standard_basis()))
- hessian = vjp_grad((basis_vectors, vspace(val).zeros()))
- return np.reshape(hessian, hessian_shape), grad[0], val[0]
-
-
-@_wrapped_unary_to_nary
-def mtp_hessian_grad_and_value(
- fun: ArrayFunction,
- x: ArrayLike,
-) -> tuple[MatrixTressianProduct, ArrayLike, ArrayLike, ScalarLike]:
- """
- Makes a function that returns MTP, Jacobian and value of a function.
-
- For a scalar-valued function `fun` the matrix-Tressian-product (MTP) is
- here defined as a function of a matrix `m` corresponding to
-
- mtp(m) = sum(m[:, :] * t[:, :, :], axis=(-1, -2))
-
- where `t` is the 'Tressian' of `f = fun(x)` wrt `x` i.e. the 3D array of
- third-order partial derivatives of the scalar-valued function such that
-
- t[i, j, k] = ∂³f / (∂x[i] ∂x[j] ∂x[k])
-
- Assumes that the function `fun` broadcasts along the first dimension of the
- input being differentiated with respect to such that a batch of outputs can
- be computed concurrently for a batch of inputs.
- """
- mtp, (hessian, grad, val) = make_vjp(
- lambda x: atuple(hessian_grad_and_value(fun)(x)),
- x,
- )
- return (
- lambda m: mtp((m, vspace(grad).zeros(), vspace(val).zeros())),
- hessian,
- grad,
- val,
- )
diff --git a/src/mici/interop.py b/src/mici/interop.py
index 805eb52..abf07b3 100644
--- a/src/mici/interop.py
+++ b/src/mici/interop.py
@@ -226,6 +226,7 @@ def sample_pymc_model(
system = mici.systems.EuclideanMetricSystem(
neg_log_dens=neg_log_dens,
grad_neg_log_dens=grad_neg_log_dens,
+ backend=None,
)
integrator = mici.integrators.LeapfrogIntegrator(system)
@@ -423,6 +424,7 @@ def sample_stan_model(
system = mici.systems.EuclideanMetricSystem(
neg_log_dens=neg_log_dens,
grad_neg_log_dens=grad_neg_log_dens,
+ backend=None,
)
integrator = mici.integrators.LeapfrogIntegrator(system, step_size=stepsize)
diff --git a/src/mici/systems.py b/src/mici/systems.py
index dbc523b..cb5c598 100644
--- a/src/mici/systems.py
+++ b/src/mici/systems.py
@@ -8,7 +8,7 @@
import numpy as np
from mici import matrices
-from mici.autodiff import autodiff_fallback
+from mici.autodiff import autodiff_fallback, wrap_function
from mici.states import cache_in_state, cache_in_state_with_aux
if TYPE_CHECKING:
@@ -62,6 +62,7 @@ def __init__(
neg_log_dens: ScalarFunction,
*,
grad_neg_log_dens: Optional[GradientFunction] = None,
+ backend: Optional[str] = None,
):
"""
Args:
@@ -78,13 +79,19 @@ def __init__(
position array. If `None` is passed (the default) an automatic
differentiation fallback will be used to attempt to construct the
derivative of `neg_log_dens` automatically.
+ backend: Name of automatic differentiation backend to use. See
+ :py:mod:`.autodiff` subpackage documentation for details of available
+ backends. If `None` (the default) no automatic differentiation fallback
+ will be used and so all derivative functions must be specified
+ explicitly.
"""
- self._neg_log_dens = neg_log_dens
+ self._neg_log_dens = wrap_function(neg_log_dens, backend)
self._grad_neg_log_dens = autodiff_fallback(
grad_neg_log_dens,
neg_log_dens,
"grad_and_value",
"grad_neg_log_dens",
+ backend,
)
@cache_in_state("pos")
@@ -285,6 +292,7 @@ def __init__(
*,
metric: Optional[MetricLike] = None,
grad_neg_log_dens: Optional[GradientFunction] = None,
+ backend: Optional[str] = None,
):
"""
Args:
@@ -311,8 +319,17 @@ def __init__(
position array. If `None` is passed (the default) an automatic
differentiation fallback will be used to attempt to construct the
derivative of `neg_log_dens` automatically.
+ backend: Name of automatic differentiation backend to use. See
+ :py:mod:`.autodiff` subpackage documentation for details of available
+ backends. If `None` (the default) no automatic differentiation fallback
+ will be used and so all derivative functions must be specified
+ explicitly.
"""
- super().__init__(neg_log_dens=neg_log_dens, grad_neg_log_dens=grad_neg_log_dens)
+ super().__init__(
+ neg_log_dens=neg_log_dens,
+ grad_neg_log_dens=grad_neg_log_dens,
+ backend=backend,
+ )
if metric is None:
self.metric = matrices.IdentityMatrix()
elif isinstance(metric, np.ndarray):
@@ -392,6 +409,7 @@ def __init__(
*,
metric: Optional[MetricLike] = None,
grad_neg_log_dens: Optional[GradientFunction] = None,
+ backend: Optional[str] = None,
):
"""
Args:
@@ -418,11 +436,17 @@ def __init__(
position array. If `None` is passed (the default) an automatic
differentiation fallback will be used to attempt to construct the
derivative of `neg_log_dens` automatically.
+ backend: Name of automatic differentiation backend to use. See
+ :py:mod:`.autodiff` subpackage documentation for details of available
+ backends. If `None` (the default) no automatic differentiation fallback
+ will be used and so all derivative functions must be specified
+ explicitly.
"""
super().__init__(
neg_log_dens=neg_log_dens,
metric=metric,
grad_neg_log_dens=grad_neg_log_dens,
+ backend=backend,
)
def h2(self, state: ChainState) -> ScalarLike:
@@ -667,6 +691,7 @@ def __init__(
dens_wrt_hausdorff: bool = True,
grad_neg_log_dens: Optional[GradientFunction] = None,
jacob_constr: Optional[JacobianFunction] = None,
+ backend: Optional[str] = None,
):
r"""
Args:
@@ -692,7 +717,7 @@ def __init__(
where :code:`jacob_constr` is the Jacobian of the constraint function
:code:`constr` and :code:`metric` is the matrix representation of the
metric on the ambient space.
- constr: Function which given a position rray return as a 1D array the value
+ constr: Function which given a position array return as a 1D array the value
of the (vector-valued) constraint function, the zero level-set of which
implicitly defines the manifold the dynamic is simulated on.
metric: Matrix object corresponding to matrix representation of metric on
@@ -737,19 +762,26 @@ def __init__(
used to attempt to construct a function to compute the Jacobian (and
value) of :code:`constr`
automatically.
+ backend: Name of automatic differentiation backend to use. See
+ :py:mod:`.autodiff` subpackage documentation for details of available
+ backends. If `None` (the default) no automatic differentiation fallback
+ will be used and so all derivative functions must be specified
+ explicitly.
"""
super().__init__(
neg_log_dens=neg_log_dens,
metric=metric,
grad_neg_log_dens=grad_neg_log_dens,
+ backend=backend,
)
- self._constr = constr
+ self._constr = wrap_function(constr, backend)
self.dens_wrt_hausdorff = dens_wrt_hausdorff
self._jacob_constr = autodiff_fallback(
jacob_constr,
constr,
"jacobian_and_value",
"jacob_constr",
+ backend,
)
@cache_in_state("pos")
@@ -861,6 +893,7 @@ def __init__(
grad_neg_log_dens: Optional[GradientFunction] = None,
jacob_constr: Optional[JacobianFunction] = None,
mhp_constr: Optional[MatrixHessianProductFunction] = None,
+ backend: Optional[str] = None,
):
"""
Args:
@@ -886,7 +919,7 @@ def __init__(
where :code:`jacob_constr` is the Jacobian of the constraint function
:code:`constr` and :code:`metric` is the matrix representation of the
metric on the ambient space.
- constr: Function which given a position rray return as a 1D array the value
+ constr: Function which given a position array return as a 1D array the value
of the (vector-valued) constraint function, the zero level-set of which
implicitly defines the manifold the dynamic is simulated on.
metric: Matrix object corresponding to matrix representation of metric on
@@ -949,6 +982,11 @@ def __init__(
fallback will be used to attempt to construct a function which
calculates the MHP (and Jacobian and value) of :code:`constr`
automatically.
+ backend: Name of automatic differentiation backend to use. See
+ :py:mod:`.autodiff` subpackage documentation for details of available
+ backends. If `None` (the default) no automatic differentiation fallback
+ will be used and so all derivative functions must be specified
+ explicitly.
"""
super().__init__(
neg_log_dens=neg_log_dens,
@@ -957,6 +995,7 @@ def __init__(
dens_wrt_hausdorff=dens_wrt_hausdorff,
grad_neg_log_dens=grad_neg_log_dens,
jacob_constr=jacob_constr,
+ backend=backend,
)
if not dens_wrt_hausdorff:
self._mhp_constr = autodiff_fallback(
@@ -964,6 +1003,7 @@ def __init__(
constr,
"mhp_jacobian_and_value",
"mhp_constr",
+ backend,
)
@cache_in_state_with_aux("pos", ("jacob_constr", "constr"))
@@ -1015,6 +1055,7 @@ def __init__(
grad_neg_log_dens: Optional[GradientFunction] = None,
jacob_constr: Optional[JacobianFunction] = None,
mhp_constr: Optional[MatrixHessianProductFunction] = None,
+ backend: Optional[str] = None,
):
"""
Args:
@@ -1099,6 +1140,11 @@ def __init__(
passed (the default) an automatic differentiation fallback will be used
to attempt to construct a function which calculates the MHP (and
Jacobian and value) of `constr` automatically.
+ backend: Name of automatic differentiation backend to use. See
+ :py:mod:`.autodiff` subpackage documentation for details of available
+ backends. If `None` (the default) no automatic differentiation fallback
+ will be used and so all derivative functions must be specified
+ explicitly.
"""
DenseConstrainedEuclideanMetricSystem.__init__(
self,
@@ -1109,6 +1155,7 @@ def __init__(
grad_neg_log_dens=grad_neg_log_dens,
jacob_constr=jacob_constr,
mhp_constr=mhp_constr,
+ backend=backend,
)
def jacob_constr_inner_product(
@@ -1192,6 +1239,7 @@ def __init__(
vjp_metric_func: Optional[VectorJacobianProductFunction] = None,
grad_neg_log_dens: Optional[GradientFunction] = None,
metric_kwargs: Optional[dict[str, Any]] = None,
+ backend: Optional[str] = None,
):
"""
Args:
@@ -1258,14 +1306,20 @@ def __init__(
derivative of `neg_log_dens` automatically.
metric_kwargs: An optional dictionary of any additional keyword arguments to
the initializer of `metric_matrix_class`.
+ backend: Name of automatic differentiation backend to use. See
+ :py:mod:`.autodiff` subpackage documentation for details of available
+ backends. If `None` (the default) no automatic differentiation fallback
+ will be used and so all derivative functions must be specified
+ explicitly.
"""
self._metric_matrix_class = metric_matrix_class
- self._metric_func = metric_func
+ self._metric_func = wrap_function(metric_func, backend)
self._vjp_metric_func = autodiff_fallback(
vjp_metric_func,
metric_func,
"vjp_and_value",
"vjp_metric_func",
+ backend,
)
self._metric_kwargs = {} if metric_kwargs is None else metric_kwargs
super().__init__(neg_log_dens, grad_neg_log_dens=grad_neg_log_dens)
@@ -1376,6 +1430,7 @@ def __init__(
*,
vjp_metric_scalar_func: Optional[VectorJacobianProductFunction] = None,
grad_neg_log_dens: Optional[GradientFunction] = None,
+ backend: Optional[str] = None,
):
"""
Args:
@@ -1420,6 +1475,11 @@ def __init__(
position array. If `None` is passed (the default) an automatic
differentiation fallback will be used to attempt to construct the
derivative of `neg_log_dens` automatically.
+ backend: Name of automatic differentiation backend to use. See
+ :py:mod:`.autodiff` subpackage documentation for details of available
+ backends. If `None` (the default) no automatic differentiation fallback
+ will be used and so all derivative functions must be specified
+ explicitly.
"""
super().__init__(
neg_log_dens=neg_log_dens,
@@ -1427,6 +1487,7 @@ def __init__(
metric_func=metric_scalar_func,
vjp_metric_func=vjp_metric_scalar_func,
grad_neg_log_dens=grad_neg_log_dens,
+ backend=backend,
)
@cache_in_state("pos")
@@ -1457,6 +1518,7 @@ def __init__(
*,
vjp_metric_diagonal_func: Optional[VectorJacobianProductFunction] = None,
grad_neg_log_dens: Optional[GradientFunction] = None,
+ backend: Optional[str] = None,
):
"""
Args:
@@ -1501,6 +1563,11 @@ def __init__(
position array. If `None` is passed (the default) an automatic
differentiation fallback will be used to attempt to construct the
derivative of `neg_log_dens` automatically.
+ backend: Name of automatic differentiation backend to use. See
+ :py:mod:`.autodiff` subpackage documentation for details of available
+ backends. If `None` (the default) no automatic differentiation fallback
+ will be used and so all derivative functions must be specified
+ explicitly.
"""
super().__init__(
neg_log_dens=neg_log_dens,
@@ -1508,6 +1575,7 @@ def __init__(
metric_func=metric_diagonal_func,
vjp_metric_func=vjp_metric_diagonal_func,
grad_neg_log_dens=grad_neg_log_dens,
+ backend=backend,
)
@@ -1530,6 +1598,7 @@ def __init__(
*,
vjp_metric_chol_func: Optional[VectorJacobianProductFunction] = None,
grad_neg_log_dens: Optional[GradientFunction] = None,
+ backend: Optional[str] = None,
):
"""
Args:
@@ -1575,6 +1644,11 @@ def __init__(
position array. If `None` is passed (the default) an automatic
differentiation fallback will be used to attempt to construct the
derivative of `neg_log_dens` automatically.
+ backend: Name of automatic differentiation backend to use. See
+ :py:mod:`.autodiff` subpackage documentation for details of available
+ backends. If `None` (the default) no automatic differentiation fallback
+ will be used and so all derivative functions must be specified
+ explicitly.
"""
super().__init__(
neg_log_dens=neg_log_dens,
@@ -1583,6 +1657,7 @@ def __init__(
vjp_metric_func=vjp_metric_chol_func,
grad_neg_log_dens=grad_neg_log_dens,
metric_kwargs={"factor_is_lower": True},
+ backend=backend,
)
@@ -1605,6 +1680,7 @@ def __init__(
*,
vjp_metric_func: Optional[VectorJacobianProductFunction] = None,
grad_neg_log_dens: Optional[GradientFunction] = None,
+ backend: Optional[str] = None,
):
"""
Args:
@@ -1650,6 +1726,11 @@ def __init__(
position array. If `None` is passed (the default) an automatic
differentiation fallback will be used to attempt to construct the
derivative of `neg_log_dens` automatically.
+ backend: Name of automatic differentiation backend to use. See
+ :py:mod:`.autodiff` subpackage documentation for details of available
+ backends. If `None` (the default) no automatic differentiation fallback
+ will be used and so all derivative functions must be specified
+ explicitly.
"""
super().__init__(
neg_log_dens=neg_log_dens,
@@ -1657,6 +1738,7 @@ def __init__(
metric_func=metric_func,
vjp_metric_func=vjp_metric_func,
grad_neg_log_dens=grad_neg_log_dens,
+ backend=backend,
)
@@ -1706,6 +1788,7 @@ def __init__(
hess_neg_log_dens: Optional[HessianFunction] = None,
mtp_neg_log_dens: Optional[MatrixTressianProductFunction] = None,
softabs_coeff: ScalarLike = 1.0,
+ backend: Optional[str] = None,
):
"""
Args:
@@ -1762,18 +1845,25 @@ def __init__(
to absolute value used to regularize Hessian eigenvalues in metric
matrix representation. As the value tends to infinity the approximation
becomes increasingly close to the absolute function.
+ backend: Name of automatic differentiation backend to use. See
+ :py:mod:`.autodiff` subpackage documentation for details of available
+ backends. If `None` (the default) no automatic differentiation fallback
+ will be used and so all derivative functions must be specified
+ explicitly.
"""
self._hess_neg_log_dens = autodiff_fallback(
hess_neg_log_dens,
neg_log_dens,
"hessian_grad_and_value",
"neg_log_dens",
+ backend,
)
self._mtp_neg_log_dens = autodiff_fallback(
mtp_neg_log_dens,
neg_log_dens,
"mtp_hessian_grad_and_value",
"mtp_neg_log_dens",
+ backend,
)
super().__init__(
neg_log_dens=neg_log_dens,
@@ -1782,6 +1872,7 @@ def __init__(
vjp_metric_func=self._mtp_neg_log_dens,
grad_neg_log_dens=grad_neg_log_dens,
metric_kwargs={"softabs_coeff": softabs_coeff},
+ backend=backend,
)
def metric_func(self, state: ChainState) -> ArrayLike:
diff --git a/tests/test_autodiff.py b/tests/test_autodiff.py
new file mode 100644
index 0000000..71b697e
--- /dev/null
+++ b/tests/test_autodiff.py
@@ -0,0 +1,326 @@
+import numpy as np
+import pytest
+
+from mici.autodiff import (
+ _REGISTERED_BACKENDS,
+ DIFF_OPS,
+ _get_backend,
+ autodiff_fallback,
+ wrap_function,
+)
+
+N_POINTS_TO_TEST = 5
+
+BACKENDS_AVAIALBLE = [
+ name for name, backend in _REGISTERED_BACKENDS.items() if backend.available
+]
+
+SCALAR_FUNCTION_DIFF_OPS = [
+ "grad_and_value",
+ "hessian_grad_and_value",
+ "mtp_hessian_grad_and_value",
+]
+VECTOR_FUNCTION_DIFF_OPS = ["jacobian_and_value", "mhp_jacobian_and_value"]
+
+
+def torus_function_and_derivatives(numpy_module):
+ toroidal_rad = 1.0
+ poloidal_rad = 0.5
+
+ def constr(q):
+ x, y, z = q
+ return numpy_module.array(
+ [((x**2 + y**2) ** 0.5 - toroidal_rad) ** 2 + z**2 - poloidal_rad**2],
+ )
+
+ def jacob_constr(q):
+ x, y, z = q
+ r = (x**2 + y**2) ** 0.5
+ return np.array(
+ [[2 * x * (r - toroidal_rad) / r, 2 * y * (r - toroidal_rad) / r, 2 * z]],
+ )
+
+ def mhp_constr(q):
+ x, y, z = q
+ r = (x**2 + y**2) ** 0.5
+ r_cubed = r**3
+ return lambda m: np.array(
+ [
+ 2 * (toroidal_rad / r_cubed) * (m[0, 0] * x**2 + m[0, 1] * x * y)
+ + 2 * m[0, 0] * (1 - toroidal_rad / r),
+ 2 * (toroidal_rad / r_cubed) * (m[0, 1] * y**2 + m[0, 0] * x * y)
+ + 2 * m[0, 1] * (1 - toroidal_rad / r),
+ 2 * m[0, 2],
+ ],
+ )
+
+ return {
+ "function": constr,
+ "jacobian_function": jacob_constr,
+ "mhp_function": mhp_constr,
+ }
+
+
+def linear_function_and_derivatives(_):
+
+ constr_matrix = np.array([[1.0, -1.0, 2.0, 3.0], [-3.0, 2.0, 0.0, 5.0]])
+
+ def constr(q):
+ return constr_matrix @ q
+
+ def jacob_constr(_):
+ return constr_matrix
+
+ def mhp_constr(_):
+ return lambda _: np.zeros(constr_matrix.shape[1])
+
+ return {
+ "function": constr,
+ "jacobian_function": jacob_constr,
+ "mhp_function": mhp_constr,
+ }
+
+
+def quadratic_form_function_and_derivatives(_):
+
+ matrix = np.array([[1.3, -0.2], [-0.2, 2.5]])
+
+ def quadratic_form(q):
+ return q @ matrix @ q / 2
+
+ def grad_quadratic_form(q):
+ return matrix @ q
+
+ def hessian_quadratic_form(_):
+ return matrix
+
+ def mtp_quadratic_form(_):
+ return lambda _: np.zeros(matrix.shape[0])
+
+ return {
+ "function": quadratic_form,
+ "grad_function": grad_quadratic_form,
+ "hessian_function": hessian_quadratic_form,
+ "mtp_function": mtp_quadratic_form,
+ }
+
+
+def cubic_function_and_derivatives(_):
+
+ def cubic(q):
+ return (q**3).sum() / 6
+
+ def grad_cubic(q):
+ return q**2 / 2
+
+ def hessian_cubic(q):
+ return np.diag(q)
+
+ def mtp_cubic(_):
+ return lambda m: m.diagonal()
+
+ return {
+ "function": cubic,
+ "grad_function": grad_cubic,
+ "hessian_function": hessian_cubic,
+ "mtp_function": mtp_cubic,
+ }
+
+
+def quartic_function_and_derivatives(_):
+
+ def quartic(q):
+ return (q**4).sum() / 24
+
+ def grad_quartic(q):
+ return q**3 / 6
+
+ def hessian_quartic(q):
+ return np.diag(q**2 / 2)
+
+ def mtp_quartic(q):
+ return lambda m: m.diagonal() * q
+
+ return {
+ "function": quartic,
+ "grad_function": grad_quartic,
+ "hessian_function": hessian_quartic,
+ "mtp_function": mtp_quartic,
+ }
+
+
+def numpify_function(function_and_derivatives, *arg_shapes):
+ import symnum
+
+ function_and_derivatives["function"] = symnum.numpify_func(
+ function_and_derivatives["function"],
+ *arg_shapes,
+ )
+
+
+@pytest.mark.parametrize("backend_name", BACKENDS_AVAIALBLE)
+def test_module_defines_all_diffops(backend_name):
+ backend = _get_backend(backend_name)
+ for diff_op_name in DIFF_OPS:
+ assert hasattr(backend.module, diff_op_name)
+ assert callable(getattr(backend.module, diff_op_name))
+
+
+def get_numpy_module(backend):
+ if backend in ("jax", "jax_nojit"):
+ import jax.numpy
+
+ return jax.numpy
+ elif backend == "autograd":
+ import autograd.numpy
+
+ return autograd.numpy
+ elif backend == "symnum":
+ import symnum.numpy
+
+ return symnum.numpy
+ else:
+ msg = f"Unrecognised backend {backend}"
+ raise ValueError(msg)
+
+
+@pytest.fixture
+def rng():
+ return np.random.default_rng(1234)
+
+
+@pytest.mark.parametrize("diff_op_name", DIFF_OPS)
+def test_autodiff_fallback_with_no_backend_raises(diff_op_name):
+ with pytest.raises(ValueError, match="None"):
+ autodiff_fallback(
+ None, lambda q: q, diff_op_name, diff_op_name + "_function", None,
+ )
+
+
+@pytest.mark.parametrize("backend_name", BACKENDS_AVAIALBLE)
+def test_autodiff_fallback_with_invalid_diff_op_raises(backend_name):
+ with pytest.raises(ValueError, match="not defined"):
+ autodiff_fallback(None, lambda q: q, "foo", "bar", backend_name)
+
+
+@pytest.mark.parametrize("backend_name", BACKENDS_AVAIALBLE)
+def test_wrap_function(backend_name):
+
+ def function(q):
+ return q**2
+
+ assert wrap_function(function, None) is function
+ wrapped_function = wrap_function(function, backend_name)
+ assert callable(wrapped_function)
+ test_input = np.arange(5)
+ output = wrapped_function(test_input)
+ assert isinstance(output, np.ndarray)
+ assert np.allclose(wrapped_function(test_input), function(test_input))
+
+
+@pytest.mark.parametrize("diff_op_name", VECTOR_FUNCTION_DIFF_OPS)
+@pytest.mark.parametrize("backend_name", BACKENDS_AVAIALBLE)
+@pytest.mark.parametrize(
+ "function_and_derivatives_and_dim",
+ [(torus_function_and_derivatives, 1, 3), (linear_function_and_derivatives, 2, 4)],
+ ids=lambda p: p[0].__name__,
+)
+def test_vector_function_diff_ops(
+ diff_op_name, backend_name, function_and_derivatives_and_dim, rng,
+):
+ construct_function_and_derivatives, dim_c, dim_q = function_and_derivatives_and_dim
+ numpy_module = get_numpy_module(backend_name)
+ function_and_derivatives = construct_function_and_derivatives(numpy_module)
+ if backend_name == "symnum":
+ numpify_function(function_and_derivatives, dim_q)
+ diff_op_function = autodiff_fallback(
+ None,
+ function_and_derivatives["function"],
+ diff_op_name,
+ diff_op_name + "_function",
+ backend_name,
+ )
+ for _ in range(N_POINTS_TO_TEST):
+ q = rng.standard_normal(dim_q)
+ derivatives_and_values = diff_op_function(q)
+ # diff_op_function returns derivatives in descending order (of derivative) while
+ # derivatives are in increasing order in function_and_derivatives
+ for (function_name, expected_value_function), test_value in zip(
+ function_and_derivatives.items(),
+ reversed(derivatives_and_values),
+ ):
+ if function_name.startswith("mhp"):
+ m = rng.standard_normal((dim_c, dim_q))
+ assert np.allclose(test_value(m), expected_value_function(q)(m))
+ else:
+ assert np.allclose(test_value, expected_value_function(q))
+
+
+@pytest.mark.parametrize("diff_op_name", SCALAR_FUNCTION_DIFF_OPS)
+@pytest.mark.parametrize("backend_name", BACKENDS_AVAIALBLE)
+@pytest.mark.parametrize(
+ "function_and_derivatives_and_dim_q",
+ [
+ (quadratic_form_function_and_derivatives, 2),
+ (cubic_function_and_derivatives, 1),
+ (cubic_function_and_derivatives, 3),
+ (quartic_function_and_derivatives, 2),
+ ],
+ ids=lambda p: p[0].__name__,
+)
+def test_scalar_function_diff_ops(
+ diff_op_name, backend_name, function_and_derivatives_and_dim_q, rng,
+):
+ construct_function_and_derivatives, dim_q = function_and_derivatives_and_dim_q
+ numpy_module = get_numpy_module(backend_name)
+ function_and_derivatives = construct_function_and_derivatives(numpy_module)
+ if backend_name == "symnum":
+ numpify_function(function_and_derivatives, dim_q)
+ diff_op_function = autodiff_fallback(
+ None,
+ function_and_derivatives["function"],
+ diff_op_name,
+ diff_op_name + "_function",
+ backend_name,
+ )
+ for _ in range(N_POINTS_TO_TEST):
+ q = rng.standard_normal(dim_q)
+ derivatives_and_values = diff_op_function(q)
+ # diff_op_function returns derivatives in descending order (of derivative) while
+ # derivatives are in increasing order in function_and_derivatives
+ for (function_name, expected_value_function), test_value in zip(
+ function_and_derivatives.items(),
+ reversed(derivatives_and_values),
+ ):
+ if function_name.startswith("mtp"):
+ m = rng.standard_normal((dim_q, dim_q))
+ assert np.allclose(test_value(m), expected_value_function(q)(m))
+ else:
+ assert np.allclose(test_value, expected_value_function(q))
+
+
+@pytest.mark.parametrize("backend_name", BACKENDS_AVAIALBLE)
+@pytest.mark.parametrize(
+ "function_and_derivatives_and_dim",
+ [(torus_function_and_derivatives, 1, 3), (linear_function_and_derivatives, 2, 4)],
+ ids=lambda p: p[0].__name__,
+)
+def test_vjp_and_value(backend_name, function_and_derivatives_and_dim, rng):
+ construct_function_and_derivatives, dim_c, dim_q = function_and_derivatives_and_dim
+ numpy_module = get_numpy_module(backend_name)
+ function_and_derivatives = construct_function_and_derivatives(numpy_module)
+ if backend_name == "symnum":
+ numpify_function(function_and_derivatives, dim_q)
+ vjp_and_value_function = autodiff_fallback(
+ None,
+ function_and_derivatives["function"],
+ "vjp_and_value",
+ "vjp_and_value_function",
+ backend_name,
+ )
+ for _ in range(N_POINTS_TO_TEST):
+ q = rng.standard_normal(dim_q)
+ vjp, value = vjp_and_value_function(q)
+ assert np.allclose(function_and_derivatives["function"](q), value)
+ v = rng.standard_normal(value.shape)
+ assert np.allclose(v @ function_and_derivatives["jacobian_function"](q), vjp(v))