diff --git a/tests/test_inference.py b/tests/test_inference.py index fcd25023..5f9182da 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -414,17 +414,41 @@ def test_simple_sim_multi_tree(self): ts = msprime.simulate(8, mutation_rate=5, recombination_rate=5, random_seed=2) tsdate.date(ts, mutation_rate=5, population_size=1, method="variational_gamma") - def test_nonglobal_priors(self): + def test_invalid_priors(self): ts = msprime.simulate(8, mutation_rate=5, recombination_rate=5, random_seed=2) priors = tsdate.prior.MixturePrior(ts, prior_distribution="gamma") grid = priors.make_parameter_grid(population_size=1) grid.grid_data[:] = [1.0, 0.0] # noninformative prior + with pytest.raises(ValueError, match="Non-positive shape/rate"): + tsdate.date( + ts, + mutation_rate=5, + method="variational_gamma", + priors=grid, + ) + + def test_custom_priors(self): + ts = msprime.simulate(8, mutation_rate=5, recombination_rate=5, random_seed=2) + priors = tsdate.prior.MixturePrior(ts, prior_distribution="gamma") + grid = priors.make_parameter_grid(population_size=1) + grid.grid_data[:] += 1.0 tsdate.date( ts, mutation_rate=5, method="variational_gamma", priors=grid, - global_prior=False, + ) + + def test_prior_mixture_dim(self): + ts = msprime.simulate(8, mutation_rate=5, recombination_rate=5, random_seed=2) + priors = tsdate.prior.MixturePrior(ts, prior_distribution="gamma") + grid = priors.make_parameter_grid(population_size=1) + tsdate.date( + ts, + mutation_rate=5, + method="variational_gamma", + priors=grid, + prior_mixture_dim=2, ) def test_bad_arguments(self): @@ -437,6 +461,14 @@ def test_bad_arguments(self): method="variational_gamma", max_iterations=-1, ) + with pytest.raises(ValueError, match="must be a positive integer"): + tsdate.date( + ts, + mutation_rate=5, + population_size=1, + method="variational_gamma", + prior_mixture_dim=0.1, + ) def test_match_central_moments(self): ts = msprime.simulate(8, mutation_rate=5, recombination_rate=5, random_seed=2) diff --git a/tsdate/cli.py b/tsdate/cli.py index abe1fb05..6248f358 100644 --- a/tsdate/cli.py +++ b/tsdate/cli.py @@ -199,6 +199,34 @@ def tsdate_cli_parser(): "but does not exactly minimize KL divergence in each EP update." ), ) + parser.add_argument( + "--max-iterations", + type=int, + help=( + "The number of iterations used in the expectation propagation " + "algorithm. Default: 20" + ), + default=20, + ) + parser.add_argument( + "--em-iterations", + type=int, + help=( + "The number of expectation-maximization iterations used to optimize the " + "i.i.d. mixture prior at the end of each expectation propagation step. " + "Setting to zero disables optimization. Default: 10" + ), + default=10, + ) + parser.add_argument( + "--prior-mixture-dim", + type=int, + help=( + "The number of components in the i.i.d. mixture prior for node " + "ages. Default: 1" + ), + default=1, + ) parser.set_defaults(runner=run_date) parser = subparsers.add_parser( @@ -253,8 +281,11 @@ def run_date(args): method=args.method, eps=args.epsilon, progress=args.progress, + max_iterations=args.max_iterations, max_shape=args.max_shape, match_central_moments=args.match_central_moments, + em_iterations=args.em_iterations, + prior_mixture_dim=args.prior_mixture_dim, ) else: params = dict( diff --git a/tsdate/core.py b/tsdate/core.py index 1dd36fce..1477248a 100644 --- a/tsdate/core.py +++ b/tsdate/core.py @@ -41,6 +41,7 @@ from . import approx from . import base from . import demography +from . import mixture from . import prior from . import provenance @@ -954,7 +955,12 @@ class ExpectationPropagation(InOutAlgorithms): Bayesian Inference" """ - def __init__(self, *args, **kwargs): + def __init__(self, *args, global_prior, **kwargs): + """ + `global_prior` contains parameters of a gamma mixture, used as an iid + prior, wheras `self.priors` (node-specific priors) are not used, except + to provide a list of nonfixed nodes. + """ super().__init__(*args, **kwargs) assert self.priors.probability_space == base.GAMMA_PAR @@ -962,24 +968,29 @@ def __init__(self, *args, **kwargs): assert self.lik.grid_size == 2 assert self.priors.timepoints.size == 2 + # global distribution of node ages + self.global_prior = global_prior.copy() + + # messages passed from prior to nodes + self.prior_messages = np.zeros((self.ts.num_nodes, 2)) + # mutation likelihoods, as gamma natural parameters self.likelihoods = np.zeros((self.ts.num_edges, 2)) for e in self.ts.edges(): self.likelihoods[e.id] = self.lik.to_natural(e) - # messages passed from factors to nodes + # messages passed from edge likelihoods to nodes self.messages = np.zeros((self.ts.num_edges, 2, 2)) - # normalizing constants from each factor + # normalizing constants from each edge likelihood self.log_partition = np.zeros(self.ts.num_edges) # the approximate posterior marginals self.posterior = np.zeros((self.ts.num_nodes, 2)) - for n in self.priors.nonfixed_nodes: - self.posterior[n] = self.priors[n] - # edge traversal order + # edge, node traversal order self.edges, self.leaves = self.factorize(self.ts.edges(), self.fixednodes) + self.freenodes = self.priors.nonfixed_nodes # factors for edges leading from fixed nodes are invariant # and can be incorporated into the posterior beforehand @@ -1013,7 +1024,7 @@ def factorize(edge_list, fixed_nodes): @staticmethod @numba.njit("f8(i4[:, :], f8[:, :], f8[:, :], f8[:, :, :], f8[:], f8[:], f8, b1)") - def propagate( + def propagate_likelihood( edges, likelihoods, posterior, @@ -1024,8 +1035,9 @@ def propagate( min_kl, ): """ - Update approximating factors for each edge, returning average relative - difference in natural parameters (TODO) + Update approximating factors for each edge. + + TODO: return max difference in natural parameters for stopping criterion :param ndarray edges: integer array of dimension `[num_edges, 3]` containing edge id, parent id, and child id. @@ -1045,13 +1057,20 @@ def propagate( assert max_shape >= 1.0 - upper = max_shape - 1.0 - lower = 1.0 / max_shape - 1.0 + # Bound the shape parameter for the posterior and cavity distributions + # so that lower_cavi < lower_post < upper_post < upper_cavi. + upper_post = max_shape - 1.0 + lower_post = 1.0 / max_shape - 1.0 + upper_cavi = 2.0 * max_shape - 1.0 + lower_cavi = 0.5 / max_shape - 1.0 def cavity_damping(x, y): + assert upper_cavi > x[0] > lower_cavi d = 1.0 - if (y[0] > 0.0) and (x[0] - y[0] < lower): - d = min(d, (x[0] - lower) / y[0]) + if (y[0] > 0.0) and (x[0] - y[0] < lower_cavi): + d = min(d, (x[0] - lower_cavi) / y[0]) + if (y[0] < 0.0) and (x[0] - y[0] > upper_cavi): + d = min(d, (x[0] - upper_cavi) / y[0]) if (y[1] > 0.0) and (x[1] - y[1] < 0.0): d = min(d, x[1] / y[1]) assert 0.0 < d <= 1.0 @@ -1059,7 +1078,11 @@ def cavity_damping(x, y): def posterior_damping(x): assert x[0] > -1.0 and x[1] > 0.0 - d = min(1.0, upper / abs(x[0])) if (x[0] > 0) else 1.0 + d = 1.0 + if x[0] > upper_post: + d = upper_post / x[0] + if x[0] < lower_post: + d = lower_post / x[0] assert 0.0 < d <= 1.0 return d @@ -1098,13 +1121,82 @@ def posterior_damping(x): return 0.0 # TODO, placeholder - def iterate(self, max_shape=1000, min_kl=True): + @staticmethod + @numba.njit("f8(i4[:], f8[:, :], f8[:, :], f8[:, :], f8[:], f8, i4, f8)") + def propagate_prior( + nodes, global_prior, posterior, messages, scale, max_shape, em_maxitt, em_reltol + ): + """ + Update approximating factors for global prior at each node. + + :param ndarray nodes: ids of nodes that should be updated + :param ndarray global_prior: rows are mixture components, columns are + zeroth, first, and second natural parameters of gamma mixture + components. Updated in place. + :param ndarray posterior: rows are nodes, columns are first and + second natural parameters of gamma posteriors. Updated in + place. + :param ndarray messages: rows are edges, columns are first and + second natural parameters of prior messages. Updated in place. + :param float max_shape: the maximum allowed shape for node posteriors + :param int em_maxitt: the maximum number of EM iterations to use when + fitting the mixture model + :param int em_reltol: the termination criterion for relative change in + log-likelihood + """ + + if global_prior.shape[0] == 0: + return 0.0 + + assert max_shape >= 1.0 + + upper = max_shape - 1.0 + lower = 1.0 / max_shape - 1.0 + + def posterior_damping(x): + assert x[0] > -1.0 and x[1] > 0.0 + d = 1.0 + if x[0] > upper: + d = upper / x[0] + if x[0] < lower: + d = lower / x[0] + assert 0.0 < d <= 1.0 + return d + + cavity = np.zeros(posterior.shape) + cavity[nodes] = posterior[nodes] - messages[nodes] * scale[nodes, np.newaxis] + global_prior, posterior[nodes] = mixture.fit_gamma_mixture( + global_prior, cavity[nodes], em_maxitt, em_reltol, False + ) + messages[nodes] = (posterior[nodes] - cavity[nodes]) / scale[nodes, np.newaxis] + + for n in nodes: + eta = posterior_damping(posterior[n]) + posterior[n] *= eta + scale[n] *= eta + + return 0.0 + + def iterate(self, em_maxitt=100, em_reltol=1e-6, max_shape=1000, min_kl=True): """ Update edge factors from leaves to root then from root to leaves, and return approximate log marginal likelihood (TODO) """ - self.propagate( + # prior update + self.propagate_prior( + self.freenodes, + self.global_prior, + self.posterior, + self.prior_messages, + self.scale, + max_shape, + em_maxitt, + em_reltol, + ) + + # rootward pass + self.propagate_likelihood( self.edges, self.likelihoods, self.posterior, @@ -1114,7 +1206,9 @@ def iterate(self, max_shape=1000, min_kl=True): max_shape, min_kl, ) - self.propagate( + + # leafward pass + self.propagate_likelihood( self.edges[::-1], self.likelihoods, self.posterior, @@ -1400,11 +1494,14 @@ class VariationalGammaMethod(EstimationMethod): def __init__(self, ts, **kwargs): super().__init__(ts, **kwargs) - # convert priors to natural parameterization and average + # convert priors to natural parameterization for n in self.priors.nonfixed_nodes: + if not np.all(self.priors[n] > 0.0): + raise ValueError( + f"Non-positive shape/rate parameters for node {n}: " + f"{self.priors[n]}" + ) self.priors[n][0] -= 1.0 - assert self.priors[n][0] > -1.0 - assert self.priors[n][1] >= 0.0 @staticmethod def mean_var(ts, posterior): @@ -1439,32 +1536,51 @@ def main_algorithm(self): self.recombination_rate, fixed_node_set=self.get_fixed_nodes_set(), ) - return ExpectationPropagation(self.priors, lik, progress=self.pbar) + return ExpectationPropagation( + self.priors, lik, progress=self.pbar, global_prior=self.global_prior + ) - def run(self, eps, max_iterations, max_shape, match_central_moments, global_prior): + def run( + self, + eps, + max_iterations, + max_shape, + match_central_moments, + prior_mixture_dim, + em_iterations, + ): if self.provenance_params is not None: self.provenance_params.update( {k: v for k, v in locals().items() if k != "self"} ) if not max_iterations >= 1: raise ValueError("Maximum number of EP iterations must be greater than 0") + if not (isinstance(prior_mixture_dim, int) and prior_mixture_dim > 0): + raise ValueError("Number of mixture components must be a positive integer") if self.mutation_rate is None: raise ValueError("Variational gamma method requires mutation rate") - if global_prior: - logging.info("Pooling node-specific priors into global prior") - self.priors.grid_data[:] = approx.average_gammas( - self.priors.grid_data[:, 0], self.priors.grid_data[:, 1] - ) + + # initialize weights/shapes/rates for i.i.d mixture prior + # note that self.priors (node-specific priors) are not currently + # used except for initialization of the mixture + self.global_prior = mixture.initialize_mixture( + self.priors.grid_data, prior_mixture_dim + ) + self.priors.grid_data[:] = [0.0, 0.0] # TODO: support node-specific priors # match sufficient statistics or match central moments min_kl = not match_central_moments dynamic_prog = self.main_algorithm() - for _ in tqdm( + for itt in tqdm( np.arange(max_iterations), desc="Expectation Propagation", disable=not self.pbar, ): - dynamic_prog.iterate(max_shape=max_shape, min_kl=min_kl) + dynamic_prog.iterate( + em_maxitt=em_iterations if itt else 0, + max_shape=max_shape, + min_kl=min_kl, + ) num_skipped = np.sum(np.isnan(dynamic_prog.log_partition)) if num_skipped > 0: @@ -1682,7 +1798,8 @@ def variational_gamma( max_iterations=None, max_shape=None, match_central_moments=None, - global_prior=True, + prior_mixture_dim=None, + em_iterations=None, **kwargs, ): """ @@ -1697,6 +1814,12 @@ def variational_gamma( new_ts = tsdate.variational_gamma( ts, mutation_rate=1e-8, population_size=1e4, max_iterations=10) + An i.i.d. gamma mixture is used as a prior for each node, that is + initialized from the conditional coalescent and updated via expectation + maximization in each iteration. If node-specific priors are supplied + (via a grid of shape/rate parameters) then these are used for + initialization. + .. note:: The prior parameters for each node-to-be-dated take the form of a gamma-distributed prior on node age, parameterised by shape and rate. @@ -1718,9 +1841,11 @@ def variational_gamma( update matches mean and variance rather than expected gamma sufficient statistics. Faster with a similar accuracy, but does not exactly minimize Kullback-Leibler divergence. Default: None, treated as False. - :param bool global_prior: If `True`, an iid prior is used for all nodes, - and is constructed by averaging gamma sufficient statistics over the free - nodes in ``priors``. Default: True. + :param int prior_mixture_dim: The number of components in the i.i.d. mixture prior + for node ages. Default: None, treated as 1. + :param int em_iterations: The number of expectation maximization iterations used + to optimize the i.i.d. mixture prior. Setting to zero disables optimization. + Default: None, treated as 10. :param \\**kwargs: Other keyword arguments as described in the :func:`date` wrapper function, notably ``mutation_rate``, and ``population_size`` or ``priors``. Further arguments include ``time_units``, ``progress``, and @@ -1752,6 +1877,10 @@ def variational_gamma( max_shape = 1000 if match_central_moments is None: match_central_moments = False + if prior_mixture_dim is None: + prior_mixture_dim = 1 + if em_iterations is None: + em_iterations = 10 dating_method = VariationalGammaMethod(tree_sequence, **kwargs) result = dating_method.run( @@ -1759,7 +1888,8 @@ def variational_gamma( max_iterations=max_iterations, max_shape=max_shape, match_central_moments=match_central_moments, - global_prior=global_prior, + prior_mixture_dim=prior_mixture_dim, + em_iterations=em_iterations, ) return dating_method.parse_result(result, eps, {"parameter": ["shape", "rate"]}) diff --git a/tsdate/mixture.py b/tsdate/mixture.py new file mode 100644 index 00000000..0eb99687 --- /dev/null +++ b/tsdate/mixture.py @@ -0,0 +1,244 @@ +# MIT License +# +# Copyright (c) 2021-23 Tskit Developers +# Copyright (c) 2020-21 University of Oxford +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +Mixture of gamma distributions that may be fit via EM to distribution-valued observations +""" +import numba +import numpy as np + +from . import approx +from . import hypergeo + + +@numba.njit("UniTuple(f8[:], 4)(f8[:], f8[:], f8[:], f8, f8)") +def _conditional_posterior(prior_logweight, prior_alpha, prior_beta, alpha, beta): + r""" + Return expectations of node age :math:`t` from the mixture model, + + ..math:: + + Ga(t | a, b) \sum_j \pi_j w_j Ga(t | \alpha_j, \beta_j) + + where :math:`a` and :math:`b` are variational parameters, + and :math:`\pi_j, \alpha_j, \beta_j` are prior weights and + parameters for a gamma mixture; and :math:`w_j` are fixed, + observation-specific weights. We use natural parameterization, + so that the shape parameter is :math:`\alpha + 1`. + + TODO: + The normalizing constants of the prior are assumed to have already + been integrated into `prior_weight`. + + Returns the contribution from each component to the + posterior expectations of :math:`E[1]`, :math:`E[t]`, :math:`E[log t]`, + and :math:`E[t log t]`. + + Note that :math:`E[1]` is *unnormalized* and *log-transformed*. + """ + + dim = prior_logweight.size + E = np.full(dim, -np.inf) # E[1] (e.g. normalizing constant) + E_t = np.zeros(dim) # E[t] + E_logt = np.zeros(dim) # E[log(t)] + E_tlogt = np.zeros(dim) # E[t * log(t)] + C = (alpha + 1) * np.log(beta) - hypergeo._gammaln(alpha + 1) if beta > 0 else 0.0 + for i in range(dim): + post_alpha = prior_alpha[i] + alpha + post_beta = prior_beta[i] + beta + if (post_alpha <= -1) or (post_beta <= 0): # skip node if divergent + E[:] = -np.inf + break + E[i] = C + ( + +hypergeo._gammaln(post_alpha + 1) + - (post_alpha + 1) * np.log(post_beta) + + prior_logweight[i] + ) + assert np.isfinite(E[i]) + # TODO: option to use moments instead of sufficient statistics? + E_t[i] = (post_alpha + 1) / post_beta + E_logt[i] = hypergeo._digamma(post_alpha + 1) - np.log(post_beta) + E_tlogt[i] = E_t[i] * E_logt[i] + E_t[i] / (post_alpha + 1) + + return E, E_t, E_logt, E_tlogt + + +@numba.njit("f8(f8[:], f8[:], f8[:], f8[:], f8[:])") +def _em_update(prior_weight, prior_alpha, prior_beta, alpha, beta): + """ + Perform an expectation maximization step for parameters of mixture components, + given variational parameters `alpha`, `beta` for each node. + + The maximization step is performed using Ye & Chen (2017) "Closed form + estimators for the gamma distribution ..." + + ``prior_weight``, ``prior_alpha``, ``prior_beta`` are updated in place. + """ + assert alpha.size == beta.size + + dim = prior_weight.size + n = np.zeros(dim) + t = np.zeros(dim) + logt = np.zeros(dim) + tlogt = np.zeros(dim) + + # incorporate prior normalizing constants into weights + prior_logweight = np.log(prior_weight) + for k in range(dim): + prior_logweight[k] += (prior_alpha[k] + 1) * np.log( + prior_beta[k] + ) - hypergeo._gammaln(prior_alpha[k] + 1) + + # expectation step: + loglik = 0.0 + for a, b in zip(alpha, beta): + E, E_t, E_logt, E_tlogt = _conditional_posterior( + prior_logweight, prior_alpha, prior_beta, a, b + ) + + # skip if posterior is improper + if np.any(np.isinf(E)): + continue + + # convert evidence to posterior weights + norm_const = np.log(np.sum(np.exp(E - np.max(E)))) + np.max(E) + weight = np.exp(E - norm_const) + + # weighted contributions to sufficient statistics + loglik += norm_const + n += weight + t += E_t * weight + logt += E_logt * weight + tlogt += E_tlogt * weight + + # maximization step: update parameters in place + prior_weight[:] = n / np.sum(n) + prior_beta[:] = n**2 / (n * tlogt - t * logt) + prior_alpha[:] = n * t / (n * tlogt - t * logt) - 1.0 + + return loglik + + +@numba.njit("f8[:](f8[:], f8[:], f8[:], f8[:], f8[:])") +def _gamma_projection(prior_weight, prior_alpha, prior_beta, alpha, beta): + """ + Given variational approximation to posterior: multiply by exact prior, + calculate sufficient statistics, and moment match to get new + approximate posterior. + + Updates ``alpha`` and ``beta`` in-place. + """ + assert alpha.size == beta.size + + dim = prior_weight.size + + # incorporate prior normalizing constants into weights + prior_logweight = np.log(prior_weight) + for k in range(dim): + prior_logweight[k] += (prior_alpha[k] + 1) * np.log( + prior_beta[k] + ) - hypergeo._gammaln(prior_alpha[k] + 1) + + log_const = np.full(alpha.size, -np.inf) + for i in range(alpha.size): + E, E_t, E_logt, E_tlogt = _conditional_posterior( + prior_logweight, prior_alpha, prior_beta, alpha[i], beta[i] + ) + + # skip if posterior is improper for all components + if np.any(np.isinf(E)): + continue + + norm = np.log(np.sum(np.exp(E - np.max(E)))) + np.max(E) + weight = np.exp(E - norm) + t = np.sum(weight * E_t) + logt = np.sum(weight * E_logt) + # tlogt = np.sum(weight * E_tlogt) + log_const[i] = norm + alpha[i], beta[i] = approx.approximate_gamma_kl(t, logt) + # beta[i] = 1 / (tlogt - t * logt) + # alpha[i] = t * beta[i] - 1.0 + + return log_const + + +@numba.njit("Tuple((f8[:,:], f8[:,:]))(f8[:,:], f8[:,:], i4, f8, b1)") +def fit_gamma_mixture(mixture, observations, max_iterations, tolerance, verbose): + """ + Run EM until relative tolerance or maximum number of iterations is + reached. Then, perform expectation-propagation update and return new + variational parameters for the posterior approximation. + """ + + assert mixture.shape[1] == 3 + assert observations.shape[1] == 2 + + mix_weight, mix_alpha, mix_beta = mixture.T + alpha, beta = observations.T + + last = np.inf + for itt in range(max_iterations): + loglik = _em_update(mix_weight, mix_alpha, mix_beta, alpha, beta) + loglik /= float(alpha.size) + update = np.abs(loglik - last) + last = loglik + if verbose: + print("EM iteration:", itt, "mean(loglik):", np.round(loglik, 5)) + print(" -> weights:", mix_weight) + print(" -> alpha:", mix_alpha) + print(" -> beta:", mix_beta) + if update < np.abs(loglik) * tolerance: + break + + # conditional posteriors for each observation + # log_const = _gamma_projection(mix_weight, mix_alpha, mix_beta, alpha, beta) + _gamma_projection(mix_weight, mix_alpha, mix_beta, alpha, beta) + + new_mixture = np.zeros(mixture.shape) + new_observations = np.zeros(observations.shape) + new_observations[:, 0] = alpha + new_observations[:, 1] = beta + new_mixture[:, 0] = mix_weight + new_mixture[:, 1] = mix_alpha + new_mixture[:, 2] = mix_beta + + return new_mixture, new_observations + + +def initialize_mixture(parameters, num_components): + """ + Initialize clusters by dividing nodes into equal groups. + "parameters" are in natural parameterization (not shape/rate) + """ + global_prior = np.empty((num_components, 3)) + num_nodes = parameters.shape[0] + age_classes = np.tile( + np.arange(num_components), + num_nodes // num_components + 1, + )[:num_nodes] + for k in range(num_components): + indices = np.equal(age_classes, k) + alpha, beta = approx.average_gammas( + parameters[indices, 0], parameters[indices, 1] + ) + global_prior[k] = [1.0 / num_components, alpha, beta] + return global_prior diff --git a/tsdate/prior.py b/tsdate/prior.py index 72a7087e..33586b26 100644 --- a/tsdate/prior.py +++ b/tsdate/prior.py @@ -30,6 +30,8 @@ import numba import numpy as np +import scipy.cluster +import scipy.special import scipy.stats import tskit from tqdm.auto import tqdm