Skip to content

Commit

Permalink
Set defaults for most things to None
Browse files Browse the repository at this point in the history
  • Loading branch information
hyanwong committed Jan 4, 2024
1 parent d319006 commit aaaf962
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 15 deletions.
8 changes: 6 additions & 2 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1629,7 +1629,9 @@ def test_node_metadata_simulated_tree(self):
algorithm = InsideOutsideMethod(
larger_ts, mutation_rate=None, population_size=10000
)
mn_post, *_ = algorithm.run(eps=1e-6, outside_standardize=True)
mn_post, *_ = algorithm.run(
eps=1e-6, outside_standardize=True, probability_space=tsdate.base.LOG
)
dated_ts = date(larger_ts, population_size=10000, mutation_rate=None)
metadata = dated_ts.tables.nodes.metadata
metadata_offset = dated_ts.tables.nodes.metadata_offset
Expand Down Expand Up @@ -1843,7 +1845,9 @@ def test_sites_time_insideoutside(self):
ts = utility_functions.two_tree_mutation_ts()
dated = tsdate.date(ts, mutation_rate=None, population_size=1)
algorithm = InsideOutsideMethod(ts, mutation_rate=None, population_size=1)
mn_post, *_ = algorithm.run(eps=1e-6, outside_standardize=True)
mn_post, *_ = algorithm.run(
eps=1e-6, outside_standardize=True, probability_space=tsdate.base.LOG
)
assert np.array_equal(
mn_post[ts.tables.mutations.node],
tsdate.sites_time_from_ts(dated, unconstrained=True, min_time=0),
Expand Down
42 changes: 29 additions & 13 deletions tsdate/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1327,8 +1327,6 @@ def mean_var(ts, posterior):
return mn_post, va_post

def setup(self, probability_space, num_threads, cache_inside):
if probability_space is None:
probability_space = base.LOG
if probability_space != base.LOG:
liklhd = Likelihoods(
self.ts,
Expand Down Expand Up @@ -1504,7 +1502,7 @@ def run(self, eps, max_iterations, max_shape, match_central_moments, global_prio
def maximization(
tree_sequence,
*,
eps=1e-6,
eps=None,
num_threads=None,
cache_inside=None,
probability_space=None,
Expand Down Expand Up @@ -1546,7 +1544,7 @@ def maximization(
:param ~tskit.TreeSequence tree_sequence: The input tree sequence to be dated.
:param float eps: The error factor in time difference calculations, and the
minimum distance separating parent and child ages in the returned tree sequence.
Default: 1e-6.
Default: None, treated as 1e-6.
:param int num_threads: The number of threads to use when precalculating likelihoods.
A simpler unthreaded algorithm is used unless this is >= 1. Default: None
:param bool ignore_oldest_root: Should the oldest root in the tree sequence be
Expand All @@ -1555,7 +1553,7 @@ def maximization(
inferred from real data. Default: False
:param string probability_space: Should the internal algorithm save
probabilities in "logarithmic" (slower, less liable to to overflow) or
"linear" space (fast, may overflow). Default: "logarithmic"
"linear" space (fast, may overflow). Default: None treated as"logarithmic"
:param \\**kwargs: Other keyword arguments as described in the :func:`date` wrapper
function, notably ``mutation_rate``, and ``population_size`` or ``priors``.
Further arguments include ``time_units``, ``progress``, and
Expand All @@ -1571,6 +1569,11 @@ def maximization(
``return_likelihood`` is ``True``) The marginal likelihood of
the mutation data given the inferred node times.
"""
if eps is None:
eps = 1e-6
if probability_space is None:
probability_space = base.LOG

algorithm = MaximizationMethod(tree_sequence, **kwargs)
result = algorithm.run(
eps=eps,
Expand Down Expand Up @@ -1661,6 +1664,10 @@ def inside_outside(
``return_likelihood`` is ``True``) The marginal likelihood of
the mutation data given the inferred node times.
"""
if eps is None:
eps = 1e-6
if probability_space is None:
probability_space = base.LOG
algorithm = InsideOutsideMethod(tree_sequence, **kwargs)
result = algorithm.run(
eps=eps,
Expand All @@ -1676,10 +1683,10 @@ def inside_outside(
def variational_gamma(
tree_sequence,
*,
eps=1e-6,
max_iterations=20,
max_shape=1000,
match_central_moments=False,
eps=None,
max_iterations=None,
max_shape=None,
match_central_moments=None,
global_prior=True,
**kwargs,
):
Expand All @@ -1706,16 +1713,16 @@ def variational_gamma(
:param ~tskit.TreeSequence tree_sequence: The input tree sequence to be dated.
:param float eps: The minimum distance separating parent and child ages in
the returned tree sequence. Default: 1e-6
the returned tree sequence. Default: None, treated as 1e-6
:param int max_iterations: The number of iterations used in the expectation
propagation algorithm. Default: 20.
propagation algorithm. Default: None, treated as 20.
:param float max_shape: The maximum value for the shape parameter in the variational
posteriors. This is equivalent to the maximum precision (inverse variance) on a
logarithmic scale. Default: 1000.
logarithmic scale. Default: None, treated as 1000.
:param bool match_central_moments: If `True`, each expectation propgation
update matches mean and variance rather than expected gamma sufficient
statistics. Faster with a similar accuracy, but does not exactly minimize
Kullback-Leibler divergence. Default: False.
Kullback-Leibler divergence. Default: None, treated as False.
:param bool global_prior: If `True`, an iid prior is used for all nodes,
and is constructed by averaging gamma sufficient statistics over the free
nodes in ``priors``. Default: True.
Expand All @@ -1741,6 +1748,15 @@ def variational_gamma(
``return_likelihood`` is ``True``) The marginal likelihood of
the mutation data given the inferred node times.
"""
if eps is None:
eps = 1e-6
if max_iterations is None:
max_iterations = 20
if max_shape is None:
max_shape = 1000
if match_central_moments is None:
match_central_moments = False

algorithm = VariationalGammaMethod(tree_sequence, **kwargs)
result = algorithm.run(
eps=eps,
Expand Down

0 comments on commit aaaf962

Please sign in to comment.