Skip to content

Commit

Permalink
Merge branch 'master' of github.com:JakobRobnik/MicroCanonicalHMC
Browse files Browse the repository at this point in the history
  • Loading branch information
JakobRobnik committed Jan 16, 2024
2 parents 42e6d6a + fcff728 commit 87dc902
Show file tree
Hide file tree
Showing 19 changed files with 1,408 additions and 501 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install pytest
pip install -r requirements.txt
pip install .
- name: Test with pytest
run: |
Expand Down
11 changes: 1 addition & 10 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,18 +1,9 @@
PKG_VERSION = $(shell python setup.py --version)

test:
JAX_PLATFORM_NAME=cpu pytest --benchmark-disable
mypy mclmc/sampling/sampler.py
mypy mclmc/sampler.py

set-bench:
pytest --benchmark-autosave

compare-bench:
pytest --benchmark-compare=0001 --benchmark-compare-fail=mean:2%

# We launch the package release by tagging the master branch with the package's
# new version number.
release:
git tag -a $(PKG_VERSION) -m $(PKG_VERSION)
git push --tag

22 changes: 19 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,31 @@


You can check out the tutorials:
- [getting started](notebooks/tutorials/intro_tutorial.ipynb): sampling from a standard Gaussian (sequential sampling)
- [advanced tutorial](notebooks/tutorials/advanced_tutorial.ipynb): sampling the hierarchical Stochastic Volatility model for the S&P500 returns data (sequential sampling)
- [getting started](notebooks/tutorials/intro_tutorial.ipynb): sampling from a standard Gaussian
- [advanced tutorial](notebooks/tutorials/advanced_tutorial.ipynb): sampling the hierarchical Stochastic Volatility model for the S&P500 returns data

Julia implementation is available [here](https://github.com/JaimeRZP/MicroCanonicalHMC.jl).

The associated papers are:
- [method and benchmark tests](https://arxiv.org/abs/2212.08549)
- [formulation as a stochastic process and first application to the lattice field theory](https://arxiv.org/abs/2303.18221)

If you have any questions do not hesitate to contact me at jakob_robnik@berkeley.edu
The code is still in active development, so let us know if you encounter any issues, including bad sampling performance, and we will do our best to help you out.
You can submit a github issue or contact us at jakob_robnik@berkeley.edu .

## Frequently asked questions:

### How can I sample with MCHMC if my parameters have bounds?
Check out [this tutorial](notebooks/tutorials/Constraints.ipynb).

### How does cost of producing one sample in HMC compare to the cost of one sample in MCHMC?
MCHMC samples are less costly. What is relevant for the computational time is the number of gradient evaluations used. Each sample in MCHMC is two gradient evaluations (1 gradient evaluation if leapfrog integrator is used instead of minimal norm integrator). Each sample in HMC is L gradient evaluations (where L is the number of leapfrog steps per sample), which can be quite large for hard targets (in default NUTS setting up to 1024).

### Is MCHMC just some weird projection of HMC onto the constant energy surface?
No, the Hamiltonian dynamics of both methods are different (the particles move differently). Below is the motion of MCHMC particles for the Rosenbrock target distribution.





![ensamble](img/rosenbrock.gif)
4 changes: 2 additions & 2 deletions mclmc/annealing.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
from mclmc.sampling import dynamics
from mclmc import dynamics

from mclmc.sampling.dynamics import update_momentum
from .dynamics import update_momentum


class vmap_target:
Expand Down
108 changes: 108 additions & 0 deletions mclmc/boundary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@

import jax.numpy as jnp




class Boundary():
"""Forms a transformation map which will bound the parameter space (this transformation will be applied in the position_update of the Hamiltonian dynamics integration)"""

def __init__(self, d,
where_positive = None,
where_reflect = None,
where_periodic = None,
a = None, b = None,
):

"""
where_positive: indices of positively constrained parameters
where_reflect: indices of rectangularly constrained parameters (with reflective boundary). Use if parameter is constrained to an interval (for example 0 < x < 1), but it is not periodic.
where_periodic: indices of rectangularly constrained parameters (with periodic boundary). Use for example for angles.
a: lower bounds
b: upper bounds
Example:
We have parameters
x = [x0, x1, x2, x3, x4, x5, x6]
and we want constraints:
x0 unconstrained
x1 > 0
0 < x2 < 2 pi (periodic)
x3 unconstrained
0 < x4 < 1 (not periodic)
-1 < x5 < 1 (not periodic)
x6 > 0
We should use:
where_positive = jnp.array([1, 6])
where_reflect = jnp.array([4, 5])
where_periodic = jnp.array([2, ])
a = jnp.array([0., 0.,-1.])
b = jnp.array([2 jnp.pi, 1., 1.])
"""


self.d = d

self.mask_positive = self.to_mask(where_positive)
self.mask_reflect = self.to_mask(where_reflect)
self.mask_periodic = self.to_mask(where_periodic)


self.a, self.b = self.extend_bounds(jnp.logical_or(self.mask_reflect, self.mask_periodic), a, b)


def map(self, x):
"""maps R^d to the constrained region
Args:
x: unconstrained parameter vector
Returns:
x': constrained parameter vector
sgn: array of signs (+1 or -1), indicating which component of the velocity should be fliped.
"""

# These functions map R^d to the constrained region (the unconstrained parameters are also maped but this will be ignored later).
# They also return a boolean array (r) which indicate which components of the velocity should be fliped.
x0, r0 = x, False
x1, r1 = self._positive(x)
x2, r2 = self._reflect(x)
x3, r3 = self._periodic(x)

combine = lambda y0, y1, y2, y3: self.mask_positive * y1 + self.mask_reflect * y2 + self.mask_periodic * y3 + (1- (self.mask_positive + self.mask_reflect + self.mask_periodic)) * y0

return combine(x0, x1, x2, x3), 1 - 2 * combine(r0, r1, r2, r3)



def _positive(self, x):
return jnp.abs(x), x < 0.

def _periodic(self, x):
return jnp.mod(x - self.a, self.b - self.a) + self.a, False

def _reflect(self, x):
y = jnp.mod((x - self.a) / (self.b - self.a), 2.)
z = 1 - jnp.abs(1. - y)
return z * (self.b-self.a) + self.a, y > 1.



def extend_bounds(self, mask, a, b):
A = jnp.zeros(len(mask))
B = jnp.ones(len(mask))

if a != None:
A = A.at[mask].set(a)
B = B.at[mask].set(b)

return A, B


def to_mask(self, where):

mask = jnp.zeros(self.d, dtype = bool)

if where == None:
return mask
else:
return mask.at[where].set(True)
Loading

0 comments on commit 87dc902

Please sign in to comment.