Skip to content

Commit

Permalink
Make sample_chains return named tuples
Browse files Browse the repository at this point in the history
  • Loading branch information
matt-graham committed Aug 7, 2023
1 parent d81d226 commit 454db41
Showing 1 changed file with 86 additions and 60 deletions.
146 changes: 86 additions & 60 deletions src/mici/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import tempfile
import signal
from warnings import warn
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, NamedTuple
import numpy as np
from numpy.random import default_rng
from mici.transitions import (
Expand Down Expand Up @@ -752,6 +752,37 @@ def _sample_chains_parallel(
return (*_collate_chain_outputs(chain_outputs), exception)


class MCMCSampleChainsOutputs(NamedTuple):
"""Outputs returned by :py:meth:`MarkovChainMonteCarloMethod.sample_chains` call.
Parameters:
final_states: States of chains after final iteration. May be used to resume
sampling a chain by passing as the initial states to a new `sample_chains`
call.
traces: Dictionary of chain trace arrays. Values in dictionary are list of
arrays of variables outputted by trace functions in `trace_funcs` with each
array in the list corresponding to a single chain and the leading dimension
of each array corresponding to the iteration (draw) index, within the main
non-adaptive sampling stage if `trace_warm_up=False` and across both warm-up
and main sampling stages otherwise. The key for each value is the
corresponding key in the dictionary returned by the trace function which
computed the traced value.
statistics: Dictionary of chain transition statistic dictionaries. Values in
outer dictionary are dictionaries of statistics for each chain transition,
keyed by the string key for the transition. The values in each inner
transition dictionary are lists of arrays of chain statistic values with
each array in the list corresponding to a single chain and the leading
dimension of each array corresponding to the iteration (draw) index, within
the main non-adaptive sampling stage if `trace_warm_up=False` and across
both warm-up and main sampling stages otherwise. The key for each value is a
string description of the corresponding transition statistic.
"""

final_states: list[ChainState]
traces: dict[str, list[NDArray]]
statistics: dict[str, dict[str, list[NDArray]]]


class MarkovChainMonteCarloMethod:
"""Generic Markov chain Monte Carlo (MCMC) sampler.
Expand Down Expand Up @@ -782,6 +813,7 @@ def sample_chains(
n_warm_up_iter: int,
n_main_iter: int,
init_states: Iterable[Union[ChainState, dict]],
*,
trace_funcs: Optional[Sequence[TraceFunction]] = None,
adapters: Optional[dict[str, Sequence[Adapter]]] = None,
stager: Optional[Stager] = None,
Expand All @@ -793,9 +825,7 @@ def sample_chains(
monitor_stats: Optional[dict[str, list[str]]] = None,
display_progress: bool = True,
progress_bar_class: Optional[ProgressBar] = None,
) -> tuple[
list[ChainState], dict[str, list[NDArray]], dict[str, dict[str, list[NDArray]]]
]:
) -> MCMCSampleChainsOutputs:
"""Sample Markov chains from given initial states with optional adaptive warm up
One or more Markov chains are sampled, with each chain iteration consisting of
Expand All @@ -818,8 +848,6 @@ def sample_chains(
init_states: Initial chain states. Each entry can be either a `ChainState`
object or a dictionary with entries specifying initial values for all
state variables used by chain transition `sample` methods.
Kwargs:
trace_funcs: Sequence of functions which compute the variables to be
recorded at each chain iteration (during only the main non-adaptive
sampling stage if `trace_warm_up` is False), with each trace function
Expand Down Expand Up @@ -891,25 +919,9 @@ def sample_chains(
`display_progress=True`.
Returns:
final_states (list[ChainState]): States of chains after final iteration. May
be used to resume sampling a chain by passing as the initial states to a
new `sample_chains` call.
traces (dict[str, list[array]]): dictionary of chain trace arrays. Values in
dictionary are list of arrays of variables outputted by trace functions
in `trace_funcs` with each array in the list corresponding to a single
chain and the leading dimension of each array corresponding to the
iteration (draw) index in the main non-adaptive sampling stage. The key
for each value is the corresponding key in the dictionary returned by
the trace function which computed the traced value.
chain_stats (dict[str, dict[str, list[array]]]): dictionary of chain
transition statistic dictionaries. Values in outer dictionary are
dictionaries of statistics for each chain transition, keyed by the
string key for the transition. The values in each inner transition
dictionary are lists of arrays of chain statistic values with each array
in the list corresponding to a single chain and the leading dimension of
each array corresponding to the iteration (draw) index in the main
non-adaptive sampling stage. The key for each value is a string
description of the corresponding integration transition statistic.
Named tuple :code:`(final_states, traces, statistics)` corresponding to
states of chains after final interatinos, dictionary of chain trace arrays
and dictionary of chain statistics dictionaries.
"""
if not display_progress:
progress_bar_class = DummyProgressBar
Expand Down Expand Up @@ -1010,14 +1022,43 @@ def sample_chains(
if stage.trace_funcs is not None or stage.record_stats:
sampling_index_offset += stage.n_iter
if isinstance(exception, KeyboardInterrupt):
return chain_states, traces, stats
return chain_states, traces, stats
return MCMCSampleChainsOutputs(chain_states, traces, stats)
return MCMCSampleChainsOutputs(chain_states, traces, stats)


class HMCSampleChainsOutputs(NamedTuple):
"""Outputs returned by :py:meth:`HamiltonianMCMC.sample_chains` call.
Parameters:
final_states: States of chains after final iteration. May be used to resume
sampling a chain by passing as the initial states to a new `sample_chains`
call.
traces: Dictionary of chain trace arrays. Values in dictionary are list of
arrays of variables outputted by trace functions in `trace_funcs` with each
array in the list corresponding to a single chain and the leading dimension
of each array corresponding to the iteration (draw) index, within the main
non-adaptive sampling stage if `trace_warm_up=False` and across both warm-up
and main sampling stages otherwise. The key for each value is the
corresponding key in the dictionary returned by the trace function which
computed the traced value.
statistics: Dictionary of chain transition statistic dictionaries. Values in
dictionary are lists of arrays of chain statistic values with each array in
the list corresponding to a single chain and the leading dimension of each
array corresponding to the iteration (draw) index, within the main
non-adaptive sampling stage if `trace_warm_up=False` and across both warm-up
and main sampling stages otherwise. The key for each value is a string
description of the corresponding integration transition statistic.
"""

final_states: list[ChainState]
traces: dict[str, list[NDArray]]
statistics: dict[str, list[NDArray]]


class HamiltonianMCMC(MarkovChainMonteCarloMethod):
"""Wrapper class for Hamiltonian Markov chain Monte Carlo (H-MCMC) methods.
class HamiltonianMonteCarlo(MarkovChainMonteCarloMethod):
"""Wrapper class for Hamiltonian Monte Carlo (HMC) methods.
Here H-MCMC is defined as a MCMC method which augments the original target variable
Here HMC is defined as a MCMC method which augments the original target variable
(henceforth position variable) with a momentum variable with a user specified
conditional distribution given the position variable. In each chain iteration two
Markov transitions leaving the resulting joint distribution on position and momentum
Expand Down Expand Up @@ -1105,7 +1146,7 @@ def sample_chains(
n_main_iter: int,
init_states: Iterable[Union[ChainState, NDArray, dict]],
**kwargs,
) -> tuple[list[ChainState], dict[str, list[NDArray]], dict[str, list[NDArray]]]:
) -> HMCSampleChainsOutputs:
"""Sample Markov chains from given initial states with adaptive warm up.
One or more Markov chains are sampled, with each chain iteration consisting of a
Expand Down Expand Up @@ -1137,7 +1178,7 @@ def sample_chains(
conditional distribution. One chain will be run for each state in the
iterable.
Kwargs:
Keyword args:
trace_funcs: Sequence of functions which compute the variables to be
recorded at each chain iteration (during only the main non-adaptive
sampling stage if `trace_warm_up` is False), with each trace function
Expand All @@ -1146,7 +1187,7 @@ def sample_chains(
returned dictionaries are used to index the trace arrays in the returned
traces dictionary. If a key appears in multiple dictionaries only the
the value corresponding to the last trace function to return that key
will be stored. Default is to use a single function which recordes the
will be stored. Default is to use a single function which records the
position component of the state under the key `pos` and the Hamiltonian
at the state under the key `hamiltonian`.
adapters: Sequence of `mici.adapters.Adapter` instances to use to
Expand Down Expand Up @@ -1204,24 +1245,9 @@ def sample_chains(
`mici.progressbars.SequenceProgressBar`.
Returns:
final_states: States of chains after final iteration. May be used to resume
sampling a chain by passing as the initial states to a new
`sample_chains` call.
traces: dictionary of chain trace arrays. Values in dictionary are list of
arrays of variables outputted by trace functions in `trace_funcs` with
each array in the list corresponding to a single chain and the leading
dimension of each array corresponding to the iteration (draw) index
(within the main non-adaptive sampling stage if `trace_warm_up` is
False). The key for each value is the corresponding key in the
dictionary returned by the trace function which computed the traced
value.
stats: dictionary of chain statistics. Values in dictionary are lists of
arrays of chain statistic values with each array in the list
corresponding to a single chain and the leading dimension of each array
corresponding to the iteration (draw) index (within the main
non-adaptive sampling stage if `trace_warm_up` is False). The key for
each value is a string description of the corresponding integration
transition statistic.
Named tuple :code:`(final_states, traces, statistics)` corresponding to
states of chains after final interatinos, dictionary of chain trace arrays
and dictionary of chain statistics dictionaries.
"""
init_states = [self._preprocess_init_state(i) for i in init_states]
# default to single dual-averaging step size adapter
Expand All @@ -1247,11 +1273,11 @@ def sample_chains(
n_warm_up_iter, n_main_iter, init_states, **kwargs
)
stats = stats.get("integration_transition", {})
return final_states, traces, stats
return HMCSampleChainsOutputs(final_states, traces, stats)


class StaticMetropolisHMC(HamiltonianMCMC):
"""Static integration time H-MCMC implementation with Metropolis sampling.
class StaticMetropolisHMC(HamiltonianMonteCarlo):
"""Static integration time HMC method with Metropolis sampling.
In each transition a trajectory is generated by integrating the Hamiltonian dynamics
from the current state in the current integration time direction for a fixed integer
Expand Down Expand Up @@ -1314,8 +1340,8 @@ def n_step(self, value: int):
self.transitions["integration_transition"].n_step = value


class RandomMetropolisHMC(HamiltonianMCMC):
"""Random integration time H-MCMC with Metropolis sampling of new state.
class RandomMetropolisHMC(HamiltonianMonteCarlo):
"""Random integration time HMC method with Metropolis sampling of new state.
In each transition a trajectory is generated by integrating the Hamiltonian dynamics
from the current state in the current integration time direction for a random
Expand Down Expand Up @@ -1386,8 +1412,8 @@ def n_step_range(self, value: tuple[int, int]):
self.transitions["integration_transition"].n_step_range = value


class DynamicMultinomialHMC(HamiltonianMCMC):
"""Dynamic integration time H-MCMC with multinomial sampling of new state.
class DynamicMultinomialHMC(HamiltonianMonteCarlo):
"""Dynamic integration time HMC method with multinomial sampling from trajectory.
In each transition a binary tree of states is recursively computed by integrating
randomly forward and backward in time by a number of steps equal to the previous
Expand Down Expand Up @@ -1492,8 +1518,8 @@ def max_delta_h(self, value: float):
self.transitions["integration_transition"].max_delta_h = value


class DynamicSliceHMC(HamiltonianMCMC):
"""Dynamic integration time H-MCMC with slice sampling of new state.
class DynamicSliceHMC(HamiltonianMonteCarlo):
"""Dynamic integration time HMC method with slice sampling from trajectory.
In each transition a binary tree of states is recursively computed by integrating
randomly forward and backward in time by a number of steps equal to the previous
Expand Down

0 comments on commit 454db41

Please sign in to comment.