diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index f9827f2..8f08243 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -20,7 +20,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.x cache: "pip" cache-dependency-path: "pyproject.toml" - name: Install tox @@ -35,4 +35,4 @@ jobs: publish_dir: docs/_build/html publish_branch: gh-pages user_name: "github-actions[bot]" - user_email: "github-actions[bot]@users.noreply.github.com" \ No newline at end of file + user_email: "github-actions[bot]@users.noreply.github.com" diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 19acb15..4471de9 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -21,9 +21,9 @@ jobs: strategy: matrix: python-version: - - "3.9" - "3.10" - "3.11" + - "3.12" steps: - name: Checkout source diff --git a/README.md b/README.md index a876515..23044cc 100644 --- a/README.md +++ b/README.md @@ -1,23 +1,17 @@
- - - - Mici logo - -

- - PyPI version - - - DOI - - - Test status - - - Documentation status - -

+ + + + Mici logo + +
+ +[![PyPI version](https://badge.fury.io/py/mici.svg)](https://pypi.org/project/mici) +[![Zenodo DOI](https://zenodo.org/badge/52494384.svg)](https://zenodo.org/badge/latestdoi/52494384) +[![Test status](https://github.com/matt-graham/mici/actions/workflows/tests.yml/badge.svg)](https://github.com/matt-graham/mici/actions/workflows/tests.yml) +[![Docs status](https://github.com/matt-graham/mici/actions/workflows/docs.yml/badge.svg)](https://matt-graham.github.io/mici) + +
**Mici** is a Python package providing implementations of *Markov chain Monte @@ -34,6 +28,10 @@ Key features include extend the package, * a pure Python code base with minimal dependencies, allowing easy integration within other code, + * built-in support for several automatic differentiation frameworks, including + [JAX](https://jax.readthedocs.io/en/latest/) and + [Autograd](https://github.com/HIPS/autograd), or the option to supply your own + derivative functions, * implementations of MCMC methods for sampling from distributions on embedded manifolds implicitly-defined by a constraint equation and distributions on Riemannian manifolds with a user-specified metric, @@ -45,7 +43,7 @@ Key features include ## Installation -To install and use Mici the minimal requirements are a Python 3.9+ environment +To install and use Mici the minimal requirements are a Python 3.10+ environment with [NumPy](http://www.numpy.org/) and [SciPy](https://www.scipy.org) installed. The latest Mici release on PyPI (and its dependencies) can be installed in the current Python environment by running @@ -63,6 +61,14 @@ pip install git+https://github.com/matt-graham/mici If available in the installed Python environment the following additional packages provide extra functionality and features + * [ArviZ](https://python.arviz.org/en/latest/index.html): if ArviZ is + available the traces (dictionary) output of a sampling run can be directly + converted to an `arviz.InferenceData` container object using + `arviz.convert_to_inference_data` or implicitly converted by passing the + traces dictionary as the `data` argument + [to ArviZ API functions](https://python.arviz.org/en/latest/api/index.html), + allowing straightforward use of the ArviZ's extensive visualisation and + diagnostic functions. * [Autograd](https://github.com/HIPS/autograd): if available Autograd will be used to automatically compute the required derivatives of the model functions (providing they are specified using functions from the @@ -74,15 +80,22 @@ packages provide extra functionality and features serialisation (via [dill](https://github.com/uqfoundation/dill)) of a much wider range of types, including of Autograd generated functions. Both Autograd and multiprocess can be installed alongside Mici by running `pip - install mici[autodiff]`. - * [ArviZ](https://python.arviz.org/en/latest/index.html): if ArviZ is - available the traces (dictionary) output of a sampling run can be directly - converted to an `arviz.InferenceData` container object using - `arviz.convert_to_inference_data` or implicitly converted by passing the - traces dictionary as the `data` argument - [to ArviZ API functions](https://python.arviz.org/en/latest/api/index.html), - allowing straightforward use of the ArviZ's extensive visualisation and - diagnostic functions. + install mici[autograd]`. + * [JAX](https://jax.readthedocs.io/en/latest/): if available JAX will be used to + automatically compute the required derivatives of the model functions (providing + they are specified using functions from the [`jax` + interface](https://jax.readthedocs.io/en/latest/jax.html)). To sample chains + parallel using JAX functions you also need to install + [multiprocess](https://github.com/uqfoundation/multiprocess), though note due to + JAX's use of multithreading which [is incompatible with forking child + processes](https://docs.python.org/3/library/os.html#os.fork), this can result in + deadlock. Both JAX and multiprocess can be installed alongside Mici by running `pip + install mici[jax]`. + * [SymNum](https://github.com/matt-graham/symnum): if available SymNum will be used to + automatically compute the required derivatives of the model functions (providing + they are specified using functions from the [`symnum.numpy` + interface](https://matt-graham.github.io/symnum/symnum.numpy.html)). Symnum can be + installed alongside Mici by running `pip install mici[symnum]`. ## Why Mici? @@ -122,7 +135,7 @@ chains in Python can dominate the computational cost, making sampling much slower than packages which outsource the sampling loop to a efficient compiled implementation. - ## Overview of package +## Overview of package API documentation for the package is available [here](https://matt-graham.github.io/mici/). The three main user-facing @@ -257,22 +270,21 @@ The manifold MCMC methods implemented in Mici have been used in several research -A simple complete example of using the package to compute approximate samples -from a distribution on a two-dimensional torus embedded in a three-dimensional -space is given below. The computed samples are visualized in the animation -above. Here we use `autograd` to automatically construct functions to calculate -the required derivatives (gradient of negative log density of target -distribution and Jacobian of constraint function), sample four chains in -parallel using `multiprocess`, use `arviz` to calculate diagnostics and use -`matplotlib` to plot the samples. - -> ⚠️ **If you do not have [`multiprocess`](https://github.com/uqfoundation/multiprocess) installed the example code below will hang or raise an error when sampling the chains as the inbuilt `multiprocessing` module does not support pickling Autograd functions.** +A simple complete example of using the package to compute approximate samples from a +distribution on a two-dimensional torus embedded in a three-dimensional space is given +below. The computed samples are visualized in the animation above. Here we use +[SymNum](https://github.com/matt-graham/symnum) to automatically construct functions to +calculate the required derivatives (gradient of negative log density of target +distribution and Jacobian of constraint function), sample four chains in parallel using +`multiprocessing`, use [ArviZ](https://python.arviz.org/en/stable/) to calculate +diagnostics and use [Matplotlib](https://matplotlib.org/) to plot the samples. ```Python -from mici import systems, integrators, samplers -import autograd.numpy as np +import mici +import numpy as np +import symnum +import symnum.numpy as snp import matplotlib.pyplot as plt -from mpl_toolkits.mplot3d import Axes3D import matplotlib.animation as animation import arviz @@ -281,44 +293,62 @@ R = 1.0 # toroidal radius ∈ (0, ∞) r = 0.5 # poloidal radius ∈ (0, R) α = 0.9 # density fluctuation amplitude ∈ [0, 1) +# State dimension +dim_q = 3 + + # Define constraint function such that the set {q : constr(q) == 0} is a torus +@symnum.numpify(dim_q) def constr(q): - x, y, z = q.T - return np.stack([((x**2 + y**2)**0.5 - R)**2 + z**2 - r**2], -1) + x, y, z = q + return snp.array([((x**2 + y**2) ** 0.5 - R) ** 2 + z**2 - r**2]) + # Define negative log density for the target distribution on torus # (with respect to 2D 'area' measure for torus) +@symnum.numpify(dim_q) def neg_log_dens(q): - x, y, z = q.T - θ = np.arctan2(y, x) - ϕ = np.arctan2(z, x / np.cos(θ) - R) - return np.log1p(r * np.cos(ϕ) / R) - np.log1p(np.sin(4*θ) * np.cos(ϕ) * α) + x, y, z = q + θ = snp.arctan2(y, x) + ϕ = snp.arctan2(z, x / snp.cos(θ) - R) + return snp.log1p(r * snp.cos(ϕ) / R) - snp.log1p(snp.sin(4 * θ) * snp.cos(ϕ) * α) + # Specify constrained Hamiltonian system with default identity metric -system = systems.DenseConstrainedEuclideanMetricSystem(neg_log_dens, constr) +system = mici.systems.DenseConstrainedEuclideanMetricSystem( + neg_log_dens, + constr, + backend="symnum", +) # System is constrained therefore use constrained leapfrog integrator -integrator = integrators.ConstrainedLeapfrogIntegrator(system) +integrator = mici.integrators.ConstrainedLeapfrogIntegrator(system) # Seed a random number generator rng = np.random.default_rng(seed=1234) # Use dynamic integration-time HMC implementation as MCMC sampler -sampler = samplers.DynamicMultinomialHMC(system, integrator, rng) +sampler = mici.samplers.DynamicMultinomialHMC(system, integrator, rng) # Sample initial positions on torus using parameterisation (θ, ϕ) ∈ [0, 2π)² # x, y, z = (R + r * cos(ϕ)) * cos(θ), (R + r * cos(ϕ)) * sin(θ), r * sin(ϕ) n_chain = 4 θ_init, ϕ_init = rng.uniform(0, 2 * np.pi, size=(2, n_chain)) -q_init = np.stack([ - (R + r * np.cos(ϕ_init)) * np.cos(θ_init), - (R + r * np.cos(ϕ_init)) * np.sin(θ_init), - r * np.sin(ϕ_init)], -1) +q_init = np.stack( + [ + (R + r * np.cos(ϕ_init)) * np.cos(θ_init), + (R + r * np.cos(ϕ_init)) * np.sin(θ_init), + r * np.sin(ϕ_init), + ], + -1, +) + # Define function to extract variables to trace during sampling def trace_func(state): x, y, z = state.pos - return {'x': x, 'y': y, 'z': z} + return {"x": x, "y": y, "z": z} + # Sample 4 chains in parallel with 500 adaptive warm up iterations in which the # integrator step size is tuned, followed by 2000 non-adaptive iterations @@ -327,7 +357,7 @@ final_states, traces, stats = sampler.sample_chains( n_main_iter=2000, init_states=q_init, n_process=4, - trace_funcs=[trace_func] + trace_funcs=[trace_func], ) # Print average accept probability and number of integrator steps per chain @@ -340,19 +370,24 @@ for c in range(n_chain): print(arviz.summary(traces)) # Visualize concatentated chain samples as animated 3D scatter plot -fig = plt.figure(figsize=(4, 4)) -ax = Axes3D(fig, [0., 0., 1., 1.], proj_type='ortho') -points_3d, = ax.plot(*(np.concatenate(traces[k]) for k in 'xyz'), '.', ms=0.5) -ax.axis('off') +fig, ax = plt.subplots( + figsize=(4, 4), + subplot_kw={"projection": "3d", "proj_type": "ortho"}, +) +(points_3d,) = ax.plot(*(np.concatenate(traces[k]) for k in "xyz"), ".", ms=0.5) +ax.axis("off") for set_lim in [ax.set_xlim, ax.set_ylim, ax.set_zlim]: set_lim((-1, 1)) + def update(i): angle = 45 * (np.sin(2 * np.pi * i / 60) + 1) ax.view_init(elev=angle, azim=angle) return (points_3d,) -anim = animation.FuncAnimation(fig, update, frames=60, interval=100, blit=True) + +anim = animation.FuncAnimation(fig, update, frames=60, interval=100) +plt.show() ``` ## References diff --git a/pyproject.toml b/pyproject.toml index f50e056..16a6225 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,9 +17,9 @@ classifiers = [ "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Typing :: Typed", ] dependencies = [ @@ -38,7 +38,7 @@ keywords = [ ] name = "mici" readme = "README.md" -requires-python = ">=3.9" +requires-python = ">=3.10" license.file = "LICENCE" urls.homepage = "https://github.com/matt-graham/mici" urls.documentation = "https://matt-graham.github.io/mici" @@ -54,9 +54,16 @@ dev = [ "tox>=4", "twine", ] -autodiff = [ +autograd = [ "autograd>=1.3", - "multiprocess>=0.7.0" + "multiprocess>=0.7.0", +] +jax = [ + "jax>=0.4.1", + "multiprocess>=0.7.0", +] +symnum = [ + "symnum>=0.2.1", ] [tool.coverage] @@ -132,7 +139,7 @@ select = [ "W", "YTT", ] -target-version = "py39" +target-version = "py310" isort.known-first-party = [ "mici", ] @@ -152,21 +159,25 @@ write_to = "src/mici/_version.py" legacy_tox_ini = """ [gh-actions] python = - 3.9: py39 3.10: py310 3.11: py311 + 3.12: py312 [testenv] commands = pytest --cov {posargs} + extras = + autograd + jax + symnum deps = pytest pytest-cov - autograd>=1.3 - multiprocess>=0.7.0 pystan pymc>=5 arviz + set_env = + JAX_ENABLE_X64=1 [testenv:docs] commands = @@ -178,7 +189,7 @@ legacy_tox_ini = """ [tox] env_list = - py39 py310 py311 + py312 """ diff --git a/src/mici/autodiff.py b/src/mici/autodiff.py deleted file mode 100644 index ce69121..0000000 --- a/src/mici/autodiff.py +++ /dev/null @@ -1,70 +0,0 @@ -"""Automatic differentation fallback for constructing derivative functions.""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -from mici import autograd_wrapper - -if TYPE_CHECKING: - from typing import Callable, Optional - - -"""List of names of valid differential operators. - -Any automatic differentiation framework wrapper module will need to provide all of these -operators as callables (with a single function as argument) to fully support all of the -required derivative functions. -""" -DIFF_OPS = [ - # vector Jacobian product and value - "vjp_and_value", - # gradient and value for scalar valued functions - "grad_and_value", - # Hessian matrix, gradient and value for scalar valued functions - "hessian_grad_and_value", - # matrix Tressian product, gradient and value for scalar valued functions - "mtp_hessian_grad_and_value", - # Jacobian matrix and value for vector valued functions - "jacobian_and_value", - # matrix Hessian product, Jacobian matrix and value for vector valued functions - "mhp_jacobian_and_value", -] - - -def autodiff_fallback( - diff_func: Optional[Callable], - func: Callable, - diff_op_name: str, - name: str, -) -> Callable: - """Generate derivative function automatically if not provided. - - Uses automatic differentiation to generate a function corresponding to a - differential operator applied to a function if an alternative implementation of the - derivative function has not been provided. - - Args: - diff_func: Either a callable implementing the required derivative function or - `None` if none was provided. - func: Function to differentiate. - diff_op_name: String specifying name of differential operator from automatic - differentiation framework wrapper to use to generate required derivative - function. - name: Name of derivative function to use in error message. - - Returns: - `diff_func` value if not `None` otherwise generated derivative of `func` by - applying named differential operator. - """ - if diff_func is not None: - return diff_func - elif diff_op_name not in DIFF_OPS: - msg = f"Differential operator {diff_op_name} is not defined." - raise ValueError(msg) - elif autograd_wrapper.AUTOGRAD_AVAILABLE: - return getattr(autograd_wrapper, diff_op_name)(func) - elif not autograd_wrapper.AUTOGRAD_AVAILABLE: - msg = f"Autograd not available therefore {name} must be provided." - raise ValueError(msg) - return None diff --git a/src/mici/autodiff/__init__.py b/src/mici/autodiff/__init__.py new file mode 100644 index 0000000..4879534 --- /dev/null +++ b/src/mici/autodiff/__init__.py @@ -0,0 +1,180 @@ +"""Automatic differentation support for constructing derivative functions. + +Multiple automatic differentiation backends are supported: + +* `jax`: High performance array computing framework, with support for running + computations on accelerator devices and just-in-time (JIT) compilation. To + differentiate a function using the `jax` backend, the function must be defined in + terms of JAX primitives, for example using the functions in the `jax.numpy` API. By + default the derivative functions produced will be JIT compiled; a `jax_nojit` variant + is available if the function (or its derivative) is not compatible with JIT + compilation. +* `autograd`: Autograd can automatically differentiate native Python and NumPy code. To + differentiate a function using the `autograd` backend it should be defined in terms + of functions from the `autograd.numpy` and `autograd.scipy` APIs. Compared to JAX, + the lack of JIT compilation in Autograd and features such as automatic vectorisation + make Autograd slower, and so JAX will generally be a better choice. +* `symnum`: SymNum is a Python package that acts a bridge between NumPy and SymPy, + providing a NumPy-like interface that can be used to symbolically define functions + which take arrays as arguments and return arrays or scalars as values. To + differentiate a function using the `symnum` backend it should be defined in terms of + functions from the `symnum.numpy` API, and should have been decorated with + ``symnum.numpify`` with specified argument shapes. SymNum is intended for use in + generating the derivatives of 'simple' functions which compose a relatively small + number of operations and act on small array inputs. By reducing interpreter overheads + it can produce code which is cheaper to evaluate than corresponding Autograd or JAX + functions (including those using JIT compilation) in such cases, and which can be + serialised with the inbuilt Python pickle library allowing use for example in + libraries which use multiprocessing to implement parallelisation across multiple + processes. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, NamedTuple + +from mici.autodiff import autograd_wrapper, jax_wrapper, symnum_wrapper + +if TYPE_CHECKING: + from types import ModuleType + from typing import Callable, Optional + + +"""Names of valid differential operators. + +Any automatic differentiation framework wrapper module will need to provide all of these +operators as callables (with a single function as argument) to fully support all of the +required derivative functions. +""" +DIFF_OPS = ( + # vector Jacobian product and value + "vjp_and_value", + # gradient and value for scalar valued functions + "grad_and_value", + # Hessian matrix, gradient and value for scalar valued functions + "hessian_grad_and_value", + # matrix Tressian product, gradient and value for scalar valued functions + "mtp_hessian_grad_and_value", + # Jacobian matrix and value for vector valued functions + "jacobian_and_value", + # matrix Hessian product, Jacobian matrix and value for vector valued functions + "mhp_jacobian_and_value", +) + + +class AutodiffBackend(NamedTuple): + """Automatic differentiation backend framework. + + Consists of a module defining differential operators, a boolean flag indicating if + backend is available in current environment and optionally a function wrapper which + applies any post processing required to functions. + """ + + module: ModuleType + available: bool + function_wrapper: Optional[Callable] = None + + +"""Available autodifferentiation framework backends.""" +_REGISTERED_BACKENDS = { + "jax": AutodiffBackend( + jax_wrapper, + jax_wrapper.JAX_AVAILABLE, + jax_wrapper.jit_and_return_numpy_arrays, + ), + "jax_nojit": AutodiffBackend( + jax_wrapper, + jax_wrapper.JAX_AVAILABLE, + jax_wrapper.return_numpy_arrays, + ), + "autograd": AutodiffBackend(autograd_wrapper, autograd_wrapper.AUTOGRAD_AVAILABLE), + "symnum": AutodiffBackend(symnum_wrapper, symnum_wrapper.SYMNUM_AVAILABLE), +} + + +def _get_backend(name: str): + # Normalize name string to all lowercase to make invariant to capitalization + name = name.lower() + if name not in _REGISTERED_BACKENDS: + msg = ( + f"Selected autodiff backend {name} not recognised: " + f"available options are {tuple(_REGISTERED_BACKENDS)}." + ) + raise ValueError(msg) + return _REGISTERED_BACKENDS[name] + + +def wrap_function(function: Callable, backend: Optional[str]): + """Apply function wrapper for automatic differentiation backend to a function. + + Backends may define a function wrapper which applies any post processing required to + functions using framework - for example ensuring the function returns NumPy arrays + or just-in-time compiling the function. + + Args: + function: Function to wrap. + backend: Name of automatic differentiation framework backend to use. If `None` + function is returned unchanged. + + Returns: + Wrapped function. + """ + if backend is None: + return function + backend = _get_backend(backend) + if backend.function_wrapper is not None: + return backend.function_wrapper(function) + else: + return function + + +def autodiff_fallback( + diff_func: Optional[Callable], + func: Callable, + diff_op_name: str, + name: str, + backend: Optional[str], +) -> Callable: + """Generate derivative function automatically if not provided. + + Uses automatic differentiation to generate a function corresponding to a + differential operator applied to a function if an alternative implementation of the + derivative function has not been provided. + + Args: + diff_func: Either a callable implementing the required derivative function or + `None` if none was provided. + func: Function to differentiate. + diff_op_name: String specifying name of differential operator from automatic + differentiation framework wrapper to use to generate required derivative + function. + name: Name of derivative function to use in error message. + backend: Name of automatic differentiation framework backend to use. If `None` + `diff_func` must be provided. + + Returns: + `diff_func` value if not `None` otherwise generated derivative of `func` by + applying named differential operator from automatic differentiation backend. + """ + if diff_func is not None: + return diff_func + elif diff_func is None and backend is None: + msg = ( + f"Automatic differentiation backend specified as `None` so {name} must" + "be provided directly." + ) + raise ValueError(msg) + elif diff_op_name not in DIFF_OPS: + msg = f"Differential operator {diff_op_name} is not defined." + raise ValueError(msg) + else: + autodiff_backend = _get_backend(backend) + if autodiff_backend.available: + diff_func = getattr(autodiff_backend.module, diff_op_name)(func) + return wrap_function(diff_func, backend) + else: + msg = ( + f"{backend} selected as autodiff backend but is not available in " + f"current environment therefore {name} must be provided directly." + ) + raise ValueError(msg) diff --git a/src/mici/autodiff/autograd_wrapper.py b/src/mici/autodiff/autograd_wrapper.py new file mode 100644 index 0000000..a3dfb3b --- /dev/null +++ b/src/mici/autodiff/autograd_wrapper.py @@ -0,0 +1,146 @@ +"""Autograd differential operators.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +AUTOGRAD_AVAILABLE = True +try: + import autograd.numpy as np + from autograd.builtins import tuple as atuple + from autograd.core import make_vjp + from autograd.extend import vspace +except ImportError: + AUTOGRAD_AVAILABLE = False + +if TYPE_CHECKING: + + from mici.types import ( + ArrayFunction, + GradientFunction, + HessianFunction, + JacobianFunction, + MatrixHessianProductFunction, + MatrixTressianProductFunction, + ScalarFunction, + VectorJacobianProductFunction, + ) + + +def grad_and_value(func: ScalarFunction) -> GradientFunction: + """Makes a function that returns both gradient and value of a function.""" + + def grad_and_value_func(x): + vjp, val = make_vjp(func, x) + if vspace(val).size != 1: + msg = "grad_and_value only applies to real scalar-output functions." + raise TypeError(msg) + return vjp(vspace(val).ones()), val + + return grad_and_value_func + + +def vjp_and_value(func: ScalarFunction) -> VectorJacobianProductFunction: + """ + Makes a function that returns vector-Jacobian-product and value of a function. + + For a vector-valued function `fun` the vector-Jacobian-product (VJP) is here + defined as a function of a vector `v` corresponding to + + vjp(v) = v @ j + + where `j` is the Jacobian of `f = fun(x)` wrt `x` i.e. the rank-2 + tensor of first-order partial derivatives of the vector-valued function, + such that + + j[i, k] = ∂f[i] / ∂x[k] + """ + + def vjp_and_value_func(x): + return make_vjp(func, x) + + return vjp_and_value_func + + +def jacobian_and_value(func: ArrayFunction) -> JacobianFunction: + """Makes a function that returns both the Jacobian and value of a function.""" + + def jacobian_and_value_func(x): + vjp, val = make_vjp(func, x) + val_vspace = vspace(val) + jacobian_shape = val_vspace.shape + vspace(x).shape + jacobian_rows = map(vjp, val_vspace.standard_basis()) + return np.reshape(np.stack(jacobian_rows), jacobian_shape), val + + return jacobian_and_value_func + + +def mhp_jacobian_and_value(func: ArrayFunction) -> MatrixHessianProductFunction: + """ + Makes a function that returns MHP, Jacobian and value of a function. + + For a vector-valued function `fun` the matrix-Hessian-product (MHP) is here + defined as a function of a matrix `m` corresponding to + + mhp(m) = sum(m[:, :, None] * h[:, :, :], axis=(0, 1)) + + where `h` is the vector-Hessian of `f = fun(x)` wrt `x` i.e. the rank-3 + tensor of second-order partial derivatives of the vector-valued function, + such that + + h[i, j, k] = ∂²f[i] / (∂x[j] ∂x[k]) + """ + + def mhp_jacobian_and_value_func(x): + mhp, (jacob, val) = make_vjp(lambda x: atuple(jacobian_and_value(func)(x)), x) + return lambda m: mhp((m, vspace(val).zeros())), jacob, val + + return mhp_jacobian_and_value_func + + +def hessian_grad_and_value(func: ArrayFunction) -> HessianFunction: + """Makes a function that returns the Hessian, gradient & value of a function.""" + + def grad_func(x): + vjp, val = make_vjp(func, x) + return vjp(vspace(val).ones()), val + + def hessian_grad_and_value_func(x): + x_vspace = vspace(x) + vjp_grad, (grad, val) = make_vjp(lambda x: atuple(grad_func(x)), x) + hessian_shape = x_vspace.shape + x_vspace.shape + zeros = vspace(val).zeros() + hessian_rows = (vjp_grad((v, zeros)) for v in x_vspace.standard_basis()) + return np.reshape(np.stack(hessian_rows), hessian_shape), grad, val + + return hessian_grad_and_value_func + + +def mtp_hessian_grad_and_value(func: ArrayFunction) -> MatrixTressianProductFunction: + """ + Makes a function that returns MTP, Jacobian and value of a function. + + For a scalar-valued function `fun` the matrix-Tressian-product (MTP) is + here defined as a function of a matrix `m` corresponding to + + mtp(m) = sum(m[:, :] * t[:, :, :], axis=(-1, -2)) + + where `t` is the 'Tressian' of `f = fun(x)` wrt `x` i.e. the 3D array of + third-order partial derivatives of the scalar-valued function such that + + t[i, j, k] = ∂³f / (∂x[i] ∂x[j] ∂x[k]) + """ + + def mtp_hessian_grad_and_value_func(x): + mtp, (hessian, grad, val) = make_vjp( + lambda x: atuple(hessian_grad_and_value(func)(x)), + x, + ) + return ( + lambda m: mtp((m, vspace(grad).zeros(), vspace(val).zeros())), + hessian, + grad, + val, + ) + + return mtp_hessian_grad_and_value_func diff --git a/src/mici/autodiff/jax_wrapper.py b/src/mici/autodiff/jax_wrapper.py new file mode 100644 index 0000000..ae0b064 --- /dev/null +++ b/src/mici/autodiff/jax_wrapper.py @@ -0,0 +1,199 @@ +"""JAX differential operators and helper functions.""" + +from __future__ import annotations + +from functools import partial +from typing import TYPE_CHECKING + +import numpy as np + +JAX_AVAILABLE = True +try: + import jax + import jax.numpy as jnp +except ImportError: + JAX_AVAILABLE = False + +if TYPE_CHECKING: + from mici.types import ( + ArrayFunction, + ArrayLike, + GradientFunction, + JacobianFunction, + MatrixHessianProductFunction, + MatrixTressianProduct, + ScalarFunction, + ScalarLike, + VectorJacobianProductFunction, + ) + + +def jit_and_return_numpy_arrays(function, **jit_kwargs): + """Wrap a JIT compiled function returning JAX arrays to instead return NumPy arrays. + + Args: + function: Function to wrap. Should return one of: a single JAX array, a callable + returning one or more JAX array or a tuple of one or more JAX arrays or + functions returning one or more JAX arrays. + **jit_kwargs: Any keyword arguments to pass to `jax.jit` operator. + + Returns: + Wrapped function. Any values returned by original function which are JAX arrays + will instead be NumPy arrays, while any values which are callables returning + JAX arrays will instead return NumPy arrays. + """ + jitted_function = jax.jit(function, **jit_kwargs) + return return_numpy_arrays(jitted_function) + + +def return_numpy_arrays(function: callable) -> callable: + """Wrap a function returning JAX arrays to instead return NumPy arrays. + + Args: + function: Function to wrap. Should return one of: a single JAX array, a callable + returning one or more JAX array or a tuple of one or more JAX arrays or + functions returning one or more JAX arrays. + + Returns: + Wrapped function. Any values returned by original function which are JAX arrays + will instead be NumPy arrays, while any values which are callables returning + JAX arrays will instead return NumPy arrays. + """ + + def as_numpy_array(value): + if callable(value): + return return_numpy_arrays(value) + elif isinstance(value, jax.Array): + return np.asarray(value) + else: + return value + + def function_returning_numpy_arrays(*args, **kwargs): + return_value = function(*args, **kwargs) + if isinstance(return_value, tuple): + return tuple(as_numpy_array(value) for value in return_value) + else: + return as_numpy_array(return_value) + + return function_returning_numpy_arrays + + +def grad_and_value(func: ScalarFunction) -> GradientFunction: + """Makes a function that returns both the Jacobian and value of a function.""" + + def grad_and_value_func(x): + value, grad = jax.value_and_grad(func)(x) + return grad, value + + return grad_and_value_func + + +def _detuple_vjp(vjp_func): + """Transform a VJP of function with one return value so it returns an array.""" + return jax.tree_util.Partial( + lambda *args, **kwargs: vjp_func.func(*args, **kwargs)[0], + *vjp_func.args, + **vjp_func.keywords, + ) + + +def vjp_and_value(func: ArrayFunction) -> VectorJacobianProductFunction: + """Makes a function that returns vector-Jacobian-product and value of a function. + + For a vector-valued function `fun` the vector-Jacobian-product (VJP) is here + defined as a function of a vector `v` corresponding to + + vjp(v) = v @ j + + where `j` is the Jacobian of `f = fun(x)` wrt `x` i.e. the rank-2 + tensor of first-order partial derivatives of the vector-valued function, + such that + + j[i, k] = ∂f[i] / ∂x[k] + """ + + def vjp_and_value_func(x): + value, vjp = jax.vjp(func, x) + return _detuple_vjp(vjp), value + + return vjp_and_value_func + + +def jacobian_and_value(func: ArrayFunction) -> JacobianFunction: + """Makes a function that returns both the Jacobian and value of a function.""" + + def jacobian_and_value_func(x): + value, pullback = jax.vjp(func, x) + basis = jnp.eye(value.size, dtype=value.dtype) + (jac,) = jax.vmap(pullback)(basis) + return jac, value + + return jacobian_and_value_func + + +def mhp_jacobian_and_value(func: ArrayFunction) -> MatrixHessianProductFunction: + """ + Makes a function that returns MHP, Jacobian and value of a function. + + For a vector-valued function `fun` the matrix-Hessian-product (MHP) is here + defined as a function of a matrix `m` corresponding to + + mhp(m) = sum(m[:, :, None] * h[:, :, :], axis=(0, 1)) + + where `h` is the vector-Hessian of `f = fun(x)` wrt `x` i.e. the rank-3 + tensor of second-order partial derivatives of the vector-valued function, + such that + + h[i, j, k] = ∂²f[i] / (∂x[j] ∂x[k]) + """ + + def mhp_jacobian_and_value_func(x): + jac, mhp, value = jax.vjp(jacobian_and_value(func), x, has_aux=True) + return _detuple_vjp(mhp), jac, value + + return mhp_jacobian_and_value_func + + +def hessian_grad_and_value( + func: ScalarFunction, +) -> tuple[ArrayLike, ArrayLike, ScalarLike]: + """Makes a function that returns the Hessian, gradient and value of a function.""" + + def hessian_grad_and_value_func(x): + basis = jnp.eye(x.size, dtype=x.dtype) + grad_and_value_func = grad_and_value(func) + pushforward = partial(jax.jvp, grad_and_value_func, (x,), has_aux=True) + grad, hessian, value = jax.vmap(pushforward, out_axes=(None, -1, None))( + (basis,), + ) + return hessian, grad, value + + return hessian_grad_and_value_func + + +def mtp_hessian_grad_and_value( + func: ScalarFunction, +) -> tuple[MatrixTressianProduct, ArrayLike, ArrayLike, ScalarLike]: + """ + Makes a function that returns MTP, Jacobian and value of a function. + + For a scalar-valued function `fun` the matrix-Tressian-product (MTP) is + here defined as a function of a matrix `m` corresponding to + + mtp(m) = sum(m[:, :] * t[:, :, :], axis=(-1, -2)) + + where `t` is the 'Tressian' of `f = fun(x)` wrt `x` i.e. the 3D array of + third-order partial derivatives of the scalar-valued function such that + + t[i, j, k] = ∂³f / (∂x[i] ∂x[j] ∂x[k]) + """ + + def hessian_and_aux_func(x): + hessian, grad, value = hessian_grad_and_value(func)(x) + return hessian, (grad, value) + + def mtp_hessian_grad_and_value_func(x): + hessian, mtp, (grad, value) = jax.vjp(hessian_and_aux_func, x, has_aux=True) + return _detuple_vjp(mtp), hessian, grad, value + + return mtp_hessian_grad_and_value_func diff --git a/src/mici/autodiff/symnum_wrapper.py b/src/mici/autodiff/symnum_wrapper.py new file mode 100644 index 0000000..f10d0d6 --- /dev/null +++ b/src/mici/autodiff/symnum_wrapper.py @@ -0,0 +1,94 @@ +"""SymNum differential operators and helper functions.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +SYMNUM_AVAILABLE = True +try: + import symnum +except ImportError: + SYMNUM_AVAILABLE = False + +if TYPE_CHECKING: + from mici.types import ( + ArrayFunction, + ArrayLike, + GradientFunction, + JacobianFunction, + MatrixHessianProductFunction, + MatrixTressianProduct, + ScalarFunction, + ScalarLike, + VectorJacobianProductFunction, + ) + + +def grad_and_value(func: ScalarFunction) -> GradientFunction: + """Makes a function that returns both the Jacobian and value of a function.""" + return symnum.grad(func, return_aux=True) + + +def vjp_and_value(func: ArrayFunction) -> VectorJacobianProductFunction: + """Makes a function that returns vector-Jacobian-product and value of a function. + + For a vector-valued function `fun` the vector-Jacobian-product (VJP) is here + defined as a function of a vector `v` corresponding to + + vjp(v) = v @ j + + where `j` is the Jacobian of `f = fun(x)` wrt `x` i.e. the rank-2 + tensor of first-order partial derivatives of the vector-valued function, + such that + + j[i, k] = ∂f[i] / ∂x[k] + """ + return symnum.vector_jacobian_product(func, return_aux=True) + + +def jacobian_and_value(func: ArrayFunction) -> JacobianFunction: + """Makes a function that returns both the Jacobian and value of a function.""" + return symnum.jacobian(func, return_aux=True) + + +def mhp_jacobian_and_value(func: ArrayFunction) -> MatrixHessianProductFunction: + """Makes a function that returns MHP, Jacobian and value of a function. + + For a vector-valued function `fun` the matrix-Hessian-product (MHP) is here + defined as a function of a matrix `m` corresponding to + + mhp(m) = sum(m[:, :, None] * h[:, :, :], axis=(0, 1)) + + where `h` is the vector-Hessian of `f = fun(x)` wrt `x` i.e. the rank-3 + tensor of second-order partial derivatives of the vector-valued function, + such that + + h[i, j, k] = ∂²f[i] / (∂x[j] ∂x[k]) + """ + return symnum.matrix_hessian_product(func, return_aux=True) + + +def hessian_grad_and_value( + func: ScalarFunction, +) -> tuple[ArrayLike, ArrayLike, ScalarLike]: + """Makes a function that returns the Hessian, gradient and value of a function.""" + return symnum.hessian(func, return_aux=True) + + +def mtp_hessian_grad_and_value( + func: ScalarFunction, +) -> tuple[MatrixTressianProduct, ArrayLike, ArrayLike, ScalarLike]: + """ + Makes a function that returns MTP, Jacobian and value of a function. + + For a scalar-valued function `fun` the matrix-Tressian-product (MTP) is + here defined as a function of a matrix `m` corresponding to + + mtp(m) = sum(m[:, :] * t[:, :, :], axis=(-1, -2)) + + where `t` is the 'Tressian' of `f = fun(x)` wrt `x` i.e. the 3D array of + third-order partial derivatives of the scalar-valued function such that + + t[i, j, k] = ∂³f / (∂x[i] ∂x[j] ∂x[k]) + """ + return symnum.matrix_tressian_product(func, return_aux=True) diff --git a/src/mici/autograd_wrapper.py b/src/mici/autograd_wrapper.py deleted file mode 100644 index 540f8b3..0000000 --- a/src/mici/autograd_wrapper.py +++ /dev/null @@ -1,153 +0,0 @@ -"""Additional autograd differential operators.""" - -from __future__ import annotations - -from functools import wraps -from typing import TYPE_CHECKING - -AUTOGRAD_AVAILABLE = True -try: - import autograd.numpy as np - from autograd.builtins import tuple as atuple - from autograd.core import make_vjp - from autograd.extend import vspace - from autograd.wrap_util import unary_to_nary -except ImportError: - AUTOGRAD_AVAILABLE = False - -if TYPE_CHECKING: - from typing import Callable - - from mici.types import ( - ArrayFunction, - ArrayLike, - MatrixHessianProduct, - MatrixTressianProduct, - ScalarFunction, - ScalarLike, - ) - - -def _wrapped_unary_to_nary(func: Callable) -> Callable: - """Use functools.wraps with unary_to_nary decorator.""" - if AUTOGRAD_AVAILABLE: - return wraps(func)(unary_to_nary(func)) - else: - return func - - -@_wrapped_unary_to_nary -def grad_and_value(fun: ScalarFunction, x: ArrayLike) -> tuple[ArrayLike, ScalarLike]: - """Makes a function that returns both gradient and value of a function.""" - vjp, val = make_vjp(fun, x) - if vspace(val).size != 1: - msg = "grad_and_value only applies to real scalar-output functions." - raise TypeError(msg) - return vjp(vspace(val).ones()), val - - -@_wrapped_unary_to_nary -def jacobian_and_value(fun: ArrayFunction, x: ArrayLike) -> tuple[ArrayLike, ArrayLike]: - """ - Makes a function that returns both the Jacobian and value of a function. - - Assumes that the function `fun` broadcasts along the first dimension of the - input being differentiated with respect to such that a batch of outputs can - be computed concurrently for a batch of inputs. - """ - val = fun(x) - v_vspace = vspace(val) - x_vspace = vspace(x) - x_rep = np.tile(x, (v_vspace.size,) + (1,) * x_vspace.ndim) - vjp_rep, _ = make_vjp(fun, x_rep) - jacobian_shape = v_vspace.shape + x_vspace.shape - basis_vectors = np.array(list(v_vspace.standard_basis())) - jacobian = vjp_rep(basis_vectors) - return np.reshape(jacobian, jacobian_shape), val - - -@_wrapped_unary_to_nary -def mhp_jacobian_and_value( - fun: ArrayFunction, - x: ArrayLike, -) -> tuple[MatrixHessianProduct, ArrayLike, ArrayLike]: - """ - Makes a function that returns MHP, Jacobian and value of a function. - - For a vector-valued function `fun` the matrix-Hessian-product (MHP) is here - defined as a function of a matrix `m` corresponding to - - mhp(m) = sum(m[:, :, None] * h[:, :, :], axis=(0, 1)) - - where `h` is the vector-Hessian of `f = fun(x)` wrt `x` i.e. the rank-3 - tensor of second-order partial derivatives of the vector-valued function, - such that - - h[i, j, k] = ∂²f[i] / (∂x[j] ∂x[k]) - - Assumes that the function `fun` broadcasts along the first dimension of the - input being differentiated with respect to such that a batch of outputs can - be computed concurrently for a batch of inputs. - """ - mhp, (jacob, val) = make_vjp(lambda x: atuple(jacobian_and_value(fun)(x)), x) - return lambda m: mhp((m, vspace(val).zeros())), jacob, val - - -@_wrapped_unary_to_nary -def hessian_grad_and_value( - fun: ArrayFunction, - x: ArrayLike, -) -> tuple[ArrayLike, ArrayLike, ScalarLike]: - """ - Makes a function that returns the Hessian, gradient & value of a function. - - Assumes that the function `fun` broadcasts along the first dimension of the - input being differentiated with respect to such that a batch of outputs can - be computed concurrently for a batch of inputs. - """ - - def grad_fun(x): - vjp, val = make_vjp(fun, x) - return vjp(vspace(val).ones()), val - - x_vspace = vspace(x) - x_rep = np.tile(x, (x_vspace.size,) + (1,) * x_vspace.ndim) - vjp_grad, (grad, val) = make_vjp(lambda x: atuple(grad_fun(x)), x_rep) - hessian_shape = x_vspace.shape + x_vspace.shape - basis_vectors = np.array(list(x_vspace.standard_basis())) - hessian = vjp_grad((basis_vectors, vspace(val).zeros())) - return np.reshape(hessian, hessian_shape), grad[0], val[0] - - -@_wrapped_unary_to_nary -def mtp_hessian_grad_and_value( - fun: ArrayFunction, - x: ArrayLike, -) -> tuple[MatrixTressianProduct, ArrayLike, ArrayLike, ScalarLike]: - """ - Makes a function that returns MTP, Jacobian and value of a function. - - For a scalar-valued function `fun` the matrix-Tressian-product (MTP) is - here defined as a function of a matrix `m` corresponding to - - mtp(m) = sum(m[:, :] * t[:, :, :], axis=(-1, -2)) - - where `t` is the 'Tressian' of `f = fun(x)` wrt `x` i.e. the 3D array of - third-order partial derivatives of the scalar-valued function such that - - t[i, j, k] = ∂³f / (∂x[i] ∂x[j] ∂x[k]) - - Assumes that the function `fun` broadcasts along the first dimension of the - input being differentiated with respect to such that a batch of outputs can - be computed concurrently for a batch of inputs. - """ - mtp, (hessian, grad, val) = make_vjp( - lambda x: atuple(hessian_grad_and_value(fun)(x)), - x, - ) - return ( - lambda m: mtp((m, vspace(grad).zeros(), vspace(val).zeros())), - hessian, - grad, - val, - ) diff --git a/src/mici/interop.py b/src/mici/interop.py index 805eb52..abf07b3 100644 --- a/src/mici/interop.py +++ b/src/mici/interop.py @@ -226,6 +226,7 @@ def sample_pymc_model( system = mici.systems.EuclideanMetricSystem( neg_log_dens=neg_log_dens, grad_neg_log_dens=grad_neg_log_dens, + backend=None, ) integrator = mici.integrators.LeapfrogIntegrator(system) @@ -423,6 +424,7 @@ def sample_stan_model( system = mici.systems.EuclideanMetricSystem( neg_log_dens=neg_log_dens, grad_neg_log_dens=grad_neg_log_dens, + backend=None, ) integrator = mici.integrators.LeapfrogIntegrator(system, step_size=stepsize) diff --git a/src/mici/systems.py b/src/mici/systems.py index dbc523b..cb5c598 100644 --- a/src/mici/systems.py +++ b/src/mici/systems.py @@ -8,7 +8,7 @@ import numpy as np from mici import matrices -from mici.autodiff import autodiff_fallback +from mici.autodiff import autodiff_fallback, wrap_function from mici.states import cache_in_state, cache_in_state_with_aux if TYPE_CHECKING: @@ -62,6 +62,7 @@ def __init__( neg_log_dens: ScalarFunction, *, grad_neg_log_dens: Optional[GradientFunction] = None, + backend: Optional[str] = None, ): """ Args: @@ -78,13 +79,19 @@ def __init__( position array. If `None` is passed (the default) an automatic differentiation fallback will be used to attempt to construct the derivative of `neg_log_dens` automatically. + backend: Name of automatic differentiation backend to use. See + :py:mod:`.autodiff` subpackage documentation for details of available + backends. If `None` (the default) no automatic differentiation fallback + will be used and so all derivative functions must be specified + explicitly. """ - self._neg_log_dens = neg_log_dens + self._neg_log_dens = wrap_function(neg_log_dens, backend) self._grad_neg_log_dens = autodiff_fallback( grad_neg_log_dens, neg_log_dens, "grad_and_value", "grad_neg_log_dens", + backend, ) @cache_in_state("pos") @@ -285,6 +292,7 @@ def __init__( *, metric: Optional[MetricLike] = None, grad_neg_log_dens: Optional[GradientFunction] = None, + backend: Optional[str] = None, ): """ Args: @@ -311,8 +319,17 @@ def __init__( position array. If `None` is passed (the default) an automatic differentiation fallback will be used to attempt to construct the derivative of `neg_log_dens` automatically. + backend: Name of automatic differentiation backend to use. See + :py:mod:`.autodiff` subpackage documentation for details of available + backends. If `None` (the default) no automatic differentiation fallback + will be used and so all derivative functions must be specified + explicitly. """ - super().__init__(neg_log_dens=neg_log_dens, grad_neg_log_dens=grad_neg_log_dens) + super().__init__( + neg_log_dens=neg_log_dens, + grad_neg_log_dens=grad_neg_log_dens, + backend=backend, + ) if metric is None: self.metric = matrices.IdentityMatrix() elif isinstance(metric, np.ndarray): @@ -392,6 +409,7 @@ def __init__( *, metric: Optional[MetricLike] = None, grad_neg_log_dens: Optional[GradientFunction] = None, + backend: Optional[str] = None, ): """ Args: @@ -418,11 +436,17 @@ def __init__( position array. If `None` is passed (the default) an automatic differentiation fallback will be used to attempt to construct the derivative of `neg_log_dens` automatically. + backend: Name of automatic differentiation backend to use. See + :py:mod:`.autodiff` subpackage documentation for details of available + backends. If `None` (the default) no automatic differentiation fallback + will be used and so all derivative functions must be specified + explicitly. """ super().__init__( neg_log_dens=neg_log_dens, metric=metric, grad_neg_log_dens=grad_neg_log_dens, + backend=backend, ) def h2(self, state: ChainState) -> ScalarLike: @@ -667,6 +691,7 @@ def __init__( dens_wrt_hausdorff: bool = True, grad_neg_log_dens: Optional[GradientFunction] = None, jacob_constr: Optional[JacobianFunction] = None, + backend: Optional[str] = None, ): r""" Args: @@ -692,7 +717,7 @@ def __init__( where :code:`jacob_constr` is the Jacobian of the constraint function :code:`constr` and :code:`metric` is the matrix representation of the metric on the ambient space. - constr: Function which given a position rray return as a 1D array the value + constr: Function which given a position array return as a 1D array the value of the (vector-valued) constraint function, the zero level-set of which implicitly defines the manifold the dynamic is simulated on. metric: Matrix object corresponding to matrix representation of metric on @@ -737,19 +762,26 @@ def __init__( used to attempt to construct a function to compute the Jacobian (and value) of :code:`constr` automatically. + backend: Name of automatic differentiation backend to use. See + :py:mod:`.autodiff` subpackage documentation for details of available + backends. If `None` (the default) no automatic differentiation fallback + will be used and so all derivative functions must be specified + explicitly. """ super().__init__( neg_log_dens=neg_log_dens, metric=metric, grad_neg_log_dens=grad_neg_log_dens, + backend=backend, ) - self._constr = constr + self._constr = wrap_function(constr, backend) self.dens_wrt_hausdorff = dens_wrt_hausdorff self._jacob_constr = autodiff_fallback( jacob_constr, constr, "jacobian_and_value", "jacob_constr", + backend, ) @cache_in_state("pos") @@ -861,6 +893,7 @@ def __init__( grad_neg_log_dens: Optional[GradientFunction] = None, jacob_constr: Optional[JacobianFunction] = None, mhp_constr: Optional[MatrixHessianProductFunction] = None, + backend: Optional[str] = None, ): """ Args: @@ -886,7 +919,7 @@ def __init__( where :code:`jacob_constr` is the Jacobian of the constraint function :code:`constr` and :code:`metric` is the matrix representation of the metric on the ambient space. - constr: Function which given a position rray return as a 1D array the value + constr: Function which given a position array return as a 1D array the value of the (vector-valued) constraint function, the zero level-set of which implicitly defines the manifold the dynamic is simulated on. metric: Matrix object corresponding to matrix representation of metric on @@ -949,6 +982,11 @@ def __init__( fallback will be used to attempt to construct a function which calculates the MHP (and Jacobian and value) of :code:`constr` automatically. + backend: Name of automatic differentiation backend to use. See + :py:mod:`.autodiff` subpackage documentation for details of available + backends. If `None` (the default) no automatic differentiation fallback + will be used and so all derivative functions must be specified + explicitly. """ super().__init__( neg_log_dens=neg_log_dens, @@ -957,6 +995,7 @@ def __init__( dens_wrt_hausdorff=dens_wrt_hausdorff, grad_neg_log_dens=grad_neg_log_dens, jacob_constr=jacob_constr, + backend=backend, ) if not dens_wrt_hausdorff: self._mhp_constr = autodiff_fallback( @@ -964,6 +1003,7 @@ def __init__( constr, "mhp_jacobian_and_value", "mhp_constr", + backend, ) @cache_in_state_with_aux("pos", ("jacob_constr", "constr")) @@ -1015,6 +1055,7 @@ def __init__( grad_neg_log_dens: Optional[GradientFunction] = None, jacob_constr: Optional[JacobianFunction] = None, mhp_constr: Optional[MatrixHessianProductFunction] = None, + backend: Optional[str] = None, ): """ Args: @@ -1099,6 +1140,11 @@ def __init__( passed (the default) an automatic differentiation fallback will be used to attempt to construct a function which calculates the MHP (and Jacobian and value) of `constr` automatically. + backend: Name of automatic differentiation backend to use. See + :py:mod:`.autodiff` subpackage documentation for details of available + backends. If `None` (the default) no automatic differentiation fallback + will be used and so all derivative functions must be specified + explicitly. """ DenseConstrainedEuclideanMetricSystem.__init__( self, @@ -1109,6 +1155,7 @@ def __init__( grad_neg_log_dens=grad_neg_log_dens, jacob_constr=jacob_constr, mhp_constr=mhp_constr, + backend=backend, ) def jacob_constr_inner_product( @@ -1192,6 +1239,7 @@ def __init__( vjp_metric_func: Optional[VectorJacobianProductFunction] = None, grad_neg_log_dens: Optional[GradientFunction] = None, metric_kwargs: Optional[dict[str, Any]] = None, + backend: Optional[str] = None, ): """ Args: @@ -1258,14 +1306,20 @@ def __init__( derivative of `neg_log_dens` automatically. metric_kwargs: An optional dictionary of any additional keyword arguments to the initializer of `metric_matrix_class`. + backend: Name of automatic differentiation backend to use. See + :py:mod:`.autodiff` subpackage documentation for details of available + backends. If `None` (the default) no automatic differentiation fallback + will be used and so all derivative functions must be specified + explicitly. """ self._metric_matrix_class = metric_matrix_class - self._metric_func = metric_func + self._metric_func = wrap_function(metric_func, backend) self._vjp_metric_func = autodiff_fallback( vjp_metric_func, metric_func, "vjp_and_value", "vjp_metric_func", + backend, ) self._metric_kwargs = {} if metric_kwargs is None else metric_kwargs super().__init__(neg_log_dens, grad_neg_log_dens=grad_neg_log_dens) @@ -1376,6 +1430,7 @@ def __init__( *, vjp_metric_scalar_func: Optional[VectorJacobianProductFunction] = None, grad_neg_log_dens: Optional[GradientFunction] = None, + backend: Optional[str] = None, ): """ Args: @@ -1420,6 +1475,11 @@ def __init__( position array. If `None` is passed (the default) an automatic differentiation fallback will be used to attempt to construct the derivative of `neg_log_dens` automatically. + backend: Name of automatic differentiation backend to use. See + :py:mod:`.autodiff` subpackage documentation for details of available + backends. If `None` (the default) no automatic differentiation fallback + will be used and so all derivative functions must be specified + explicitly. """ super().__init__( neg_log_dens=neg_log_dens, @@ -1427,6 +1487,7 @@ def __init__( metric_func=metric_scalar_func, vjp_metric_func=vjp_metric_scalar_func, grad_neg_log_dens=grad_neg_log_dens, + backend=backend, ) @cache_in_state("pos") @@ -1457,6 +1518,7 @@ def __init__( *, vjp_metric_diagonal_func: Optional[VectorJacobianProductFunction] = None, grad_neg_log_dens: Optional[GradientFunction] = None, + backend: Optional[str] = None, ): """ Args: @@ -1501,6 +1563,11 @@ def __init__( position array. If `None` is passed (the default) an automatic differentiation fallback will be used to attempt to construct the derivative of `neg_log_dens` automatically. + backend: Name of automatic differentiation backend to use. See + :py:mod:`.autodiff` subpackage documentation for details of available + backends. If `None` (the default) no automatic differentiation fallback + will be used and so all derivative functions must be specified + explicitly. """ super().__init__( neg_log_dens=neg_log_dens, @@ -1508,6 +1575,7 @@ def __init__( metric_func=metric_diagonal_func, vjp_metric_func=vjp_metric_diagonal_func, grad_neg_log_dens=grad_neg_log_dens, + backend=backend, ) @@ -1530,6 +1598,7 @@ def __init__( *, vjp_metric_chol_func: Optional[VectorJacobianProductFunction] = None, grad_neg_log_dens: Optional[GradientFunction] = None, + backend: Optional[str] = None, ): """ Args: @@ -1575,6 +1644,11 @@ def __init__( position array. If `None` is passed (the default) an automatic differentiation fallback will be used to attempt to construct the derivative of `neg_log_dens` automatically. + backend: Name of automatic differentiation backend to use. See + :py:mod:`.autodiff` subpackage documentation for details of available + backends. If `None` (the default) no automatic differentiation fallback + will be used and so all derivative functions must be specified + explicitly. """ super().__init__( neg_log_dens=neg_log_dens, @@ -1583,6 +1657,7 @@ def __init__( vjp_metric_func=vjp_metric_chol_func, grad_neg_log_dens=grad_neg_log_dens, metric_kwargs={"factor_is_lower": True}, + backend=backend, ) @@ -1605,6 +1680,7 @@ def __init__( *, vjp_metric_func: Optional[VectorJacobianProductFunction] = None, grad_neg_log_dens: Optional[GradientFunction] = None, + backend: Optional[str] = None, ): """ Args: @@ -1650,6 +1726,11 @@ def __init__( position array. If `None` is passed (the default) an automatic differentiation fallback will be used to attempt to construct the derivative of `neg_log_dens` automatically. + backend: Name of automatic differentiation backend to use. See + :py:mod:`.autodiff` subpackage documentation for details of available + backends. If `None` (the default) no automatic differentiation fallback + will be used and so all derivative functions must be specified + explicitly. """ super().__init__( neg_log_dens=neg_log_dens, @@ -1657,6 +1738,7 @@ def __init__( metric_func=metric_func, vjp_metric_func=vjp_metric_func, grad_neg_log_dens=grad_neg_log_dens, + backend=backend, ) @@ -1706,6 +1788,7 @@ def __init__( hess_neg_log_dens: Optional[HessianFunction] = None, mtp_neg_log_dens: Optional[MatrixTressianProductFunction] = None, softabs_coeff: ScalarLike = 1.0, + backend: Optional[str] = None, ): """ Args: @@ -1762,18 +1845,25 @@ def __init__( to absolute value used to regularize Hessian eigenvalues in metric matrix representation. As the value tends to infinity the approximation becomes increasingly close to the absolute function. + backend: Name of automatic differentiation backend to use. See + :py:mod:`.autodiff` subpackage documentation for details of available + backends. If `None` (the default) no automatic differentiation fallback + will be used and so all derivative functions must be specified + explicitly. """ self._hess_neg_log_dens = autodiff_fallback( hess_neg_log_dens, neg_log_dens, "hessian_grad_and_value", "neg_log_dens", + backend, ) self._mtp_neg_log_dens = autodiff_fallback( mtp_neg_log_dens, neg_log_dens, "mtp_hessian_grad_and_value", "mtp_neg_log_dens", + backend, ) super().__init__( neg_log_dens=neg_log_dens, @@ -1782,6 +1872,7 @@ def __init__( vjp_metric_func=self._mtp_neg_log_dens, grad_neg_log_dens=grad_neg_log_dens, metric_kwargs={"softabs_coeff": softabs_coeff}, + backend=backend, ) def metric_func(self, state: ChainState) -> ArrayLike: diff --git a/tests/test_autodiff.py b/tests/test_autodiff.py new file mode 100644 index 0000000..71b697e --- /dev/null +++ b/tests/test_autodiff.py @@ -0,0 +1,326 @@ +import numpy as np +import pytest + +from mici.autodiff import ( + _REGISTERED_BACKENDS, + DIFF_OPS, + _get_backend, + autodiff_fallback, + wrap_function, +) + +N_POINTS_TO_TEST = 5 + +BACKENDS_AVAIALBLE = [ + name for name, backend in _REGISTERED_BACKENDS.items() if backend.available +] + +SCALAR_FUNCTION_DIFF_OPS = [ + "grad_and_value", + "hessian_grad_and_value", + "mtp_hessian_grad_and_value", +] +VECTOR_FUNCTION_DIFF_OPS = ["jacobian_and_value", "mhp_jacobian_and_value"] + + +def torus_function_and_derivatives(numpy_module): + toroidal_rad = 1.0 + poloidal_rad = 0.5 + + def constr(q): + x, y, z = q + return numpy_module.array( + [((x**2 + y**2) ** 0.5 - toroidal_rad) ** 2 + z**2 - poloidal_rad**2], + ) + + def jacob_constr(q): + x, y, z = q + r = (x**2 + y**2) ** 0.5 + return np.array( + [[2 * x * (r - toroidal_rad) / r, 2 * y * (r - toroidal_rad) / r, 2 * z]], + ) + + def mhp_constr(q): + x, y, z = q + r = (x**2 + y**2) ** 0.5 + r_cubed = r**3 + return lambda m: np.array( + [ + 2 * (toroidal_rad / r_cubed) * (m[0, 0] * x**2 + m[0, 1] * x * y) + + 2 * m[0, 0] * (1 - toroidal_rad / r), + 2 * (toroidal_rad / r_cubed) * (m[0, 1] * y**2 + m[0, 0] * x * y) + + 2 * m[0, 1] * (1 - toroidal_rad / r), + 2 * m[0, 2], + ], + ) + + return { + "function": constr, + "jacobian_function": jacob_constr, + "mhp_function": mhp_constr, + } + + +def linear_function_and_derivatives(_): + + constr_matrix = np.array([[1.0, -1.0, 2.0, 3.0], [-3.0, 2.0, 0.0, 5.0]]) + + def constr(q): + return constr_matrix @ q + + def jacob_constr(_): + return constr_matrix + + def mhp_constr(_): + return lambda _: np.zeros(constr_matrix.shape[1]) + + return { + "function": constr, + "jacobian_function": jacob_constr, + "mhp_function": mhp_constr, + } + + +def quadratic_form_function_and_derivatives(_): + + matrix = np.array([[1.3, -0.2], [-0.2, 2.5]]) + + def quadratic_form(q): + return q @ matrix @ q / 2 + + def grad_quadratic_form(q): + return matrix @ q + + def hessian_quadratic_form(_): + return matrix + + def mtp_quadratic_form(_): + return lambda _: np.zeros(matrix.shape[0]) + + return { + "function": quadratic_form, + "grad_function": grad_quadratic_form, + "hessian_function": hessian_quadratic_form, + "mtp_function": mtp_quadratic_form, + } + + +def cubic_function_and_derivatives(_): + + def cubic(q): + return (q**3).sum() / 6 + + def grad_cubic(q): + return q**2 / 2 + + def hessian_cubic(q): + return np.diag(q) + + def mtp_cubic(_): + return lambda m: m.diagonal() + + return { + "function": cubic, + "grad_function": grad_cubic, + "hessian_function": hessian_cubic, + "mtp_function": mtp_cubic, + } + + +def quartic_function_and_derivatives(_): + + def quartic(q): + return (q**4).sum() / 24 + + def grad_quartic(q): + return q**3 / 6 + + def hessian_quartic(q): + return np.diag(q**2 / 2) + + def mtp_quartic(q): + return lambda m: m.diagonal() * q + + return { + "function": quartic, + "grad_function": grad_quartic, + "hessian_function": hessian_quartic, + "mtp_function": mtp_quartic, + } + + +def numpify_function(function_and_derivatives, *arg_shapes): + import symnum + + function_and_derivatives["function"] = symnum.numpify_func( + function_and_derivatives["function"], + *arg_shapes, + ) + + +@pytest.mark.parametrize("backend_name", BACKENDS_AVAIALBLE) +def test_module_defines_all_diffops(backend_name): + backend = _get_backend(backend_name) + for diff_op_name in DIFF_OPS: + assert hasattr(backend.module, diff_op_name) + assert callable(getattr(backend.module, diff_op_name)) + + +def get_numpy_module(backend): + if backend in ("jax", "jax_nojit"): + import jax.numpy + + return jax.numpy + elif backend == "autograd": + import autograd.numpy + + return autograd.numpy + elif backend == "symnum": + import symnum.numpy + + return symnum.numpy + else: + msg = f"Unrecognised backend {backend}" + raise ValueError(msg) + + +@pytest.fixture +def rng(): + return np.random.default_rng(1234) + + +@pytest.mark.parametrize("diff_op_name", DIFF_OPS) +def test_autodiff_fallback_with_no_backend_raises(diff_op_name): + with pytest.raises(ValueError, match="None"): + autodiff_fallback( + None, lambda q: q, diff_op_name, diff_op_name + "_function", None, + ) + + +@pytest.mark.parametrize("backend_name", BACKENDS_AVAIALBLE) +def test_autodiff_fallback_with_invalid_diff_op_raises(backend_name): + with pytest.raises(ValueError, match="not defined"): + autodiff_fallback(None, lambda q: q, "foo", "bar", backend_name) + + +@pytest.mark.parametrize("backend_name", BACKENDS_AVAIALBLE) +def test_wrap_function(backend_name): + + def function(q): + return q**2 + + assert wrap_function(function, None) is function + wrapped_function = wrap_function(function, backend_name) + assert callable(wrapped_function) + test_input = np.arange(5) + output = wrapped_function(test_input) + assert isinstance(output, np.ndarray) + assert np.allclose(wrapped_function(test_input), function(test_input)) + + +@pytest.mark.parametrize("diff_op_name", VECTOR_FUNCTION_DIFF_OPS) +@pytest.mark.parametrize("backend_name", BACKENDS_AVAIALBLE) +@pytest.mark.parametrize( + "function_and_derivatives_and_dim", + [(torus_function_and_derivatives, 1, 3), (linear_function_and_derivatives, 2, 4)], + ids=lambda p: p[0].__name__, +) +def test_vector_function_diff_ops( + diff_op_name, backend_name, function_and_derivatives_and_dim, rng, +): + construct_function_and_derivatives, dim_c, dim_q = function_and_derivatives_and_dim + numpy_module = get_numpy_module(backend_name) + function_and_derivatives = construct_function_and_derivatives(numpy_module) + if backend_name == "symnum": + numpify_function(function_and_derivatives, dim_q) + diff_op_function = autodiff_fallback( + None, + function_and_derivatives["function"], + diff_op_name, + diff_op_name + "_function", + backend_name, + ) + for _ in range(N_POINTS_TO_TEST): + q = rng.standard_normal(dim_q) + derivatives_and_values = diff_op_function(q) + # diff_op_function returns derivatives in descending order (of derivative) while + # derivatives are in increasing order in function_and_derivatives + for (function_name, expected_value_function), test_value in zip( + function_and_derivatives.items(), + reversed(derivatives_and_values), + ): + if function_name.startswith("mhp"): + m = rng.standard_normal((dim_c, dim_q)) + assert np.allclose(test_value(m), expected_value_function(q)(m)) + else: + assert np.allclose(test_value, expected_value_function(q)) + + +@pytest.mark.parametrize("diff_op_name", SCALAR_FUNCTION_DIFF_OPS) +@pytest.mark.parametrize("backend_name", BACKENDS_AVAIALBLE) +@pytest.mark.parametrize( + "function_and_derivatives_and_dim_q", + [ + (quadratic_form_function_and_derivatives, 2), + (cubic_function_and_derivatives, 1), + (cubic_function_and_derivatives, 3), + (quartic_function_and_derivatives, 2), + ], + ids=lambda p: p[0].__name__, +) +def test_scalar_function_diff_ops( + diff_op_name, backend_name, function_and_derivatives_and_dim_q, rng, +): + construct_function_and_derivatives, dim_q = function_and_derivatives_and_dim_q + numpy_module = get_numpy_module(backend_name) + function_and_derivatives = construct_function_and_derivatives(numpy_module) + if backend_name == "symnum": + numpify_function(function_and_derivatives, dim_q) + diff_op_function = autodiff_fallback( + None, + function_and_derivatives["function"], + diff_op_name, + diff_op_name + "_function", + backend_name, + ) + for _ in range(N_POINTS_TO_TEST): + q = rng.standard_normal(dim_q) + derivatives_and_values = diff_op_function(q) + # diff_op_function returns derivatives in descending order (of derivative) while + # derivatives are in increasing order in function_and_derivatives + for (function_name, expected_value_function), test_value in zip( + function_and_derivatives.items(), + reversed(derivatives_and_values), + ): + if function_name.startswith("mtp"): + m = rng.standard_normal((dim_q, dim_q)) + assert np.allclose(test_value(m), expected_value_function(q)(m)) + else: + assert np.allclose(test_value, expected_value_function(q)) + + +@pytest.mark.parametrize("backend_name", BACKENDS_AVAIALBLE) +@pytest.mark.parametrize( + "function_and_derivatives_and_dim", + [(torus_function_and_derivatives, 1, 3), (linear_function_and_derivatives, 2, 4)], + ids=lambda p: p[0].__name__, +) +def test_vjp_and_value(backend_name, function_and_derivatives_and_dim, rng): + construct_function_and_derivatives, dim_c, dim_q = function_and_derivatives_and_dim + numpy_module = get_numpy_module(backend_name) + function_and_derivatives = construct_function_and_derivatives(numpy_module) + if backend_name == "symnum": + numpify_function(function_and_derivatives, dim_q) + vjp_and_value_function = autodiff_fallback( + None, + function_and_derivatives["function"], + "vjp_and_value", + "vjp_and_value_function", + backend_name, + ) + for _ in range(N_POINTS_TO_TEST): + q = rng.standard_normal(dim_q) + vjp, value = vjp_and_value_function(q) + assert np.allclose(function_and_derivatives["function"](q), value) + v = rng.standard_normal(value.shape) + assert np.allclose(v @ function_and_derivatives["jacobian_function"](q), vjp(v))