From 6c1d07bef5a07e37ce763953f053b5e2daaa8015 Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Tue, 8 Aug 2023 11:57:15 -0700 Subject: [PATCH 1/4] Mixture prior works More streamlined numerical checks Initialize gamma mixture from conditional coalescent prior Add pdf Update mixture.py to use natural parameterization WIP Moved fully into numba Cleanup Cleanup More debugging WIP Working wording Add missing constant to loglikelihood Skip prior update completely instead of components Skip prior update completely instead of components Remove verbose; use logweights in conditional posterior Move mixture initialization to function Docstrings and CLI Remove some debugging inserts Remove preemptive reference Fix tests --- tests/test_inference.py | 21 ++-- tsdate/cli.py | 31 ++++++ tsdate/core.py | 160 ++++++++++++++++++++++----- tsdate/mixture.py | 240 ++++++++++++++++++++++++++++++++++++++++ tsdate/prior.py | 2 + 5 files changed, 419 insertions(+), 35 deletions(-) create mode 100644 tsdate/mixture.py diff --git a/tests/test_inference.py b/tests/test_inference.py index fcd25023..9616dde3 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -419,13 +419,13 @@ def test_nonglobal_priors(self): priors = tsdate.prior.MixturePrior(ts, prior_distribution="gamma") grid = priors.make_parameter_grid(population_size=1) grid.grid_data[:] = [1.0, 0.0] # noninformative prior - tsdate.date( - ts, - mutation_rate=5, - method="variational_gamma", - priors=grid, - global_prior=False, - ) + with pytest.raises(ValueError, match="not yet implemented"): + tsdate.date( + ts, + mutation_rate=5, + method="variational_gamma", + priors=grid, + ) def test_bad_arguments(self): ts = utility_functions.two_tree_mutation_ts() @@ -437,6 +437,13 @@ def test_bad_arguments(self): method="variational_gamma", max_iterations=-1, ) + with pytest.raises(ValueError, match="must be a positive integer"): + tsdate.date( + ts, + mutation_rate=5, + method="variational_gamma", + global_prior=False, + ) def test_match_central_moments(self): ts = msprime.simulate(8, mutation_rate=5, recombination_rate=5, random_seed=2) diff --git a/tsdate/cli.py b/tsdate/cli.py index abe1fb05..54e89a96 100644 --- a/tsdate/cli.py +++ b/tsdate/cli.py @@ -199,6 +199,34 @@ def tsdate_cli_parser(): "but does not exactly minimize KL divergence in each EP update." ), ) + parser.add_argument( + "--max-iterations", + type=int, + help=( + "The number of iterations used in the expectation propagation " + "algorithm. Default: 20" + ), + default=20, + ) + parser.add_argument( + "--em-iterations", + type=int, + help=( + "The number of expectation-maximization iterations used to optimize the " + "global mixture prior at the end of each expectation propagation step. " + "Setting to zero disables optimization. Default: 10" + ), + default=10, + ) + parser.add_argument( + "--global-prior", + type=int, + help=( + "The number of components in the i.i.d. mixture prior for node " + "ages. Default: 1" + ), + default=1, + ) parser.set_defaults(runner=run_date) parser = subparsers.add_parser( @@ -253,8 +281,11 @@ def run_date(args): method=args.method, eps=args.epsilon, progress=args.progress, + max_iterations=args.max_iterations, max_shape=args.max_shape, match_central_moments=args.match_central_moments, + em_iterations=args.em_iterations, + global_prior=args.global_prior, ) else: params = dict( diff --git a/tsdate/core.py b/tsdate/core.py index 1dd36fce..9fb42c69 100644 --- a/tsdate/core.py +++ b/tsdate/core.py @@ -41,6 +41,7 @@ from . import approx from . import base from . import demography +from . import mixture from . import prior from . import provenance @@ -954,7 +955,7 @@ class ExpectationPropagation(InOutAlgorithms): Bayesian Inference" """ - def __init__(self, *args, **kwargs): + def __init__(self, *args, global_prior, **kwargs): super().__init__(*args, **kwargs) assert self.priors.probability_space == base.GAMMA_PAR @@ -962,24 +963,29 @@ def __init__(self, *args, **kwargs): assert self.lik.grid_size == 2 assert self.priors.timepoints.size == 2 + # global distribution of node ages + self.global_prior = global_prior.copy() + + # messages passed from prior to nodes + self.prior_messages = np.zeros((self.ts.num_nodes, 2)) + # mutation likelihoods, as gamma natural parameters self.likelihoods = np.zeros((self.ts.num_edges, 2)) for e in self.ts.edges(): self.likelihoods[e.id] = self.lik.to_natural(e) - # messages passed from factors to nodes + # messages passed from edge likelihoods to nodes self.messages = np.zeros((self.ts.num_edges, 2, 2)) - # normalizing constants from each factor + # normalizing constants from each edge likelihood self.log_partition = np.zeros(self.ts.num_edges) # the approximate posterior marginals self.posterior = np.zeros((self.ts.num_nodes, 2)) - for n in self.priors.nonfixed_nodes: - self.posterior[n] = self.priors[n] - # edge traversal order + # edge, node traversal order self.edges, self.leaves = self.factorize(self.ts.edges(), self.fixednodes) + self.freenodes = self.priors.nonfixed_nodes # factors for edges leading from fixed nodes are invariant # and can be incorporated into the posterior beforehand @@ -1013,7 +1019,7 @@ def factorize(edge_list, fixed_nodes): @staticmethod @numba.njit("f8(i4[:, :], f8[:, :], f8[:, :], f8[:, :, :], f8[:], f8[:], f8, b1)") - def propagate( + def propagate_likelihood( edges, likelihoods, posterior, @@ -1045,13 +1051,20 @@ def propagate( assert max_shape >= 1.0 - upper = max_shape - 1.0 - lower = 1.0 / max_shape - 1.0 + # Bound the shape parameter for the posterior and cavity distributions + # so that lower_cavi < lower_post < upper_post < upper_cavi. + upper_post = max_shape - 1.0 + lower_post = 1.0 / max_shape - 1.0 + upper_cavi = 2.0 * max_shape - 1.0 + lower_cavi = 0.5 / max_shape - 1.0 def cavity_damping(x, y): + assert upper_cavi > x[0] > lower_cavi d = 1.0 - if (y[0] > 0.0) and (x[0] - y[0] < lower): - d = min(d, (x[0] - lower) / y[0]) + if (y[0] > 0.0) and (x[0] - y[0] < lower_cavi): + d = min(d, (x[0] - lower_cavi) / y[0]) + if (y[0] < 0.0) and (x[0] - y[0] > upper_cavi): + d = min(d, (x[0] - upper_cavi) / y[0]) if (y[1] > 0.0) and (x[1] - y[1] < 0.0): d = min(d, x[1] / y[1]) assert 0.0 < d <= 1.0 @@ -1059,7 +1072,11 @@ def cavity_damping(x, y): def posterior_damping(x): assert x[0] > -1.0 and x[1] > 0.0 - d = min(1.0, upper / abs(x[0])) if (x[0] > 0) else 1.0 + d = 1.0 + if x[0] > upper_post: + d = upper_post / x[0] + if x[0] < lower_post: + d = lower_post / x[0] assert 0.0 < d <= 1.0 return d @@ -1098,13 +1115,81 @@ def posterior_damping(x): return 0.0 # TODO, placeholder - def iterate(self, max_shape=1000, min_kl=True): + @staticmethod + @numba.njit("f8(i4[:], f8[:, :], f8[:, :], f8[:, :], f8[:], f8, i4, f8)") + def propagate_prior( + nodes, global_prior, posterior, messages, scale, max_shape, em_maxitt, em_reltol + ): + """TODO + + :param ndarray nodes: ids of nodes that should be updated + :param ndarray global_prior: rows are mixture components, columns are + zeroth, first, and second natural parameters of gamma mixture + components. Updated in place. + :param ndarray posterior: rows are nodes, columns are first and + second natural parameters of gamma posteriors. Updated in + place. + :param ndarray messages: rows are edges, columns are first and + second natural parameters of prior messages. Updated in place. + :param float max_shape: the maximum allowed shape for node posteriors + :param int em_maxitt: the maximum number of EM iterations to use when + fitting the mixture model + :param int em_reltol: the termination criterion for relative change in + log-likelihood + """ + + if global_prior.shape[0] == 0: + return 0.0 + + assert max_shape >= 1.0 + + upper = max_shape - 1.0 + lower = 1.0 / max_shape - 1.0 + + def posterior_damping(x): + assert x[0] > -1.0 and x[1] > 0.0 + d = 1.0 + if x[0] > upper: + d = upper / x[0] + if x[0] < lower: + d = lower / x[0] + assert 0.0 < d <= 1.0 + return d + + cavity = np.zeros(posterior.shape) + cavity[nodes] = posterior[nodes] - messages[nodes] * scale[nodes, np.newaxis] + global_prior, posterior[nodes] = mixture.fit_gamma_mixture( + global_prior, cavity[nodes], em_maxitt, em_reltol, False + ) + messages[nodes] = (posterior[nodes] - cavity[nodes]) / scale[nodes, np.newaxis] + + for n in nodes: + eta = posterior_damping(posterior[n]) + posterior[n] *= eta + scale[n] *= eta + + return 0.0 + + def iterate(self, em_maxitt=100, em_reltol=1e-6, max_shape=1000, min_kl=True): """ Update edge factors from leaves to root then from root to leaves, and return approximate log marginal likelihood (TODO) """ - self.propagate( + # prior update + self.propagate_prior( + self.freenodes, + self.global_prior, + self.posterior, + self.prior_messages, + self.scale, + max_shape, + em_maxitt, + em_reltol, + ) + + # rootward pass + self.propagate_likelihood( self.edges, self.likelihoods, self.posterior, @@ -1114,7 +1199,9 @@ def iterate(self, max_shape=1000, min_kl=True): max_shape, min_kl, ) - self.propagate( + + # leafward pass + self.propagate_likelihood( self.edges[::-1], self.likelihoods, self.posterior, @@ -1439,32 +1526,36 @@ def main_algorithm(self): self.recombination_rate, fixed_node_set=self.get_fixed_nodes_set(), ) - return ExpectationPropagation(self.priors, lik, progress=self.pbar) + return ExpectationPropagation(self.priors, lik, progress=self.pbar, global_prior=self.prior_mixture) - def run(self, eps, max_iterations, max_shape, match_central_moments, global_prior): + def run(self, eps, max_iterations, max_shape, match_central_moments, global_prior, em_iterations): if self.provenance_params is not None: self.provenance_params.update( {k: v for k, v in locals().items() if k != "self"} ) if not max_iterations >= 1: raise ValueError("Maximum number of EP iterations must be greater than 0") + if not (isinstance(global_prior, int) and global_prior > 0): + raise ValueError("'global_prior' must be a positive integer") if self.mutation_rate is None: raise ValueError("Variational gamma method requires mutation rate") - if global_prior: - logging.info("Pooling node-specific priors into global prior") - self.priors.grid_data[:] = approx.average_gammas( - self.priors.grid_data[:, 0], self.priors.grid_data[:, 1] - ) + + self.prior_mixture = mixture.initialize_mixture(self.priors.grid_data, global_prior) + self.priors.grid_data[:] = [0.0, 0.0] # TODO: support node-specific priors # match sufficient statistics or match central moments min_kl = not match_central_moments dynamic_prog = self.main_algorithm() - for _ in tqdm( + for itt in tqdm( np.arange(max_iterations), desc="Expectation Propagation", disable=not self.pbar, ): - dynamic_prog.iterate(max_shape=max_shape, min_kl=min_kl) + dynamic_prog.iterate( + em_maxitt=em_iterations if itt else 0, + max_shape=max_shape, + min_kl=min_kl, + ) num_skipped = np.sum(np.isnan(dynamic_prog.log_partition)) if num_skipped > 0: @@ -1682,7 +1773,8 @@ def variational_gamma( max_iterations=None, max_shape=None, match_central_moments=None, - global_prior=True, + global_prior=1, + em_iterations=10, **kwargs, ): """ @@ -1697,6 +1789,11 @@ def variational_gamma( new_ts = tsdate.variational_gamma( ts, mutation_rate=1e-8, population_size=1e4, max_iterations=10) + An i.i.d. gamma mixture is used as a prior for each node, that is + initialized from the conditional coalescent and updated via expectation + maximization in each iteration. In addition, node-specific priors may be + specified via a grid of shape/rate parameters. + .. note:: The prior parameters for each node-to-be-dated take the form of a gamma-distributed prior on node age, parameterised by shape and rate. @@ -1718,9 +1815,11 @@ def variational_gamma( update matches mean and variance rather than expected gamma sufficient statistics. Faster with a similar accuracy, but does not exactly minimize Kullback-Leibler divergence. Default: None, treated as False. - :param bool global_prior: If `True`, an iid prior is used for all nodes, - and is constructed by averaging gamma sufficient statistics over the free - nodes in ``priors``. Default: True. + :param int global_prior: The number of components in the i.i.d. mixture prior + for node ages. Default: None, treated as 1. + :param int em_iterations: The number of expectation maximization iterations used + to optimize the global mixture prior. Setting to zero disables optimization. + Default: None, treated as 10. :param \\**kwargs: Other keyword arguments as described in the :func:`date` wrapper function, notably ``mutation_rate``, and ``population_size`` or ``priors``. Further arguments include ``time_units``, ``progress``, and @@ -1752,6 +1851,10 @@ def variational_gamma( max_shape = 1000 if match_central_moments is None: match_central_moments = False + if global_prior is None: + global_prior = 1 + if em_iterations is None: + em_iterations = 10 dating_method = VariationalGammaMethod(tree_sequence, **kwargs) result = dating_method.run( @@ -1760,6 +1863,7 @@ def variational_gamma( max_shape=max_shape, match_central_moments=match_central_moments, global_prior=global_prior, + em_iterations=em_iterations, ) return dating_method.parse_result(result, eps, {"parameter": ["shape", "rate"]}) diff --git a/tsdate/mixture.py b/tsdate/mixture.py new file mode 100644 index 00000000..51538b9f --- /dev/null +++ b/tsdate/mixture.py @@ -0,0 +1,240 @@ +# MIT License +# +# Copyright (c) 2021-23 Tskit Developers +# Copyright (c) 2020-21 University of Oxford +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +Mixture of gamma distributions that may be fit via EM to distribution-valued observations +""" +import numba +import numpy as np + +from . import approx +from . import hypergeo + + +@numba.njit("UniTuple(f8[:], 4)(f8[:], f8[:], f8[:], f8, f8)") +def _conditional_posterior(prior_logweight, prior_alpha, prior_beta, alpha, beta): + r""" + Return expectations of node age :math:`t` from the mixture model, + + ..math:: + + Ga(t | a, b) \sum_j \pi_j w_j Ga(t | \alpha_j, \beta_j) + + where :math:`a` and :math:`b` are variational parameters, + and :math:`\pi_j, \alpha_j, \beta_j` are prior weights and + parameters for a gamma mixture; and :math:`w_j` are fixed, + observation-specific weights. We use natural parameterization, + so that the shape parameter is :math:`\alpha + 1`. + + TODO: + The normalizing constants of the prior are assumed to have already + been integrated into `prior_weight`. + + Returns the contribution from each component to the + posterior expectations of :math:`E[1]`, :math:`E[t]`, :math:`E[log t]`, + and :math:`E[t log t]`. + + Note that :math:`E[1]` is *unnormalized* and *log-transformed*. + """ + + dim = prior_logweight.size + E = np.full(dim, -np.inf) # E[1] (e.g. normalizing constant) + E_t = np.zeros(dim) # E[t] + E_logt = np.zeros(dim) # E[log(t)] + E_tlogt = np.zeros(dim) # E[t * log(t)] + C = (alpha + 1) * np.log(beta) - hypergeo._gammaln(alpha + 1) if beta > 0 else 0.0 + for i in range(dim): + post_alpha = prior_alpha[i] + alpha + post_beta = prior_beta[i] + beta + if (post_alpha <= -1) or (post_beta <= 0): # skip node if divergent + E[:] = -np.inf + break + E[i] = C + ( + +hypergeo._gammaln(post_alpha + 1) + - (post_alpha + 1) * np.log(post_beta) + + prior_logweight[i] + ) + assert np.isfinite(E[i]) + # TODO: option to use moments instead of sufficient statistics? + E_t[i] = (post_alpha + 1) / post_beta + E_logt[i] = hypergeo._digamma(post_alpha + 1) - np.log(post_beta) + E_tlogt[i] = E_t[i] * E_logt[i] + E_t[i] / (post_alpha + 1) + + return E, E_t, E_logt, E_tlogt + + +@numba.njit("f8(f8[:], f8[:], f8[:], f8[:], f8[:])") +def _em_update(prior_weight, prior_alpha, prior_beta, alpha, beta): + """ + Perform an expectation maximization step for parameters of mixture components, + given variational parameters `alpha`, `beta` for each node. + + The maximization step is performed using Ye & Chen (2017) "Closed form + estimators for the gamma distribution ..." + + ``prior_weight``, ``prior_alpha``, ``prior_beta`` are updated in place. + """ + assert alpha.size == beta.size + + dim = prior_weight.size + n = np.zeros(dim) + t = np.zeros(dim) + logt = np.zeros(dim) + tlogt = np.zeros(dim) + + # incorporate prior normalizing constants into weights + prior_logweight = np.log(prior_weight) + for k in range(dim): + prior_logweight[k] += (prior_alpha[k] + 1) * np.log( + prior_beta[k] + ) - hypergeo._gammaln(prior_alpha[k] + 1) + + # expectation step: + loglik = 0.0 + for a, b in zip(alpha, beta): + E, E_t, E_logt, E_tlogt = _conditional_posterior( + prior_logweight, prior_alpha, prior_beta, a, b + ) + + # skip if posterior is improper + if np.any(np.isinf(E)): + continue + + # convert evidence to posterior weights + norm_const = np.log(np.sum(np.exp(E - np.max(E)))) + np.max(E) + weight = np.exp(E - norm_const) + + # weighted contributions to sufficient statistics + loglik += norm_const + n += weight + t += E_t * weight + logt += E_logt * weight + tlogt += E_tlogt * weight + + # maximization step: update parameters in place + prior_weight[:] = n / np.sum(n) + prior_beta[:] = n**2 / (n * tlogt - t * logt) + prior_alpha[:] = n * t / (n * tlogt - t * logt) - 1.0 + + return loglik + + +@numba.njit("f8[:](f8[:], f8[:], f8[:], f8[:], f8[:])") +def _gamma_projection(prior_weight, prior_alpha, prior_beta, alpha, beta): + """ + Given variational approximation to posterior: multiply by exact prior, + calculate sufficient statistics, and moment match to get new + approximate posterior. + + Updates ``alpha`` and ``beta`` in-place. + """ + assert alpha.size == beta.size + + dim = prior_weight.size + + # incorporate prior normalizing constants into weights + prior_logweight = np.log(prior_weight) + for k in range(dim): + prior_logweight[k] += (prior_alpha[k] + 1) * np.log( + prior_beta[k] + ) - hypergeo._gammaln(prior_alpha[k] + 1) + + log_const = np.full(alpha.size, -np.inf) + for i in range(alpha.size): + E, E_t, E_logt, E_tlogt = _conditional_posterior( + prior_logweight, prior_alpha, prior_beta, alpha[i], beta[i] + ) + + # skip if posterior is improper for all components + if np.any(np.isinf(E)): + continue + + norm = np.log(np.sum(np.exp(E - np.max(E)))) + np.max(E) + weight = np.exp(E - norm) + t = np.sum(weight * E_t) + logt = np.sum(weight * E_logt) + # tlogt = np.sum(weight * E_tlogt) + log_const[i] = norm + alpha[i], beta[i] = approx.approximate_gamma_kl(t, logt) + # beta[i] = 1 / (tlogt - t * logt) + # alpha[i] = t * beta[i] - 1.0 + + return log_const + + +@numba.njit("Tuple((f8[:,:], f8[:,:]))(f8[:,:], f8[:,:], i4, f8, b1)") +def fit_gamma_mixture(mixture, observations, max_iterations, tolerance, verbose): + """ + Run EM until relative tolerance or maximum number of iterations is + reached. Then, perform expectation-propagation update and return new + variational parameters for the posterior approximation. + """ + + assert mixture.shape[1] == 3 + assert observations.shape[1] == 2 + + mix_weight, mix_alpha, mix_beta = mixture.T + alpha, beta = observations.T + + last = np.inf + for itt in range(max_iterations): + loglik = _em_update(mix_weight, mix_alpha, mix_beta, alpha, beta) + loglik /= float(alpha.size) + update = np.abs(loglik - last) + last = loglik + if verbose: + print("EM iteration:", itt, "mean(loglik):", np.round(loglik, 5)) + print(" -> weights:", mix_weight) + print(" -> alpha:", mix_alpha) + print(" -> beta:", mix_beta) + if update < np.abs(loglik) * tolerance: + break + + # conditional posteriors for each observation + # log_const = _gamma_projection(mix_weight, mix_alpha, mix_beta, alpha, beta) + _gamma_projection(mix_weight, mix_alpha, mix_beta, alpha, beta) + + new_mixture = np.zeros(mixture.shape) + new_observations = np.zeros(observations.shape) + new_observations[:, 0] = alpha + new_observations[:, 1] = beta + new_mixture[:, 0] = mix_weight + new_mixture[:, 1] = mix_alpha + new_mixture[:, 2] = mix_beta + + return new_mixture, new_observations + + +def initialize_mixture(parameters, num_components): + """initialize clusters by dividing nodes into equal groups""" + global_prior = np.empty((num_components, 3)) + num_nodes = parameters.shape[0] + age_classes = np.tile(np.arange(num_components), num_nodes // num_components + 1)[ + :num_nodes + ] + for k in range(num_components): + indices = np.equal(age_classes, k) + alpha, beta = approx.average_gammas( + parameters[indices, 0] - 1.0, parameters[indices, 1] + ) + global_prior[k] = [1.0 / num_components, alpha, beta] + return global_prior diff --git a/tsdate/prior.py b/tsdate/prior.py index 72a7087e..33586b26 100644 --- a/tsdate/prior.py +++ b/tsdate/prior.py @@ -30,6 +30,8 @@ import numba import numpy as np +import scipy.cluster +import scipy.special import scipy.stats import tskit from tqdm.auto import tqdm From 3b3b6cb0d94ee2e482f8ccfdcdbae5f246b37a3c Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Fri, 5 Jan 2024 15:11:11 -0800 Subject: [PATCH 2/4] Sorting out tests, prior parameterization --- tests/test_inference.py | 17 +++++++++++++++-- tsdate/core.py | 29 ++++++++++++++++++++++------- tsdate/mixture.py | 14 +++++++++----- 3 files changed, 46 insertions(+), 14 deletions(-) diff --git a/tests/test_inference.py b/tests/test_inference.py index 9616dde3..1f30ebec 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -414,12 +414,12 @@ def test_simple_sim_multi_tree(self): ts = msprime.simulate(8, mutation_rate=5, recombination_rate=5, random_seed=2) tsdate.date(ts, mutation_rate=5, population_size=1, method="variational_gamma") - def test_nonglobal_priors(self): + def test_invalid_priors(self): ts = msprime.simulate(8, mutation_rate=5, recombination_rate=5, random_seed=2) priors = tsdate.prior.MixturePrior(ts, prior_distribution="gamma") grid = priors.make_parameter_grid(population_size=1) grid.grid_data[:] = [1.0, 0.0] # noninformative prior - with pytest.raises(ValueError, match="not yet implemented"): + with pytest.raises(ValueError, match="Non-positive shape/rate"): tsdate.date( ts, mutation_rate=5, @@ -427,6 +427,18 @@ def test_nonglobal_priors(self): priors=grid, ) + def test_custom_priors(self): + ts = msprime.simulate(8, mutation_rate=5, recombination_rate=5, random_seed=2) + priors = tsdate.prior.MixturePrior(ts, prior_distribution="gamma") + grid = priors.make_parameter_grid(population_size=1) + grid.grid_data[:] += 1.0 + tsdate.date( + ts, + mutation_rate=5, + method="variational_gamma", + priors=grid, + ) + def test_bad_arguments(self): ts = utility_functions.two_tree_mutation_ts() with pytest.raises(ValueError, match="Maximum number of EP iterations"): @@ -441,6 +453,7 @@ def test_bad_arguments(self): tsdate.date( ts, mutation_rate=5, + population_size=1, method="variational_gamma", global_prior=False, ) diff --git a/tsdate/core.py b/tsdate/core.py index 9fb42c69..dcf077b5 100644 --- a/tsdate/core.py +++ b/tsdate/core.py @@ -1487,11 +1487,14 @@ class VariationalGammaMethod(EstimationMethod): def __init__(self, ts, **kwargs): super().__init__(ts, **kwargs) - # convert priors to natural parameterization and average + # convert priors to natural parameterization for n in self.priors.nonfixed_nodes: + if not np.all(self.priors[n] > 0.0): + raise ValueError( + f"Non-positive shape/rate parameters for node {n}: " + f"{self.priors[n]}" + ) self.priors[n][0] -= 1.0 - assert self.priors[n][0] > -1.0 - assert self.priors[n][1] >= 0.0 @staticmethod def mean_var(ts, posterior): @@ -1526,9 +1529,19 @@ def main_algorithm(self): self.recombination_rate, fixed_node_set=self.get_fixed_nodes_set(), ) - return ExpectationPropagation(self.priors, lik, progress=self.pbar, global_prior=self.prior_mixture) + return ExpectationPropagation( + self.priors, lik, progress=self.pbar, global_prior=self.prior_mixture + ) - def run(self, eps, max_iterations, max_shape, match_central_moments, global_prior, em_iterations): + def run( + self, + eps, + max_iterations, + max_shape, + match_central_moments, + global_prior, + em_iterations, + ): if self.provenance_params is not None: self.provenance_params.update( {k: v for k, v in locals().items() if k != "self"} @@ -1540,8 +1553,10 @@ def run(self, eps, max_iterations, max_shape, match_central_moments, global_prio if self.mutation_rate is None: raise ValueError("Variational gamma method requires mutation rate") - self.prior_mixture = mixture.initialize_mixture(self.priors.grid_data, global_prior) - self.priors.grid_data[:] = [0.0, 0.0] # TODO: support node-specific priors + self.prior_mixture = mixture.initialize_mixture( + self.priors.grid_data, global_prior + ) + self.priors.grid_data[:] = [0.0, 0.0] # TODO: support node-specific priors # match sufficient statistics or match central moments min_kl = not match_central_moments diff --git a/tsdate/mixture.py b/tsdate/mixture.py index 51538b9f..0eb99687 100644 --- a/tsdate/mixture.py +++ b/tsdate/mixture.py @@ -225,16 +225,20 @@ def fit_gamma_mixture(mixture, observations, max_iterations, tolerance, verbose) def initialize_mixture(parameters, num_components): - """initialize clusters by dividing nodes into equal groups""" + """ + Initialize clusters by dividing nodes into equal groups. + "parameters" are in natural parameterization (not shape/rate) + """ global_prior = np.empty((num_components, 3)) num_nodes = parameters.shape[0] - age_classes = np.tile(np.arange(num_components), num_nodes // num_components + 1)[ - :num_nodes - ] + age_classes = np.tile( + np.arange(num_components), + num_nodes // num_components + 1, + )[:num_nodes] for k in range(num_components): indices = np.equal(age_classes, k) alpha, beta = approx.average_gammas( - parameters[indices, 0] - 1.0, parameters[indices, 1] + parameters[indices, 0], parameters[indices, 1] ) global_prior[k] = [1.0 / num_components, alpha, beta] return global_prior From 8f5ead820560524924f4af2e46b53c2fe0503b5e Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Fri, 5 Jan 2024 15:26:32 -0800 Subject: [PATCH 3/4] Rename global_prior arg; more tests --- tests/test_inference.py | 14 +++++++++++++- tsdate/cli.py | 6 +++--- tsdate/core.py | 37 ++++++++++++++++++++----------------- 3 files changed, 36 insertions(+), 21 deletions(-) diff --git a/tests/test_inference.py b/tests/test_inference.py index 1f30ebec..5f9182da 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -439,6 +439,18 @@ def test_custom_priors(self): priors=grid, ) + def test_prior_mixture_dim(self): + ts = msprime.simulate(8, mutation_rate=5, recombination_rate=5, random_seed=2) + priors = tsdate.prior.MixturePrior(ts, prior_distribution="gamma") + grid = priors.make_parameter_grid(population_size=1) + tsdate.date( + ts, + mutation_rate=5, + method="variational_gamma", + priors=grid, + prior_mixture_dim=2, + ) + def test_bad_arguments(self): ts = utility_functions.two_tree_mutation_ts() with pytest.raises(ValueError, match="Maximum number of EP iterations"): @@ -455,7 +467,7 @@ def test_bad_arguments(self): mutation_rate=5, population_size=1, method="variational_gamma", - global_prior=False, + prior_mixture_dim=0.1, ) def test_match_central_moments(self): diff --git a/tsdate/cli.py b/tsdate/cli.py index 54e89a96..6248f358 100644 --- a/tsdate/cli.py +++ b/tsdate/cli.py @@ -213,13 +213,13 @@ def tsdate_cli_parser(): type=int, help=( "The number of expectation-maximization iterations used to optimize the " - "global mixture prior at the end of each expectation propagation step. " + "i.i.d. mixture prior at the end of each expectation propagation step. " "Setting to zero disables optimization. Default: 10" ), default=10, ) parser.add_argument( - "--global-prior", + "--prior-mixture-dim", type=int, help=( "The number of components in the i.i.d. mixture prior for node " @@ -285,7 +285,7 @@ def run_date(args): max_shape=args.max_shape, match_central_moments=args.match_central_moments, em_iterations=args.em_iterations, - global_prior=args.global_prior, + prior_mixture_dim=args.prior_mixture_dim, ) else: params = dict( diff --git a/tsdate/core.py b/tsdate/core.py index dcf077b5..d1377f07 100644 --- a/tsdate/core.py +++ b/tsdate/core.py @@ -1030,8 +1030,9 @@ def propagate_likelihood( min_kl, ): """ - Update approximating factors for each edge, returning average relative - difference in natural parameters (TODO) + Update approximating factors for each edge. + + TODO: return max difference in natural parameters for stopping criterion :param ndarray edges: integer array of dimension `[num_edges, 3]` containing edge id, parent id, and child id. @@ -1120,7 +1121,8 @@ def posterior_damping(x): def propagate_prior( nodes, global_prior, posterior, messages, scale, max_shape, em_maxitt, em_reltol ): - """TODO + """ + Update approximating factors for global prior at each node. :param ndarray nodes: ids of nodes that should be updated :param ndarray global_prior: rows are mixture components, columns are @@ -1530,7 +1532,7 @@ def main_algorithm(self): fixed_node_set=self.get_fixed_nodes_set(), ) return ExpectationPropagation( - self.priors, lik, progress=self.pbar, global_prior=self.prior_mixture + self.priors, lik, progress=self.pbar, global_prior=self.global_prior ) def run( @@ -1539,7 +1541,7 @@ def run( max_iterations, max_shape, match_central_moments, - global_prior, + prior_mixture_dim, em_iterations, ): if self.provenance_params is not None: @@ -1548,13 +1550,13 @@ def run( ) if not max_iterations >= 1: raise ValueError("Maximum number of EP iterations must be greater than 0") - if not (isinstance(global_prior, int) and global_prior > 0): - raise ValueError("'global_prior' must be a positive integer") + if not (isinstance(prior_mixture_dim, int) and prior_mixture_dim > 0): + raise ValueError("Number of mixture components must be a positive integer") if self.mutation_rate is None: raise ValueError("Variational gamma method requires mutation rate") - self.prior_mixture = mixture.initialize_mixture( - self.priors.grid_data, global_prior + self.global_prior = mixture.initialize_mixture( + self.priors.grid_data, prior_mixture_dim ) self.priors.grid_data[:] = [0.0, 0.0] # TODO: support node-specific priors @@ -1788,7 +1790,7 @@ def variational_gamma( max_iterations=None, max_shape=None, match_central_moments=None, - global_prior=1, + prior_mixture_dim=1, em_iterations=10, **kwargs, ): @@ -1806,8 +1808,9 @@ def variational_gamma( An i.i.d. gamma mixture is used as a prior for each node, that is initialized from the conditional coalescent and updated via expectation - maximization in each iteration. In addition, node-specific priors may be - specified via a grid of shape/rate parameters. + maximization in each iteration. If node-specific priors are supplied + (via a grid of shape/rate parameters) then these are used for + initialization. .. note:: The prior parameters for each node-to-be-dated take the form of a @@ -1830,10 +1833,10 @@ def variational_gamma( update matches mean and variance rather than expected gamma sufficient statistics. Faster with a similar accuracy, but does not exactly minimize Kullback-Leibler divergence. Default: None, treated as False. - :param int global_prior: The number of components in the i.i.d. mixture prior + :param int prior_mixture_dim: The number of components in the i.i.d. mixture prior for node ages. Default: None, treated as 1. :param int em_iterations: The number of expectation maximization iterations used - to optimize the global mixture prior. Setting to zero disables optimization. + to optimize the i.i.d. mixture prior. Setting to zero disables optimization. Default: None, treated as 10. :param \\**kwargs: Other keyword arguments as described in the :func:`date` wrapper function, notably ``mutation_rate``, and ``population_size`` or ``priors``. @@ -1866,8 +1869,8 @@ def variational_gamma( max_shape = 1000 if match_central_moments is None: match_central_moments = False - if global_prior is None: - global_prior = 1 + if prior_mixture_dim is None: + prior_mixture_dim = 1 if em_iterations is None: em_iterations = 10 @@ -1877,7 +1880,7 @@ def variational_gamma( max_iterations=max_iterations, max_shape=max_shape, match_central_moments=match_central_moments, - global_prior=global_prior, + prior_mixture_dim=prior_mixture_dim, em_iterations=em_iterations, ) return dating_method.parse_result(result, eps, {"parameter": ["shape", "rate"]}) From cf58a1ced8187c9ce9035e0dac50bfd7bc16e80b Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Fri, 5 Jan 2024 17:23:58 -0800 Subject: [PATCH 4/4] Minor fixes --- tsdate/core.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tsdate/core.py b/tsdate/core.py index d1377f07..1477248a 100644 --- a/tsdate/core.py +++ b/tsdate/core.py @@ -956,6 +956,11 @@ class ExpectationPropagation(InOutAlgorithms): """ def __init__(self, *args, global_prior, **kwargs): + """ + `global_prior` contains parameters of a gamma mixture, used as an iid + prior, wheras `self.priors` (node-specific priors) are not used, except + to provide a list of nonfixed nodes. + """ super().__init__(*args, **kwargs) assert self.priors.probability_space == base.GAMMA_PAR @@ -1555,6 +1560,9 @@ def run( if self.mutation_rate is None: raise ValueError("Variational gamma method requires mutation rate") + # initialize weights/shapes/rates for i.i.d mixture prior + # note that self.priors (node-specific priors) are not currently + # used except for initialization of the mixture self.global_prior = mixture.initialize_mixture( self.priors.grid_data, prior_mixture_dim ) @@ -1790,8 +1798,8 @@ def variational_gamma( max_iterations=None, max_shape=None, match_central_moments=None, - prior_mixture_dim=1, - em_iterations=10, + prior_mixture_dim=None, + em_iterations=None, **kwargs, ): """