Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set metadata schemas #303

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 37 additions & 25 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
Test cases for the python API for tsdate.
"""
import collections
import json
import logging
import unittest

Expand Down Expand Up @@ -1619,7 +1618,9 @@ def test_posterior_mean_var(self):
ts = utility_functions.single_tree_ts_n2()
for distr in ("gamma", "lognorm"):
posterior, algo = TestTotalFunctionalValueTree().find_posterior(ts, distr)
ts_node_metadata, mn_post, vr_post = posterior_mean_var(ts, posterior)
ts_node_metadata, mn_post, vr_post = posterior_mean_var(
ts, posterior, save_metadata=False
)
assert np.array_equal(
mn_post,
[
Expand All @@ -1631,31 +1632,35 @@ def test_posterior_mean_var(self):

def test_node_metadata_single_tree_n2(self):
ts = utility_functions.single_tree_ts_n2()
tables = ts.dump_tables()
tables.nodes.metadata_schema = tskit.MetadataSchema.permissive_json()
ts = tables.tree_sequence()
posterior, algo = TestTotalFunctionalValueTree().find_posterior(ts, "lognorm")
ts_node_metadata, mn_post, vr_post = posterior_mean_var(ts, posterior)
assert json.loads(ts_node_metadata.node(2).metadata)["mn"] == mn_post[2]
assert json.loads(ts_node_metadata.node(2).metadata)["vr"] == vr_post[2]
assert ts_node_metadata.node(2).metadata["mn"] == mn_post[2]
assert ts_node_metadata.node(2).metadata["vr"] == vr_post[2]

def test_node_metadata_simulated_tree(self):
larger_ts = msprime.simulate(
10, mutation_rate=1, recombination_rate=1, length=20, random_seed=12
)
_, mn_post, _, _, eps, _ = get_dates(
larger_ts, mutation_rate=None, population_size=10000
)
dated_ts = date(larger_ts, population_size=10000, mutation_rate=None)
metadata = dated_ts.tables.nodes.metadata
metadata_offset = dated_ts.tables.nodes.metadata_offset
unconstrained_mn = [
json.loads(met.decode())["mn"]
for met in tskit.unpack_bytes(metadata, metadata_offset)
if len(met.decode()) > 0
]
assert np.allclose(unconstrained_mn, mn_post[larger_ts.num_samples :])
assert np.all(
dated_ts.tables.nodes.time[larger_ts.num_samples :]
>= mn_post[larger_ts.num_samples :]
is_sample = np.zeros(larger_ts.num_nodes, dtype=bool)
is_sample[larger_ts.samples()] = True
is_not_sample = np.logical_not(is_sample)
# This calls posterior_mean_var
_, mn_post, _, _, _, _ = get_dates(
larger_ts,
method="inside_outside",
population_size=1,
mutation_rate=1,
save_metadata=False,
)
constrained_time = constrain_ages_topo(larger_ts, mn_post, eps=1e-6)
# Samples identical in all methods
assert np.allclose(larger_ts.nodes_time[is_sample], mn_post[is_sample])
assert np.allclose(constrained_time[is_sample], mn_post[is_sample])
# Non-samples should adhere to constraints
assert np.all(constrained_time[is_not_sample] >= mn_post[is_not_sample])


class TestConstrainAgesTopo:
Expand Down Expand Up @@ -1833,7 +1838,7 @@ def test_node_times(self):

def test_fails_unconstrained(self):
ts = utility_functions.two_tree_mutation_ts()
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="must be tsdated"):
nodes_time_unconstrained(ts)


Expand All @@ -1847,15 +1852,22 @@ def test_no_sites(self):
with pytest.raises(ValueError):
tsdate.sites_time_from_ts(ts)

def test_undated(self):
ts = utility_functions.two_tree_mutation_ts()
with pytest.raises(ValueError, match="Try calling"):
tsdate.sites_time_from_ts(ts, unconstrained=True)

def test_node_selection_param(self):
ts = utility_functions.two_tree_mutation_ts()
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="node_selection parameter"):
tsdate.sites_time_from_ts(ts, node_selection="sibling")

def test_sites_time_insideoutside(self):
ts = utility_functions.two_tree_mutation_ts()
dated = tsdate.date(ts, mutation_rate=None, population_size=1)
_, mn_post, _, _, eps, _ = get_dates(ts, mutation_rate=None, population_size=1)
_, mn_post, _, _, eps, _ = get_dates(
ts, mutation_rate=None, population_size=1, save_metadata=False
)
assert np.array_equal(
mn_post[ts.tables.mutations.node],
tsdate.sites_time_from_ts(dated, unconstrained=True, min_time=0),
Expand Down Expand Up @@ -1959,15 +1971,15 @@ def test_sites_time_simulated(self):
10, mutation_rate=1, recombination_rate=1, length=20, random_seed=12
)
_, mn_post, _, _, _, _ = get_dates(
larger_ts, mutation_rate=None, population_size=10000
larger_ts, mutation_rate=None, population_size=10000, save_metadata=False
)
dated = date(larger_ts, mutation_rate=None, population_size=10000)
assert np.allclose(
mn_post[larger_ts.tables.mutations.node],
mn_post[larger_ts.mutations_node],
tsdate.sites_time_from_ts(dated, unconstrained=True, min_time=0),
)
assert np.allclose(
dated.tables.nodes.time[larger_ts.tables.mutations.node],
dated.tables.nodes.time[larger_ts.mutations_node],
tsdate.sites_time_from_ts(dated, unconstrained=False, min_time=0),
)

Expand Down
Loading