Skip to content

Commit

Permalink
Set metadata schemas
Browse files Browse the repository at this point in the history
Defaults to a "struct" type unless a schema already exists (or if there is a null schema that can be interpreted as JSON). Fixes #302
  • Loading branch information
hyanwong committed Jul 24, 2023
1 parent 7282a30 commit 93c0d4d
Show file tree
Hide file tree
Showing 5 changed files with 580 additions and 86 deletions.
53 changes: 30 additions & 23 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
Test cases for the python API for tsdate.
"""
import collections
import json
import logging
import unittest

Expand Down Expand Up @@ -1626,7 +1625,9 @@ def test_posterior_mean_var(self):
ts = utility_functions.single_tree_ts_n2()
for distr in ("gamma", "lognorm"):
posterior, algo = TestTotalFunctionalValueTree().find_posterior(ts, distr)
ts_node_metadata, mn_post, vr_post = posterior_mean_var(ts, posterior)
ts_node_metadata, mn_post, vr_post = posterior_mean_var(
ts, posterior, save_metadata=False
)
assert np.array_equal(
mn_post,
[
Expand All @@ -1638,31 +1639,35 @@ def test_posterior_mean_var(self):

def test_node_metadata_single_tree_n2(self):
ts = utility_functions.single_tree_ts_n2()
tables = ts.dump_tables()
tables.nodes.metadata_schema = tskit.MetadataSchema.permissive_json()
ts = tables.tree_sequence()
posterior, algo = TestTotalFunctionalValueTree().find_posterior(ts, "lognorm")
ts_node_metadata, mn_post, vr_post = posterior_mean_var(ts, posterior)
assert json.loads(ts_node_metadata.node(2).metadata)["mn"] == mn_post[2]
assert json.loads(ts_node_metadata.node(2).metadata)["vr"] == vr_post[2]
assert ts_node_metadata.node(2).metadata["mn"] == mn_post[2]
assert ts_node_metadata.node(2).metadata["vr"] == vr_post[2]

def test_node_metadata_simulated_tree(self):
larger_ts = msprime.simulate(
10, mutation_rate=1, recombination_rate=1, length=20, random_seed=12
)
_, mn_post, _, _, eps, _ = get_dates(
larger_ts, mutation_rate=None, population_size=10000
)
dated_ts = date(larger_ts, population_size=10000, mutation_rate=None)
metadata = dated_ts.tables.nodes.metadata
metadata_offset = dated_ts.tables.nodes.metadata_offset
unconstrained_mn = [
json.loads(met.decode())["mn"]
for met in tskit.unpack_bytes(metadata, metadata_offset)
if len(met.decode()) > 0
]
assert np.array_equal(unconstrained_mn, mn_post[larger_ts.num_samples :])
assert np.all(
dated_ts.tables.nodes.time[larger_ts.num_samples :]
>= mn_post[larger_ts.num_samples :]
is_sample = np.zeros(larger_ts.num_nodes, dtype=bool)
is_sample[larger_ts.samples()] = True
is_not_sample = np.logical_not(is_sample)
# This calls posterior_mean_var
_, mn_post, _, _, _, _ = get_dates(
larger_ts,
method="inside_outside",
population_size=1,
mutation_rate=1,
save_metadata=False,
)
constrained_time = constrain_ages_topo(larger_ts, mn_post, eps=1e-6)
# Samples identical in all methods
assert np.array_equal(larger_ts.nodes_time[is_sample], mn_post[is_sample])
assert np.array_equal(constrained_time[is_sample], mn_post[is_sample])
# Non-samples should adhere to constraints
assert np.all(constrained_time[is_not_sample] >= mn_post[is_not_sample])


class TestConstrainAgesTopo:
Expand Down Expand Up @@ -1862,7 +1867,9 @@ def test_node_selection_param(self):
def test_sites_time_insideoutside(self):
ts = utility_functions.two_tree_mutation_ts()
dated = tsdate.date(ts, mutation_rate=None, population_size=1)
_, mn_post, _, _, eps, _ = get_dates(ts, mutation_rate=None, population_size=1)
_, mn_post, _, _, eps, _ = get_dates(
ts, mutation_rate=None, population_size=1, save_metadata=False
)
assert np.array_equal(
mn_post[ts.tables.mutations.node],
tsdate.sites_time_from_ts(dated, unconstrained=True, min_time=0),
Expand Down Expand Up @@ -1966,15 +1973,15 @@ def test_sites_time_simulated(self):
10, mutation_rate=1, recombination_rate=1, length=20, random_seed=12
)
_, mn_post, _, _, _, _ = get_dates(
larger_ts, mutation_rate=None, population_size=10000
larger_ts, mutation_rate=None, population_size=10000, save_metadata=False
)
dated = date(larger_ts, mutation_rate=None, population_size=10000)
assert np.array_equal(
mn_post[larger_ts.tables.mutations.node],
mn_post[larger_ts.mutations_node],
tsdate.sites_time_from_ts(dated, unconstrained=True, min_time=0),
)
assert np.array_equal(
dated.tables.nodes.time[larger_ts.tables.mutations.node],
dated.nodes_time[larger_ts.mutations_node],
tsdate.sites_time_from_ts(dated, unconstrained=False, min_time=0),
)

Expand Down
Loading

0 comments on commit 93c0d4d

Please sign in to comment.