Skip to content

Commit

Permalink
Merge pull request #350 from nspope/refactor-api
Browse files Browse the repository at this point in the history
Refactor/redocument API
  • Loading branch information
hyanwong authored Dec 29, 2023
2 parents 50f53b4 + 9c2c977 commit 343dd63
Show file tree
Hide file tree
Showing 10 changed files with 489 additions and 423 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@
`action="count"`, so `-v` turns verbosity to INFO level,
whereas `-vv` turns verbosity to DEBUG level.

- The `return_posteriors=True` option with `method="inside_outside"`
previously returned a dict that included keys `start_time` and `end_time`,
giving the impression that the posterior for node age is discretized over
time slices in this algorithm. In actuality, the posterior is discretized
atomically over time points, so `start_time` and `end_time` have been
replaced by a single key `time`.

- Python 3.7 is no longer supported.

**Features**
Expand Down
2 changes: 1 addition & 1 deletion docs/priors.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ See below for more explanation of the interpretation of the parameters passed to
For {ref}`sec_methods_discrete_time` methods, it is possible to switch from the
(default) lognormal approximation to a gamma distribution, used when building a
mixture prior for nodes that have variable numbers of children in different
genomic regions. The discretized prior is then constructed by evaluating the
genomic regions. The discretised prior is then constructed by evaluating the
lognormal (or gamma) distribution across a grid of fixed times. Tests have shown that the
lognormal is usually a better fit to the true prior in most cases.

Expand Down
2 changes: 2 additions & 0 deletions docs/python-api.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ This page provides formal documentation for the `tsdate` Python API.

```{eval-rst}
.. autofunction:: tsdate.date
.. autofunction:: tsdate.discretised_dates
.. autofunction:: tsdate.variational_dates
```

## Prior and Time Discretisation Options
Expand Down
2 changes: 1 addition & 1 deletion docs/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ different stages of dating will take.

If the {ref}`method<sec_methods>` used for dating involves discrete time slices, `tsdate` scales
quadratically in the number of time slices chosen. For greater temporal resolution,
you are thus advised to use the `variational_gamma` method, which does not discretize time.
you are thus advised to use the `variational_gamma` method, which does not discretise time.

#### Optimisations

Expand Down
50 changes: 12 additions & 38 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@
from tsdate import base
from tsdate.core import constrain_ages_topo
from tsdate.core import date
from tsdate.core import get_dates
from tsdate.core import discretised_dates
from tsdate.core import discretised_mean_var
from tsdate.core import InOutAlgorithms
from tsdate.core import Likelihoods
from tsdate.core import LogLikelihoods
from tsdate.core import posterior_mean_var
from tsdate.core import variational_dates
from tsdate.core import VariationalLikelihoods
from tsdate.demography import PopulationSizeHistory
Expand Down Expand Up @@ -500,7 +500,7 @@ def test_custom_timegrid_is_not_rescaled(self):
prior = MixturePrior(ts)
demography = PopulationSizeHistory(3)
timepoints = np.array([0, 300, 1000, 2000])
prior_grid = prior.make_discretized_prior(demography, timepoints=timepoints)
prior_grid = prior.make_discretised_prior(demography, timepoints=timepoints)
assert np.array_equal(prior_grid.timepoints, timepoints)


Expand Down Expand Up @@ -1602,35 +1602,16 @@ def test_bad_Ne(self):
tsdate.build_prior_grid(ts, population_size=-10)


class TestCallingErrors:
def test_bad_vgamma_probability_space(self):
ts = utility_functions.single_tree_ts_n2()
with pytest.raises(ValueError, match="Cannot specify"):
variational_dates(ts, 1, 1, probability_space=base.LOG)

def test_bad_vgamma_num_threads(self):
# Test can be removed if we specify num_threads in the future
ts = utility_functions.single_tree_ts_n2()
with pytest.raises(ValueError, match="does not currently"):
variational_dates(ts, 1, 1, num_threads=2)

def test_bad_vgamma_ignore_oldest_root(self):
# Test can be removed in the future if this is implemented
ts = utility_functions.single_tree_ts_n2()
with pytest.raises(ValueError, match="not implemented"):
variational_dates(ts, 1, 1, ignore_oldest_root=True)


class TestPosteriorMeanVar:
class TestDiscretisedMeanVar:
"""
Test posterior_mean_var works as expected
Test discretised_mean_var works as expected
"""

def test_posterior_mean_var(self):
def test_discretised_mean_var(self):
ts = utility_functions.single_tree_ts_n2()
for distr in ("gamma", "lognorm"):
posterior, algo = TestTotalFunctionalValueTree().find_posterior(ts, distr)
ts_node_metadata, mn_post, vr_post = posterior_mean_var(ts, posterior)
mn_post, vr_post = discretised_mean_var(ts, posterior)
assert np.array_equal(
mn_post,
[
Expand All @@ -1640,19 +1621,12 @@ def test_posterior_mean_var(self):
],
)

def test_node_metadata_single_tree_n2(self):
ts = utility_functions.single_tree_ts_n2()
posterior, algo = TestTotalFunctionalValueTree().find_posterior(ts, "lognorm")
ts_node_metadata, mn_post, vr_post = posterior_mean_var(ts, posterior)
assert json.loads(ts_node_metadata.node(2).metadata)["mn"] == mn_post[2]
assert json.loads(ts_node_metadata.node(2).metadata)["vr"] == vr_post[2]

