From aaaf96287da972c2917d82a9838676a61ad378b2 Mon Sep 17 00:00:00 2001 From: Yan Wong Date: Thu, 4 Jan 2024 23:17:29 +0000 Subject: [PATCH] Set defaults for most things to None --- tests/test_functions.py | 8 ++++++-- tsdate/core.py | 42 ++++++++++++++++++++++++++++------------- 2 files changed, 35 insertions(+), 15 deletions(-) diff --git a/tests/test_functions.py b/tests/test_functions.py index d9ee0677..776db07b 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -1629,7 +1629,9 @@ def test_node_metadata_simulated_tree(self): algorithm = InsideOutsideMethod( larger_ts, mutation_rate=None, population_size=10000 ) - mn_post, *_ = algorithm.run(eps=1e-6, outside_standardize=True) + mn_post, *_ = algorithm.run( + eps=1e-6, outside_standardize=True, probability_space=tsdate.base.LOG + ) dated_ts = date(larger_ts, population_size=10000, mutation_rate=None) metadata = dated_ts.tables.nodes.metadata metadata_offset = dated_ts.tables.nodes.metadata_offset @@ -1843,7 +1845,9 @@ def test_sites_time_insideoutside(self): ts = utility_functions.two_tree_mutation_ts() dated = tsdate.date(ts, mutation_rate=None, population_size=1) algorithm = InsideOutsideMethod(ts, mutation_rate=None, population_size=1) - mn_post, *_ = algorithm.run(eps=1e-6, outside_standardize=True) + mn_post, *_ = algorithm.run( + eps=1e-6, outside_standardize=True, probability_space=tsdate.base.LOG + ) assert np.array_equal( mn_post[ts.tables.mutations.node], tsdate.sites_time_from_ts(dated, unconstrained=True, min_time=0), diff --git a/tsdate/core.py b/tsdate/core.py index 108af99d..9fea4d0a 100644 --- a/tsdate/core.py +++ b/tsdate/core.py @@ -1327,8 +1327,6 @@ def mean_var(ts, posterior): return mn_post, va_post def setup(self, probability_space, num_threads, cache_inside): - if probability_space is None: - probability_space = base.LOG if probability_space != base.LOG: liklhd = Likelihoods( self.ts, @@ -1504,7 +1502,7 @@ def run(self, eps, max_iterations, max_shape, match_central_moments, global_prio def maximization( tree_sequence, *, - eps=1e-6, + eps=None, num_threads=None, cache_inside=None, probability_space=None, @@ -1546,7 +1544,7 @@ def maximization( :param ~tskit.TreeSequence tree_sequence: The input tree sequence to be dated. :param float eps: The error factor in time difference calculations, and the minimum distance separating parent and child ages in the returned tree sequence. - Default: 1e-6. + Default: None, treated as 1e-6. :param int num_threads: The number of threads to use when precalculating likelihoods. A simpler unthreaded algorithm is used unless this is >= 1. Default: None :param bool ignore_oldest_root: Should the oldest root in the tree sequence be @@ -1555,7 +1553,7 @@ def maximization( inferred from real data. Default: False :param string probability_space: Should the internal algorithm save probabilities in "logarithmic" (slower, less liable to to overflow) or - "linear" space (fast, may overflow). Default: "logarithmic" + "linear" space (fast, may overflow). Default: None treated as"logarithmic" :param \\**kwargs: Other keyword arguments as described in the :func:`date` wrapper function, notably ``mutation_rate``, and ``population_size`` or ``priors``. Further arguments include ``time_units``, ``progress``, and @@ -1571,6 +1569,11 @@ def maximization( ``return_likelihood`` is ``True``) The marginal likelihood of the mutation data given the inferred node times. """ + if eps is None: + eps = 1e-6 + if probability_space is None: + probability_space = base.LOG + algorithm = MaximizationMethod(tree_sequence, **kwargs) result = algorithm.run( eps=eps, @@ -1661,6 +1664,10 @@ def inside_outside( ``return_likelihood`` is ``True``) The marginal likelihood of the mutation data given the inferred node times. """ + if eps is None: + eps = 1e-6 + if probability_space is None: + probability_space = base.LOG algorithm = InsideOutsideMethod(tree_sequence, **kwargs) result = algorithm.run( eps=eps, @@ -1676,10 +1683,10 @@ def inside_outside( def variational_gamma( tree_sequence, *, - eps=1e-6, - max_iterations=20, - max_shape=1000, - match_central_moments=False, + eps=None, + max_iterations=None, + max_shape=None, + match_central_moments=None, global_prior=True, **kwargs, ): @@ -1706,16 +1713,16 @@ def variational_gamma( :param ~tskit.TreeSequence tree_sequence: The input tree sequence to be dated. :param float eps: The minimum distance separating parent and child ages in - the returned tree sequence. Default: 1e-6 + the returned tree sequence. Default: None, treated as 1e-6 :param int max_iterations: The number of iterations used in the expectation - propagation algorithm. Default: 20. + propagation algorithm. Default: None, treated as 20. :param float max_shape: The maximum value for the shape parameter in the variational posteriors. This is equivalent to the maximum precision (inverse variance) on a - logarithmic scale. Default: 1000. + logarithmic scale. Default: None, treated as 1000. :param bool match_central_moments: If `True`, each expectation propgation update matches mean and variance rather than expected gamma sufficient statistics. Faster with a similar accuracy, but does not exactly minimize - Kullback-Leibler divergence. Default: False. + Kullback-Leibler divergence. Default: None, treated as False. :param bool global_prior: If `True`, an iid prior is used for all nodes, and is constructed by averaging gamma sufficient statistics over the free nodes in ``priors``. Default: True. @@ -1741,6 +1748,15 @@ def variational_gamma( ``return_likelihood`` is ``True``) The marginal likelihood of the mutation data given the inferred node times. """ + if eps is None: + eps = 1e-6 + if max_iterations is None: + max_iterations = 20 + if max_shape is None: + max_shape = 1000 + if match_central_moments is None: + match_central_moments = False + algorithm = VariationalGammaMethod(tree_sequence, **kwargs) result = algorithm.run( eps=eps,