Skip to content

Commit

Permalink
Set defaults for most things to None
Browse files Browse the repository at this point in the history
  • Loading branch information
hyanwong committed Jan 4, 2024
1 parent d319006 commit 46366c6
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 19 deletions.
8 changes: 6 additions & 2 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1629,7 +1629,9 @@ def test_node_metadata_simulated_tree(self):
algorithm = InsideOutsideMethod(
larger_ts, mutation_rate=None, population_size=10000
)
mn_post, *_ = algorithm.run(eps=1e-6, outside_standardize=True)
mn_post, *_ = algorithm.run(
eps=1e-6, outside_standardize=True, probability_space=tsdate.base.LOG
)
dated_ts = date(larger_ts, population_size=10000, mutation_rate=None)
metadata = dated_ts.tables.nodes.metadata
metadata_offset = dated_ts.tables.nodes.metadata_offset
Expand Down Expand Up @@ -1843,7 +1845,9 @@ def test_sites_time_insideoutside(self):
ts = utility_functions.two_tree_mutation_ts()
dated = tsdate.date(ts, mutation_rate=None, population_size=1)
algorithm = InsideOutsideMethod(ts, mutation_rate=None, population_size=1)
mn_post, *_ = algorithm.run(eps=1e-6, outside_standardize=True)
mn_post, *_ = algorithm.run(
eps=1e-6, outside_standardize=True, probability_space=tsdate.base.LOG
)
assert np.array_equal(
mn_post[ts.tables.mutations.node],
tsdate.sites_time_from_ts(dated, unconstrained=True, min_time=0),
Expand Down
50 changes: 33 additions & 17 deletions tsdate/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1327,8 +1327,6 @@ def mean_var(ts, posterior):
return mn_post, va_post

def setup(self, probability_space, num_threads, cache_inside):
if probability_space is None:
probability_space = base.LOG
if probability_space != base.LOG:
liklhd = Likelihoods(
self.ts,
Expand Down Expand Up @@ -1504,14 +1502,14 @@ def run(self, eps, max_iterations, max_shape, match_central_moments, global_prio
def maximization(
tree_sequence,
*,
eps=1e-6,
eps=None,
num_threads=None,
cache_inside=None,
probability_space=None,
**kwargs,
):
"""
Infer dates for nodes in a genealogical graph using the "outside maximuzation"
Infer dates for nodes in a genealogical graph using the "outside maximization"
algorithm. This approximates the marginal posterior distribution of a node's
age using an atomic discretization of time (e.g. point masses at particular
timepoints).
Expand All @@ -1523,7 +1521,7 @@ def maximization(
on each edge). The outside maximization step passes forwards in time from the roots,
updating each node's time on the basis of the most likely timepoint for
each parent of that node. This provides a reasonable point estimate for node times,
but does not generaten a true posterior time distribution.
but does not generate a true posterior time distribution.
For example:
Expand All @@ -1546,7 +1544,7 @@ def maximization(
:param ~tskit.TreeSequence tree_sequence: The input tree sequence to be dated.
:param float eps: The error factor in time difference calculations, and the
minimum distance separating parent and child ages in the returned tree sequence.
Default: 1e-6.
Default: None, treated as 1e-6.
:param int num_threads: The number of threads to use when precalculating likelihoods.
A simpler unthreaded algorithm is used unless this is >= 1. Default: None
:param bool ignore_oldest_root: Should the oldest root in the tree sequence be
Expand All @@ -1555,7 +1553,7 @@ def maximization(
inferred from real data. Default: False
:param string probability_space: Should the internal algorithm save
probabilities in "logarithmic" (slower, less liable to to overflow) or
"linear" space (fast, may overflow). Default: "logarithmic"
"linear" space (fast, may overflow). Default: None treated as"logarithmic"
:param \\**kwargs: Other keyword arguments as described in the :func:`date` wrapper
function, notably ``mutation_rate``, and ``population_size`` or ``priors``.
Further arguments include ``time_units``, ``progress``, and
Expand All @@ -1571,6 +1569,11 @@ def maximization(
``return_likelihood`` is ``True``) The marginal likelihood of
the mutation data given the inferred node times.
"""
if eps is None:
eps = 1e-6
if probability_space is None:
probability_space = base.LOG

algorithm = MaximizationMethod(tree_sequence, **kwargs)
result = algorithm.run(
eps=eps,
Expand Down Expand Up @@ -1661,6 +1664,10 @@ def inside_outside(
``return_likelihood`` is ``True``) The marginal likelihood of
the mutation data given the inferred node times.
"""
if eps is None:
eps = 1e-6
if probability_space is None:
probability_space = base.LOG
algorithm = InsideOutsideMethod(tree_sequence, **kwargs)
result = algorithm.run(
eps=eps,
Expand All @@ -1676,18 +1683,18 @@ def inside_outside(
def variational_gamma(
tree_sequence,
*,
eps=1e-6,
max_iterations=20,
max_shape=1000,
match_central_moments=False,
eps=None,
max_iterations=None,
max_shape=None,
match_central_moments=None,
global_prior=True,
**kwargs,
):
"""
Infer dates for nodes in a tree sequence using expectation propagation,
which approximates the marginal posterior distribution of a given node's
age with a gamma distribution. Convergence to the correct posteriors is
obtained by updating the distributions for node dates using several rounds
age with a gamma distribution. Convergence to the correct posterior moments
is obtained by updating the distributions for node dates using several rounds
of iteration. For example:
.. code-block:: python
Expand All @@ -1706,16 +1713,16 @@ def variational_gamma(
:param ~tskit.TreeSequence tree_sequence: The input tree sequence to be dated.
:param float eps: The minimum distance separating parent and child ages in
the returned tree sequence. Default: 1e-6
the returned tree sequence. Default: None, treated as 1e-6
:param int max_iterations: The number of iterations used in the expectation
propagation algorithm. Default: 20.
propagation algorithm. Default: None, treated as 20.
:param float max_shape: The maximum value for the shape parameter in the variational
posteriors. This is equivalent to the maximum precision (inverse variance) on a
logarithmic scale. Default: 1000.
logarithmic scale. Default: None, treated as 1000.
:param bool match_central_moments: If `True`, each expectation propgation
update matches mean and variance rather than expected gamma sufficient
statistics. Faster with a similar accuracy, but does not exactly minimize
Kullback-Leibler divergence. Default: False.
Kullback-Leibler divergence. Default: None, treated as False.
:param bool global_prior: If `True`, an iid prior is used for all nodes,
and is constructed by averaging gamma sufficient statistics over the free
nodes in ``priors``. Default: True.
Expand All @@ -1741,6 +1748,15 @@ def variational_gamma(
``return_likelihood`` is ``True``) The marginal likelihood of
the mutation data given the inferred node times.
"""
if eps is None:
eps = 1e-6
if max_iterations is None:
max_iterations = 20
if max_shape is None:
max_shape = 1000
if match_central_moments is None:
match_central_moments = False

algorithm = VariationalGammaMethod(tree_sequence, **kwargs)
result = algorithm.run(
eps=eps,
Expand Down

0 comments on commit 46366c6

Please sign in to comment.