def test_node_metadata_simulated_tree(self):
larger_ts = msprime.simulate(
10, mutation_rate=1, recombination_rate=1, length=20, random_seed=12
)
_, mn_post, _, _, eps, _, _ = get_dates(
larger_ts, mutation_rate=None, population_size=10000
mn_post, *_ = discretised_dates(
larger_ts, mutation_rate=None, population_size=10000, eps=1e-6
)
dated_ts = date(larger_ts, population_size=10000, mutation_rate=None)
metadata = dated_ts.tables.nodes.metadata
Expand Down Expand Up @@ -1866,8 +1840,8 @@ def test_node_selection_param(self):
def test_sites_time_insideoutside(self):
ts = utility_functions.two_tree_mutation_ts()
dated = tsdate.date(ts, mutation_rate=None, population_size=1)
_, mn_post, _, _, eps, _, _ = get_dates(
ts, mutation_rate=None, population_size=1
mn_post, *_ = discretised_dates(
ts, mutation_rate=None, population_size=1, eps=1e-6
)
assert np.array_equal(
mn_post[ts.tables.mutations.node],
Expand Down Expand Up @@ -1971,7 +1945,7 @@ def test_sites_time_simulated(self):
larger_ts = msprime.simulate(
10, mutation_rate=1, recombination_rate=1, length=20, random_seed=12
)
_, mn_post, _, _, _, _, _ = get_dates(
mn_post, *_ = discretised_dates(
larger_ts, mutation_rate=None, population_size=10000
)
dated = date(larger_ts, mutation_rate=None, population_size=10000)
Expand Down
38 changes: 24 additions & 14 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,27 +128,37 @@ def test_no_posteriors(self):
method="maximization",
mutation_rate=1,
)
assert len(posteriors) == ts.num_nodes - ts.num_samples + 2
assert len(posteriors["start_time"]) == len(posteriors["end_time"])
assert len(posteriors["start_time"]) > 0
assert posteriors is None

def test_discretised_posteriors(self):
ts = utility_functions.two_tree_mutation_ts()
ts, posteriors = tsdate.date(
ts, mutation_rate=None, population_size=1, return_posteriors=True
)
assert len(posteriors) == ts.num_nodes - ts.num_samples + 1
assert len(posteriors["time"]) > 0
for node in ts.nodes():
if not node.is_sample():
assert node.id in posteriors
assert posteriors[node.id] is None
assert len(posteriors[node.id]) == len(posteriors["time"])
assert np.isclose(np.sum(posteriors[node.id]), 1)

def test_posteriors(self):
def test_variational_posteriors(self):
ts = utility_functions.two_tree_mutation_ts()
ts, posteriors = tsdate.date(
ts, mutation_rate=None, population_size=1, return_posteriors=True
ts,
mutation_rate=1e-2,
population_size=1,
method="variational_gamma",
return_posteriors=True,
)
assert len(posteriors) == ts.num_nodes - ts.num_samples + 2
assert len(posteriors["start_time"]) == len(posteriors["end_time"])
assert len(posteriors["start_time"]) > 0
assert len(posteriors) == ts.num_nodes - ts.num_samples + 1
assert len(posteriors["parameter"]) == 2
for node in ts.nodes():
if not node.is_sample():
assert node.id in posteriors
assert len(posteriors[node.id]) == len(posteriors["start_time"])
assert np.isclose(np.sum(posteriors[node.id]), 1)
assert len(posteriors[node.id]) == 2
assert np.all(posteriors[node.id] > 0)

def test_marginal_likelihood(self):
ts = utility_functions.two_tree_mutation_ts()
Expand Down Expand Up @@ -419,7 +429,7 @@ def test_nonglobal_priors(self):

def test_bad_arguments(self):
ts = utility_functions.two_tree_mutation_ts()
with pytest.raises(ValueError, match="Maximum number of iterations"):
with pytest.raises(ValueError, match="Maximum number of EP iterations"):
tsdate.date(
ts,
mutation_rate=5,
Expand All @@ -435,13 +445,13 @@ def test_match_central_moments(self):
mutation_rate=5,
population_size=1,
method="variational_gamma",
method_of_moments=False,
match_central_moments=False,
)
ts1 = tsdate.date(
ts,
mutation_rate=5,
population_size=1,
method="variational_gamma",
method_of_moments=True,
match_central_moments=True,
)
assert np.any(np.not_equal(ts0.nodes_time, ts1.nodes_time))
3 changes: 2 additions & 1 deletion tsdate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
# SOFTWARE.
from .cache import * # NOQA: F401,F403
from .core import date # NOQA: F401
from .core import get_dates # NOQA: F401
from .core import discretised_dates # NOQA: F401
from .core import variational_dates # NOQA: F401
from .prior import build_grid as build_prior_grid # NOQA: F401
from .prior import parameter_grid as build_parameter_grid # NOQA: F401
from .provenance import __version__ # NOQA: F401
Expand Down
Loading

0 comments on commit 343dd63

Please sign in to comment.