Skip to content

Commit

Permalink
Add support for alternative automatic differentiation backends (#11)
Browse files Browse the repository at this point in the history
* Add JAX differential operators for JAX backend

* Initial generalization of autodiff fallback to allow different backends

* Add missing VJP and value operator to Autograd wrapper

* Complete JAX differential operators

Deals with detupleifying VJPs and removes internal jitting

* Feed through backend in to systems and wrap base functions

* Add SymNum backend option

* Set default backend dynamically based on availability

* Remove unused imports

* Rename and add optional dependency groups

* Update README example to SymNum + update autodiff support description

* Reformat badges

* Add new lines around badges div

* Remove use of vectorization trick in autograd diffops

* Add autodiff tests

* Make autodiff subpackage + add details of backends to docstring

* Correct name of nested function

* Remove use of unary_to_nary in autograd wrapper

* Install jax and symnum in tox tests environment

* Enable JAX x64 mode in tests

* Remove strict keyword from zip

* Remove strict argument from zip

* Explicitly set None backend in interop functions

* Bump minimum SymNum version

* Drop Python 3.9 support and add Python 3.12 to test matrix

* Default to no autodiff backend
  • Loading branch information
matt-graham authored Oct 6, 2024
1 parent 7f5c73e commit ea4adee
Show file tree
Hide file tree
Showing 13 changed files with 1,166 additions and 305 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.x
cache: "pip"
cache-dependency-path: "pyproject.toml"
- name: Install tox
Expand All @@ -35,4 +35,4 @@ jobs:
publish_dir: docs/_build/html
publish_branch: gh-pages
user_name: "github-actions[bot]"
user_email: "github-actions[bot]@users.noreply.github.com"
user_email: "github-actions[bot]@users.noreply.github.com"
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ jobs:
strategy:
matrix:
python-version:
- "3.9"
- "3.10"
- "3.11"
- "3.12"

steps:
- name: Checkout source
Expand Down
161 changes: 98 additions & 63 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,23 +1,17 @@
<div style="text-align: center;" align="center">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/matt-graham/mici/main/images/mici-logo-rectangular-light-text.svg">
<source media="(prefers-color-scheme: light)" srcset="https://raw.githubusercontent.com/matt-graham/mici/main/images/mici-logo-rectangular.svg">
<img alt="Mici logo" src="https://raw.githubusercontent.com/matt-graham/mici/main/images/mici-logo-rectangular.svg" width="400px">
</picture>
<p>
<a href="https://badge.fury.io/py/mici">
<img src="https://badge.fury.io/py/mici.svg" alt="PyPI version"/>
</a>
<a href="https://zenodo.org/badge/latestdoi/52494384">
<img src="https://zenodo.org/badge/52494384.svg" alt="DOI"/>
</a>
<a href="https://github.com/matt-graham/mici/actions/workflows/tests.yml">
<img src="https://github.com/matt-graham/mici/actions/workflows/tests.yml/badge.svg" alt="Test status" />
</a>
<a href="https://matt-graham.github.io/mici">
<img src="https://github.com/matt-graham/mici/actions/workflows/docs.yml/badge.svg" alt="Documentation status" />
</a>
</p>
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/matt-graham/mici/main/images/mici-logo-rectangular-light-text.svg">
<source media="(prefers-color-scheme: light)" srcset="https://raw.githubusercontent.com/matt-graham/mici/main/images/mici-logo-rectangular.svg">
<img alt="Mici logo" src="https://raw.githubusercontent.com/matt-graham/mici/main/images/mici-logo-rectangular.svg" width="400px">
</picture>
<div style="text-align: center;" align="center">

[![PyPI version](https://badge.fury.io/py/mici.svg)](https://pypi.org/project/mici)
[![Zenodo DOI](https://zenodo.org/badge/52494384.svg)](https://zenodo.org/badge/latestdoi/52494384)
[![Test status](https://github.com/matt-graham/mici/actions/workflows/tests.yml/badge.svg)](https://github.com/matt-graham/mici/actions/workflows/tests.yml)
[![Docs status](https://github.com/matt-graham/mici/actions/workflows/docs.yml/badge.svg)](https://matt-graham.github.io/mici)

</div>
</div>

**Mici** is a Python package providing implementations of *Markov chain Monte
Expand All @@ -34,6 +28,10 @@ Key features include
extend the package,
* a pure Python code base with minimal dependencies,
allowing easy integration within other code,
* built-in support for several automatic differentiation frameworks, including
[JAX](https://jax.readthedocs.io/en/latest/) and
[Autograd](https://github.com/HIPS/autograd), or the option to supply your own
derivative functions,
* implementations of MCMC methods for sampling from distributions on embedded
manifolds implicitly-defined by a constraint equation and distributions on
Riemannian manifolds with a user-specified metric,
Expand All @@ -45,7 +43,7 @@ Key features include

## Installation

To install and use Mici the minimal requirements are a Python 3.9+ environment
To install and use Mici the minimal requirements are a Python 3.10+ environment
with [NumPy](http://www.numpy.org/) and [SciPy](https://www.scipy.org)
installed. The latest Mici release on PyPI (and its dependencies) can be
installed in the current Python environment by running
Expand All @@ -63,6 +61,14 @@ pip install git+https://github.com/matt-graham/mici
If available in the installed Python environment the following additional
packages provide extra functionality and features

* [ArviZ](https://python.arviz.org/en/latest/index.html): if ArviZ is
available the traces (dictionary) output of a sampling run can be directly
converted to an `arviz.InferenceData` container object using
`arviz.convert_to_inference_data` or implicitly converted by passing the
traces dictionary as the `data` argument
[to ArviZ API functions](https://python.arviz.org/en/latest/api/index.html),
allowing straightforward use of the ArviZ's extensive visualisation and
diagnostic functions.
* [Autograd](https://github.com/HIPS/autograd): if available Autograd will
be used to automatically compute the required derivatives of the model
functions (providing they are specified using functions from the
Expand All @@ -74,15 +80,22 @@ packages provide extra functionality and features
serialisation (via [dill](https://github.com/uqfoundation/dill)) of a much
wider range of types, including of Autograd generated functions. Both
Autograd and multiprocess can be installed alongside Mici by running `pip
install mici[autodiff]`.
* [ArviZ](https://python.arviz.org/en/latest/index.html): if ArviZ is
available the traces (dictionary) output of a sampling run can be directly
converted to an `arviz.InferenceData` container object using
`arviz.convert_to_inference_data` or implicitly converted by passing the
traces dictionary as the `data` argument
[to ArviZ API functions](https://python.arviz.org/en/latest/api/index.html),
allowing straightforward use of the ArviZ's extensive visualisation and
diagnostic functions.
install mici[autograd]`.
* [JAX](https://jax.readthedocs.io/en/latest/): if available JAX will be used to
automatically compute the required derivatives of the model functions (providing
they are specified using functions from the [`jax`
interface](https://jax.readthedocs.io/en/latest/jax.html)). To sample chains
parallel using JAX functions you also need to install
[multiprocess](https://github.com/uqfoundation/multiprocess), though note due to
JAX's use of multithreading which [is incompatible with forking child
processes](https://docs.python.org/3/library/os.html#os.fork), this can result in
deadlock. Both JAX and multiprocess can be installed alongside Mici by running `pip
install mici[jax]`.
* [SymNum](https://github.com/matt-graham/symnum): if available SymNum will be used to
automatically compute the required derivatives of the model functions (providing
they are specified using functions from the [`symnum.numpy`
interface](https://matt-graham.github.io/symnum/symnum.numpy.html)). Symnum can be
installed alongside Mici by running `pip install mici[symnum]`.

## Why Mici?

Expand Down Expand Up @@ -122,7 +135,7 @@ chains in Python can dominate the computational cost, making sampling much
slower than packages which outsource the sampling loop to a efficient compiled
implementation.

## Overview of package
## Overview of package

API documentation for the package is available
[here](https://matt-graham.github.io/mici/). The three main user-facing
Expand Down Expand Up @@ -257,22 +270,21 @@ The manifold MCMC methods implemented in Mici have been used in several research

<img src='https://raw.githubusercontent.com/matt-graham/mici/main/images/torus-samples.gif' width='360px'/>

A simple complete example of using the package to compute approximate samples
from a distribution on a two-dimensional torus embedded in a three-dimensional
space is given below. The computed samples are visualized in the animation
above. Here we use `autograd` to automatically construct functions to calculate
the required derivatives (gradient of negative log density of target
distribution and Jacobian of constraint function), sample four chains in
parallel using `multiprocess`, use `arviz` to calculate diagnostics and use
`matplotlib` to plot the samples.

> ⚠️ **If you do not have [`multiprocess`](https://github.com/uqfoundation/multiprocess) installed the example code below will hang or raise an error when sampling the chains as the inbuilt `multiprocessing` module does not support pickling Autograd functions.**
A simple complete example of using the package to compute approximate samples from a
distribution on a two-dimensional torus embedded in a three-dimensional space is given
below. The computed samples are visualized in the animation above. Here we use
[SymNum](https://github.com/matt-graham/symnum) to automatically construct functions to
calculate the required derivatives (gradient of negative log density of target
distribution and Jacobian of constraint function), sample four chains in parallel using
`multiprocessing`, use [ArviZ](https://python.arviz.org/en/stable/) to calculate
diagnostics and use [Matplotlib](https://matplotlib.org/) to plot the samples.

```Python
from mici import systems, integrators, samplers
import autograd.numpy as np
import mici
import numpy as np
import symnum
import symnum.numpy as snp
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.animation as animation
import arviz

Expand All @@ -281,44 +293,62 @@ R = 1.0 # toroidal radius ∈ (0, ∞)
r = 0.5 # poloidal radius ∈ (0, R)
α = 0.9 # density fluctuation amplitude ∈ [0, 1)

# State dimension
dim_q = 3


# Define constraint function such that the set {q : constr(q) == 0} is a torus
@symnum.numpify(dim_q)
def constr(q):
x, y, z = q.T
return np.stack([((x**2 + y**2)**0.5 - R)**2 + z**2 - r**2], -1)
x, y, z = q
return snp.array([((x**2 + y**2) ** 0.5 - R) ** 2 + z**2 - r**2])


# Define negative log density for the target distribution on torus
# (with respect to 2D 'area' measure for torus)
@symnum.numpify(dim_q)
def neg_log_dens(q):
x, y, z = q.T
θ = np.arctan2(y, x)
ϕ = np.arctan2(z, x / np.cos(θ) - R)
return np.log1p(r * np.cos(ϕ) / R) - np.log1p(np.sin(4*θ) * np.cos(ϕ) * α)
x, y, z = q
θ = snp.arctan2(y, x)
ϕ = snp.arctan2(z, x / snp.cos(θ) - R)
return snp.log1p(r * snp.cos(ϕ) / R) - snp.log1p(snp.sin(4 * θ) * snp.cos(ϕ) * α)


# Specify constrained Hamiltonian system with default identity metric
system = systems.DenseConstrainedEuclideanMetricSystem(neg_log_dens, constr)
system = mici.systems.DenseConstrainedEuclideanMetricSystem(
neg_log_dens,
constr,
backend="symnum",
)

# System is constrained therefore use constrained leapfrog integrator
integrator = integrators.ConstrainedLeapfrogIntegrator(system)
integrator = mici.integrators.ConstrainedLeapfrogIntegrator(system)

# Seed a random number generator
rng = np.random.default_rng(seed=1234)

# Use dynamic integration-time HMC implementation as MCMC sampler
sampler = samplers.DynamicMultinomialHMC(system, integrator, rng)
sampler = mici.samplers.DynamicMultinomialHMC(system, integrator, rng)

# Sample initial positions on torus using parameterisation (θ, ϕ) ∈ [0, 2π)²
# x, y, z = (R + r * cos(ϕ)) * cos(θ), (R + r * cos(ϕ)) * sin(θ), r * sin(ϕ)
n_chain = 4
θ_init, ϕ_init = rng.uniform(0, 2 * np.pi, size=(2, n_chain))
q_init = np.stack([
(R + r * np.cos(ϕ_init)) * np.cos(θ_init),
(R + r * np.cos(ϕ_init)) * np.sin(θ_init),
r * np.sin(ϕ_init)], -1)
q_init = np.stack(
[
(R + r * np.cos(ϕ_init)) * np.cos(θ_init),
(R + r * np.cos(ϕ_init)) * np.sin(θ_init),
r * np.sin(ϕ_init),
],
-1,
)


# Define function to extract variables to trace during sampling
def trace_func(state):
x, y, z = state.pos
return {'x': x, 'y': y, 'z': z}
return {"x": x, "y": y, "z": z}


# Sample 4 chains in parallel with 500 adaptive warm up iterations in which the
# integrator step size is tuned, followed by 2000 non-adaptive iterations
Expand All @@ -327,7 +357,7 @@ final_states, traces, stats = sampler.sample_chains(
n_main_iter=2000,
init_states=q_init,
n_process=4,
trace_funcs=[trace_func]
trace_funcs=[trace_func],
)

# Print average accept probability and number of integrator steps per chain
Expand All @@ -340,19 +370,24 @@ for c in range(n_chain):
print(arviz.summary(traces))

# Visualize concatentated chain samples as animated 3D scatter plot
fig = plt.figure(figsize=(4, 4))
ax = Axes3D(fig, [0., 0., 1., 1.], proj_type='ortho')
points_3d, = ax.plot(*(np.concatenate(traces[k]) for k in 'xyz'), '.', ms=0.5)
ax.axis('off')
fig, ax = plt.subplots(
figsize=(4, 4),
subplot_kw={"projection": "3d", "proj_type": "ortho"},
)
(points_3d,) = ax.plot(*(np.concatenate(traces[k]) for k in "xyz"), ".", ms=0.5)
ax.axis("off")
for set_lim in [ax.set_xlim, ax.set_ylim, ax.set_zlim]:
set_lim((-1, 1))


def update(i):
angle = 45 * (np.sin(2 * np.pi * i / 60) + 1)
ax.view_init(elev=angle, azim=angle)
return (points_3d,)

anim = animation.FuncAnimation(fig, update, frames=60, interval=100, blit=True)

anim = animation.FuncAnimation(fig, update, frames=60, interval=100)
plt.show()
```

## References
Expand Down
29 changes: 20 additions & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ classifiers = [
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Typing :: Typed",
]
dependencies = [
Expand All @@ -38,7 +38,7 @@ keywords = [
]
name = "mici"
readme = "README.md"
requires-python = ">=3.9"
requires-python = ">=3.10"
license.file = "LICENCE"
urls.homepage = "https://github.com/matt-graham/mici"
urls.documentation = "https://matt-graham.github.io/mici"
Expand All @@ -54,9 +54,16 @@ dev = [
"tox>=4",
"twine",
]
autodiff = [
autograd = [
"autograd>=1.3",
"multiprocess>=0.7.0"
"multiprocess>=0.7.0",
]
jax = [
"jax>=0.4.1",
"multiprocess>=0.7.0",
]
symnum = [
"symnum>=0.2.1",
]

[tool.coverage]
Expand Down Expand Up @@ -132,7 +139,7 @@ select = [
"W",
"YTT",
]
target-version = "py39"
target-version = "py310"
isort.known-first-party = [
"mici",
]
Expand All @@ -152,21 +159,25 @@ write_to = "src/mici/_version.py"
legacy_tox_ini = """
[gh-actions]
python =
3.9: py39
3.10: py310
3.11: py311
3.12: py312
[testenv]
commands =
pytest --cov {posargs}
extras =
autograd
jax
symnum
deps =
pytest
pytest-cov
autograd>=1.3
multiprocess>=0.7.0
pystan
pymc>=5
arviz
set_env =
JAX_ENABLE_X64=1
[testenv:docs]
commands =
Expand All @@ -178,7 +189,7 @@ legacy_tox_ini = """
[tox]
env_list =
py39
py310
py311
py312
"""
Loading

0 comments on commit ea4adee

Please sign in to comment.