diff --git a/src/mici/samplers.py b/src/mici/samplers.py index 09aab6d..555f650 100644 --- a/src/mici/samplers.py +++ b/src/mici/samplers.py @@ -11,7 +11,7 @@ import tempfile import signal from warnings import warn -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, NamedTuple import numpy as np from numpy.random import default_rng from mici.transitions import ( @@ -752,6 +752,37 @@ def _sample_chains_parallel( return (*_collate_chain_outputs(chain_outputs), exception) +class MCMCSampleChainsOutputs(NamedTuple): + """Outputs returned by :py:meth:`MarkovChainMonteCarloMethod.sample_chains` call. + + Parameters: + final_states: States of chains after final iteration. May be used to resume + sampling a chain by passing as the initial states to a new `sample_chains` + call. + traces: Dictionary of chain trace arrays. Values in dictionary are list of + arrays of variables outputted by trace functions in `trace_funcs` with each + array in the list corresponding to a single chain and the leading dimension + of each array corresponding to the iteration (draw) index, within the main + non-adaptive sampling stage if `trace_warm_up=False` and across both warm-up + and main sampling stages otherwise. The key for each value is the + corresponding key in the dictionary returned by the trace function which + computed the traced value. + statistics: Dictionary of chain transition statistic dictionaries. Values in + outer dictionary are dictionaries of statistics for each chain transition, + keyed by the string key for the transition. The values in each inner + transition dictionary are lists of arrays of chain statistic values with + each array in the list corresponding to a single chain and the leading + dimension of each array corresponding to the iteration (draw) index, within + the main non-adaptive sampling stage if `trace_warm_up=False` and across + both warm-up and main sampling stages otherwise. The key for each value is a + string description of the corresponding transition statistic. + """ + + final_states: list[ChainState] + traces: dict[str, list[NDArray]] + statistics: dict[str, dict[str, list[NDArray]]] + + class MarkovChainMonteCarloMethod: """Generic Markov chain Monte Carlo (MCMC) sampler. @@ -782,6 +813,7 @@ def sample_chains( n_warm_up_iter: int, n_main_iter: int, init_states: Iterable[Union[ChainState, dict]], + *, trace_funcs: Optional[Sequence[TraceFunction]] = None, adapters: Optional[dict[str, Sequence[Adapter]]] = None, stager: Optional[Stager] = None, @@ -793,9 +825,7 @@ def sample_chains( monitor_stats: Optional[dict[str, list[str]]] = None, display_progress: bool = True, progress_bar_class: Optional[ProgressBar] = None, - ) -> tuple[ - list[ChainState], dict[str, list[NDArray]], dict[str, dict[str, list[NDArray]]] - ]: + ) -> MCMCSampleChainsOutputs: """Sample Markov chains from given initial states with optional adaptive warm up One or more Markov chains are sampled, with each chain iteration consisting of @@ -818,8 +848,6 @@ def sample_chains( init_states: Initial chain states. Each entry can be either a `ChainState` object or a dictionary with entries specifying initial values for all state variables used by chain transition `sample` methods. - - Kwargs: trace_funcs: Sequence of functions which compute the variables to be recorded at each chain iteration (during only the main non-adaptive sampling stage if `trace_warm_up` is False), with each trace function @@ -891,25 +919,9 @@ def sample_chains( `display_progress=True`. Returns: - final_states (list[ChainState]): States of chains after final iteration. May - be used to resume sampling a chain by passing as the initial states to a - new `sample_chains` call. - traces (dict[str, list[array]]): dictionary of chain trace arrays. Values in - dictionary are list of arrays of variables outputted by trace functions - in `trace_funcs` with each array in the list corresponding to a single - chain and the leading dimension of each array corresponding to the - iteration (draw) index in the main non-adaptive sampling stage. The key - for each value is the corresponding key in the dictionary returned by - the trace function which computed the traced value. - chain_stats (dict[str, dict[str, list[array]]]): dictionary of chain - transition statistic dictionaries. Values in outer dictionary are - dictionaries of statistics for each chain transition, keyed by the - string key for the transition. The values in each inner transition - dictionary are lists of arrays of chain statistic values with each array - in the list corresponding to a single chain and the leading dimension of - each array corresponding to the iteration (draw) index in the main - non-adaptive sampling stage. The key for each value is a string - description of the corresponding integration transition statistic. + Named tuple :code:`(final_states, traces, statistics)` corresponding to + states of chains after final interatinos, dictionary of chain trace arrays + and dictionary of chain statistics dictionaries. """ if not display_progress: progress_bar_class = DummyProgressBar @@ -1010,14 +1022,43 @@ def sample_chains( if stage.trace_funcs is not None or stage.record_stats: sampling_index_offset += stage.n_iter if isinstance(exception, KeyboardInterrupt): - return chain_states, traces, stats - return chain_states, traces, stats + return MCMCSampleChainsOutputs(chain_states, traces, stats) + return MCMCSampleChainsOutputs(chain_states, traces, stats) + + +class HMCSampleChainsOutputs(NamedTuple): + """Outputs returned by :py:meth:`HamiltonianMCMC.sample_chains` call. + + Parameters: + final_states: States of chains after final iteration. May be used to resume + sampling a chain by passing as the initial states to a new `sample_chains` + call. + traces: Dictionary of chain trace arrays. Values in dictionary are list of + arrays of variables outputted by trace functions in `trace_funcs` with each + array in the list corresponding to a single chain and the leading dimension + of each array corresponding to the iteration (draw) index, within the main + non-adaptive sampling stage if `trace_warm_up=False` and across both warm-up + and main sampling stages otherwise. The key for each value is the + corresponding key in the dictionary returned by the trace function which + computed the traced value. + statistics: Dictionary of chain transition statistic dictionaries. Values in + dictionary are lists of arrays of chain statistic values with each array in + the list corresponding to a single chain and the leading dimension of each + array corresponding to the iteration (draw) index, within the main + non-adaptive sampling stage if `trace_warm_up=False` and across both warm-up + and main sampling stages otherwise. The key for each value is a string + description of the corresponding integration transition statistic. + """ + + final_states: list[ChainState] + traces: dict[str, list[NDArray]] + statistics: dict[str, list[NDArray]] -class HamiltonianMCMC(MarkovChainMonteCarloMethod): - """Wrapper class for Hamiltonian Markov chain Monte Carlo (H-MCMC) methods. +class HamiltonianMonteCarlo(MarkovChainMonteCarloMethod): + """Wrapper class for Hamiltonian Monte Carlo (HMC) methods. - Here H-MCMC is defined as a MCMC method which augments the original target variable + Here HMC is defined as a MCMC method which augments the original target variable (henceforth position variable) with a momentum variable with a user specified conditional distribution given the position variable. In each chain iteration two Markov transitions leaving the resulting joint distribution on position and momentum @@ -1105,7 +1146,7 @@ def sample_chains( n_main_iter: int, init_states: Iterable[Union[ChainState, NDArray, dict]], **kwargs, - ) -> tuple[list[ChainState], dict[str, list[NDArray]], dict[str, list[NDArray]]]: + ) -> HMCSampleChainsOutputs: """Sample Markov chains from given initial states with adaptive warm up. One or more Markov chains are sampled, with each chain iteration consisting of a @@ -1137,7 +1178,7 @@ def sample_chains( conditional distribution. One chain will be run for each state in the iterable. - Kwargs: + Keyword args: trace_funcs: Sequence of functions which compute the variables to be recorded at each chain iteration (during only the main non-adaptive sampling stage if `trace_warm_up` is False), with each trace function @@ -1146,7 +1187,7 @@ def sample_chains( returned dictionaries are used to index the trace arrays in the returned traces dictionary. If a key appears in multiple dictionaries only the the value corresponding to the last trace function to return that key - will be stored. Default is to use a single function which recordes the + will be stored. Default is to use a single function which records the position component of the state under the key `pos` and the Hamiltonian at the state under the key `hamiltonian`. adapters: Sequence of `mici.adapters.Adapter` instances to use to @@ -1204,24 +1245,9 @@ def sample_chains( `mici.progressbars.SequenceProgressBar`. Returns: - final_states: States of chains after final iteration. May be used to resume - sampling a chain by passing as the initial states to a new - `sample_chains` call. - traces: dictionary of chain trace arrays. Values in dictionary are list of - arrays of variables outputted by trace functions in `trace_funcs` with - each array in the list corresponding to a single chain and the leading - dimension of each array corresponding to the iteration (draw) index - (within the main non-adaptive sampling stage if `trace_warm_up` is - False). The key for each value is the corresponding key in the - dictionary returned by the trace function which computed the traced - value. - stats: dictionary of chain statistics. Values in dictionary are lists of - arrays of chain statistic values with each array in the list - corresponding to a single chain and the leading dimension of each array - corresponding to the iteration (draw) index (within the main - non-adaptive sampling stage if `trace_warm_up` is False). The key for - each value is a string description of the corresponding integration - transition statistic. + Named tuple :code:`(final_states, traces, statistics)` corresponding to + states of chains after final interatinos, dictionary of chain trace arrays + and dictionary of chain statistics dictionaries. """ init_states = [self._preprocess_init_state(i) for i in init_states] # default to single dual-averaging step size adapter @@ -1247,11 +1273,11 @@ def sample_chains( n_warm_up_iter, n_main_iter, init_states, **kwargs ) stats = stats.get("integration_transition", {}) - return final_states, traces, stats + return HMCSampleChainsOutputs(final_states, traces, stats) -class StaticMetropolisHMC(HamiltonianMCMC): - """Static integration time H-MCMC implementation with Metropolis sampling. +class StaticMetropolisHMC(HamiltonianMonteCarlo): + """Static integration time HMC method with Metropolis sampling. In each transition a trajectory is generated by integrating the Hamiltonian dynamics from the current state in the current integration time direction for a fixed integer @@ -1314,8 +1340,8 @@ def n_step(self, value: int): self.transitions["integration_transition"].n_step = value -class RandomMetropolisHMC(HamiltonianMCMC): - """Random integration time H-MCMC with Metropolis sampling of new state. +class RandomMetropolisHMC(HamiltonianMonteCarlo): + """Random integration time HMC method with Metropolis sampling of new state. In each transition a trajectory is generated by integrating the Hamiltonian dynamics from the current state in the current integration time direction for a random @@ -1386,8 +1412,8 @@ def n_step_range(self, value: tuple[int, int]): self.transitions["integration_transition"].n_step_range = value -class DynamicMultinomialHMC(HamiltonianMCMC): - """Dynamic integration time H-MCMC with multinomial sampling of new state. +class DynamicMultinomialHMC(HamiltonianMonteCarlo): + """Dynamic integration time HMC method with multinomial sampling from trajectory. In each transition a binary tree of states is recursively computed by integrating randomly forward and backward in time by a number of steps equal to the previous @@ -1492,8 +1518,8 @@ def max_delta_h(self, value: float): self.transitions["integration_transition"].max_delta_h = value -class DynamicSliceHMC(HamiltonianMCMC): - """Dynamic integration time H-MCMC with slice sampling of new state. +class DynamicSliceHMC(HamiltonianMonteCarlo): + """Dynamic integration time HMC method with slice sampling from trajectory. In each transition a binary tree of states is recursively computed by integrating randomly forward and backward in time by a number of steps equal to the previous