diff --git a/tests/test_functions.py b/tests/test_functions.py index a3c37948..81c9f968 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -25,7 +25,6 @@ Test cases for the python API for tsdate. """ import collections -import json import logging import unittest @@ -1619,7 +1618,9 @@ def test_posterior_mean_var(self): ts = utility_functions.single_tree_ts_n2() for distr in ("gamma", "lognorm"): posterior, algo = TestTotalFunctionalValueTree().find_posterior(ts, distr) - ts_node_metadata, mn_post, vr_post = posterior_mean_var(ts, posterior) + ts_node_metadata, mn_post, vr_post = posterior_mean_var( + ts, posterior, save_metadata=False + ) assert np.array_equal( mn_post, [ @@ -1631,31 +1632,35 @@ def test_posterior_mean_var(self): def test_node_metadata_single_tree_n2(self): ts = utility_functions.single_tree_ts_n2() + tables = ts.dump_tables() + tables.nodes.metadata_schema = tskit.MetadataSchema.permissive_json() + ts = tables.tree_sequence() posterior, algo = TestTotalFunctionalValueTree().find_posterior(ts, "lognorm") ts_node_metadata, mn_post, vr_post = posterior_mean_var(ts, posterior) - assert json.loads(ts_node_metadata.node(2).metadata)["mn"] == mn_post[2] - assert json.loads(ts_node_metadata.node(2).metadata)["vr"] == vr_post[2] + assert ts_node_metadata.node(2).metadata["mn"] == mn_post[2] + assert ts_node_metadata.node(2).metadata["vr"] == vr_post[2] def test_node_metadata_simulated_tree(self): larger_ts = msprime.simulate( 10, mutation_rate=1, recombination_rate=1, length=20, random_seed=12 ) - _, mn_post, _, _, eps, _ = get_dates( - larger_ts, mutation_rate=None, population_size=10000 - ) - dated_ts = date(larger_ts, population_size=10000, mutation_rate=None) - metadata = dated_ts.tables.nodes.metadata - metadata_offset = dated_ts.tables.nodes.metadata_offset - unconstrained_mn = [ - json.loads(met.decode())["mn"] - for met in tskit.unpack_bytes(metadata, metadata_offset) - if len(met.decode()) > 0 - ] - assert np.allclose(unconstrained_mn, mn_post[larger_ts.num_samples :]) - assert np.all( - dated_ts.tables.nodes.time[larger_ts.num_samples :] - >= mn_post[larger_ts.num_samples :] + is_sample = np.zeros(larger_ts.num_nodes, dtype=bool) + is_sample[larger_ts.samples()] = True + is_not_sample = np.logical_not(is_sample) + # This calls posterior_mean_var + _, mn_post, _, _, _, _ = get_dates( + larger_ts, + method="inside_outside", + population_size=1, + mutation_rate=1, + save_metadata=False, ) + constrained_time = constrain_ages_topo(larger_ts, mn_post, eps=1e-6) + # Samples identical in all methods + assert np.allclose(larger_ts.nodes_time[is_sample], mn_post[is_sample]) + assert np.allclose(constrained_time[is_sample], mn_post[is_sample]) + # Non-samples should adhere to constraints + assert np.all(constrained_time[is_not_sample] >= mn_post[is_not_sample]) class TestConstrainAgesTopo: @@ -1855,7 +1860,9 @@ def test_node_selection_param(self): def test_sites_time_insideoutside(self): ts = utility_functions.two_tree_mutation_ts() dated = tsdate.date(ts, mutation_rate=None, population_size=1) - _, mn_post, _, _, eps, _ = get_dates(ts, mutation_rate=None, population_size=1) + _, mn_post, _, _, eps, _ = get_dates( + ts, mutation_rate=None, population_size=1, save_metadata=False + ) assert np.array_equal( mn_post[ts.tables.mutations.node], tsdate.sites_time_from_ts(dated, unconstrained=True, min_time=0), @@ -1959,15 +1966,15 @@ def test_sites_time_simulated(self): 10, mutation_rate=1, recombination_rate=1, length=20, random_seed=12 ) _, mn_post, _, _, _, _ = get_dates( - larger_ts, mutation_rate=None, population_size=10000 + larger_ts, mutation_rate=None, population_size=10000, save_metadata=False ) dated = date(larger_ts, mutation_rate=None, population_size=10000) assert np.allclose( - mn_post[larger_ts.tables.mutations.node], + mn_post[larger_ts.mutations_node], tsdate.sites_time_from_ts(dated, unconstrained=True, min_time=0), ) assert np.allclose( - dated.tables.nodes.time[larger_ts.tables.mutations.node], + dated.tables.nodes.time[larger_ts.mutations_node], tsdate.sites_time_from_ts(dated, unconstrained=False, min_time=0), ) diff --git a/tests/test_metadata.py b/tests/test_metadata.py new file mode 100644 index 00000000..85820633 --- /dev/null +++ b/tests/test_metadata.py @@ -0,0 +1,303 @@ +# MIT License +# +# Copyright (c) 2021-23 Tskit Developers +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +Test cases for metadata setting functionality in tsdate. +""" +import json +import logging + +import numpy as np +import pytest +import tskit +import utility_functions + +from tsdate.core import date +from tsdate.metadata import node_md_struct +from tsdate.metadata import save_node_metadata + +struct_obj_only_example = tskit.MetadataSchema( + { + "codec": "struct", + "type": "object", + "properties": { + "node_id": {"type": "integer", "binaryFormat": "i"}, + }, + "additionalProperties": False, + } +) + +struct_bad_mn = tskit.MetadataSchema( + { + "codec": "struct", + "type": "object", + "properties": { + "mn": {"type": "integer", "binaryFormat": "i"}, + }, + "additionalProperties": False, + } +) + +struct_bad_vr = tskit.MetadataSchema( + { + "codec": "struct", + "type": "object", + "properties": { + "vr": {"type": "string", "binaryFormat": "10p"}, + }, + "additionalProperties": False, + } +) + + +class TestBytes: + """ + Tests for when existing node metadata is in raw bytes + """ + + def test_no_existing(self): + ts = utility_functions.single_tree_ts_n2() + root = ts.first().root + assert ts.node(root).metadata == b"" + assert ts.table_metadata_schemas.node == tskit.MetadataSchema(None) + ts = date(ts, mutation_rate=1, population_size=1) + assert ts.node(root).metadata["mn"] == pytest.approx(ts.nodes_time[root]) + assert ts.node(root).metadata["vr"] > 0 + + def test_append_existing(self): + ts = utility_functions.single_tree_ts_n2() + root = ts.first().root + assert ts.table_metadata_schemas.node == tskit.MetadataSchema(None) + tables = ts.dump_tables() + tables.nodes.clear() + for nd in ts.nodes(): + tables.nodes.append(nd.replace(metadata=b'{"node_id": %d}' % nd.id)) + ts = tables.tree_sequence() + assert json.loads(ts.node(root).metadata.decode())["node_id"] == root + ts = date(ts, mutation_rate=1, population_size=1) + assert ts.node(root).metadata["node_id"] == root + assert ts.node(root).metadata["mn"] == pytest.approx(ts.nodes_time[root]) + assert ts.node(root).metadata["vr"] > 0 + + def test_replace_existing(self): + ts = utility_functions.single_tree_ts_n2() + root = ts.first().root + assert ts.table_metadata_schemas.node == tskit.MetadataSchema(None) + tables = ts.dump_tables() + tables.nodes.clear() + for nd in ts.nodes(): + tables.nodes.append(nd.replace(metadata=b'{"mn": 1.0}')) + ts = tables.tree_sequence() + assert json.loads(ts.node(root).metadata.decode())["mn"] == pytest.approx(1.0) + ts = date(ts, mutation_rate=1, population_size=1) + assert ts.node(root).metadata["mn"] != pytest.approx(1.0) + assert ts.node(root).metadata["mn"] == pytest.approx(ts.nodes_time[root]) + assert ts.node(root).metadata["vr"] > 0 + + def test_existing_bad(self): + ts = utility_functions.single_tree_ts_n2() + assert ts.table_metadata_schemas.node == tskit.MetadataSchema(None) + tables = ts.dump_tables() + tables.nodes.clear() + for nd in ts.nodes(): + tables.nodes.append(nd.replace(metadata=b"!!")) + ts = tables.tree_sequence() + with pytest.raises(ValueError, match="Cannot modify"): + date(ts, mutation_rate=1, population_size=1) + + def test_erase_existing_bad(self, caplog): + ts = utility_functions.single_tree_ts_n2() + root = ts.first().root + assert ts.table_metadata_schemas.node == tskit.MetadataSchema(None) + tables = ts.dump_tables() + tables.nodes.clear() + for nd in ts.nodes(): + tables.nodes.append(nd.replace(metadata=b"!!")) + ts = tables.tree_sequence() + # Should be able to replace using set_metadat=True + with caplog.at_level(logging.WARNING): + ts = date(ts, mutation_rate=1, population_size=1, set_metadata=True) + assert "Erasing existing node metadata" in caplog.text + assert ts.table_metadata_schemas.node.schema["codec"] == "struct" + assert ts.node(root).metadata["mn"] == pytest.approx(ts.nodes_time[root]) + assert ts.node(root).metadata["vr"] > 0 + + +class TestStruct: + """ + Tests for when existing node metadata is as a struct + """ + + def test_append_existing(self): + ts = utility_functions.single_tree_ts_n2() + root = ts.first().root + tables = ts.dump_tables() + tables.nodes.metadata_schema = struct_obj_only_example + tables.nodes.packset_metadata( + [ + tables.nodes.metadata_schema.validate_and_encode_row({"node_id": i}) + for i in range(ts.num_nodes) + ] + ) + ts = tables.tree_sequence() + assert ts.node(root).metadata["node_id"] == root + ts = date(ts, mutation_rate=1, population_size=1) + assert ts.node(root).metadata["node_id"] == root + assert ts.node(root).metadata["mn"] == pytest.approx(ts.nodes_time[root]) + assert ts.node(root).metadata["vr"] > 0 + + def test_replace_existing(self, caplog): + ts = utility_functions.single_tree_ts_n2() + root = ts.first().root + tables = ts.dump_tables() + tables.nodes.metadata_schema = node_md_struct + tables.nodes.packset_metadata( + [ + tables.nodes.metadata_schema.validate_and_encode_row(None) + for _ in range(ts.num_nodes) + ] + ) + ts = tables.tree_sequence() + assert ts.node(root).metadata is None + with caplog.at_level(logging.INFO): + ts = date(ts, mutation_rate=1, population_size=1) + assert ts.table_metadata_schemas.node.schema["codec"] == "struct" + assert "Replacing 'mn'" in caplog.text + assert "Replacing 'vr'" in caplog.text + assert "Schema modified" in caplog.text + assert ts.node(root).metadata["mn"] == pytest.approx(ts.nodes_time[root]) + assert ts.node(root).metadata["vr"] > 0 + sample = ts.samples()[0] + assert ts.node(sample).metadata is None + + def test_existing_bad_mn(self, caplog): + ts = utility_functions.single_tree_ts_n2() + tables = ts.dump_tables() + tables.nodes.metadata_schema = struct_bad_mn + tables.nodes.packset_metadata( + [ + tables.nodes.metadata_schema.validate_and_encode_row({"mn": 1}) + for _ in range(ts.num_nodes) + ] + ) + ts = tables.tree_sequence() + with pytest.raises( + ValueError, match=r"Cannot change type of node.metadata\['mn'\]" + ): + date(ts, mutation_rate=1, population_size=1) + + def test_existing_bad_vr(self, caplog): + ts = utility_functions.single_tree_ts_n2() + tables = ts.dump_tables() + tables.nodes.metadata_schema = struct_bad_vr + tables.nodes.packset_metadata( + [ + tables.nodes.metadata_schema.validate_and_encode_row({"vr": "foo"}) + for _ in range(ts.num_nodes) + ] + ) + ts = tables.tree_sequence() + with pytest.raises( + ValueError, match=r"Cannot change type of node.metadata\['vr'\]" + ): + date(ts, mutation_rate=1, population_size=1) + + +class TestJson: + """ + Tests for when existing node metadata is json encoded + """ + + def test_replace_existing(self, caplog): + ts = utility_functions.single_tree_ts_n2() + root = ts.first().root + tables = ts.dump_tables() + schema = tables.nodes.metadata_schema = tskit.MetadataSchema.permissive_json() + tables.nodes.packset_metadata( + [ + schema.validate_and_encode_row( + {f"node {i}": 1, "mn": "foo", "vr": "bar"} + ) + for i in range(ts.num_nodes) + ] + ) + ts = tables.tree_sequence() + assert "node 0" in ts.node(0).metadata + assert ts.node(0).metadata["mn"] == "foo" + with caplog.at_level(logging.INFO): + ts = date(ts, mutation_rate=1, population_size=1) + assert ts.table_metadata_schemas.node.schema["codec"] == "json" + assert "Schema modified" in caplog.text + sample = ts.samples()[0] + assert f"node {sample}" in ts.node(sample).metadata + # Should have deleted mn and vr + assert "mn" not in ts.node(sample).metadata + assert "vr" not in ts.node(sample).metadata + assert f"node {root}" in ts.node(root).metadata + assert ts.node(root).metadata["mn"] == pytest.approx(ts.nodes_time[root]) + assert ts.node(root).metadata["vr"] > 0 + + +class TestNoSetMetadata: + """ + Tests for when metadata is not saved + """ + + @pytest.mark.parametrize( + "method", ["inside_outside", "maximization", "variational_gamma"] + ) + def test_empty(self, method): + ts = utility_functions.single_tree_ts_n2() + assert len(ts.tables.nodes.metadata) == 0 + ts = date( + ts, mutation_rate=1, population_size=1, method=method, set_metadata=False + ) + assert len(ts.tables.nodes.metadata) == 0 + + @pytest.mark.parametrize( + "method", ["inside_outside", "maximization", "variational_gamma"] + ) + def test_random_md(self, method): + ts = utility_functions.single_tree_ts_n2() + assert len(ts.tables.nodes.metadata) == 0 + tables = ts.dump_tables() + tables.nodes.packset_metadata([(b"random %i" % u) for u in range(ts.num_nodes)]) + ts = tables.tree_sequence() + assert len(ts.tables.nodes.metadata) > 0 + dts = date( + ts, mutation_rate=1, population_size=1, method=method, set_metadata=False + ) + assert len(ts.tables.nodes.metadata) == len(dts.tables.nodes.metadata) + + +class TestFunctions: + """ + Test internal metadata functions + """ + + def test_bad_save_node_metadata(self): + ts = utility_functions.single_tree_ts_n2() + bad_arr = np.zeros(ts.num_nodes + 1) + good_arr = np.zeros(ts.num_nodes) + for m, v in ([bad_arr, good_arr], [good_arr, bad_arr]): + with pytest.raises(ValueError, match="arrays of length ts.num_nodes"): + save_node_metadata(ts, m, v, fixed_node_set=set(ts.samples())) diff --git a/tsdate/core.py b/tsdate/core.py index e0025b74..d1405c6c 100644 --- a/tsdate/core.py +++ b/tsdate/core.py @@ -25,7 +25,6 @@ """ import functools import itertools -import json import logging import multiprocessing import operator @@ -40,6 +39,7 @@ from . import approx from . import base from . import demography +from . import metadata from . import prior from . import provenance @@ -1057,44 +1057,34 @@ def iterate(self, *, iter_num=None, progress=None): # return marginal_lik -def posterior_mean_var(ts, posterior, *, fixed_node_set=None): +def posterior_mean_var(ts, posterior, *, save_metadata=True, fixed_node_set=None): """ Mean and variance of node age. Fixed nodes will be given a mean of their exact time in the tree sequence, and zero variance (as long as they are identified by the fixed_node_set). If fixed_node_set is None, we attempt to date all the non-sample nodes - Also assigns the estimated mean and variance of the age of each node - as metadata in the tree sequence. + + If save_metadata is True, the estimated mean and variance of the age of each node + is assigned as metadata in the returned tree sequence, otherwise the tree + sequence that was paseed in is returned unchanged. """ mn_post = np.full(ts.num_nodes, np.nan) # Fill with NaNs so we detect when there's vr_post = np.full(ts.num_nodes, np.nan) # been an error - tables = ts.dump_tables() if fixed_node_set is None: - fixed_node_set = ts.samples() + fixed_node_set = set(ts.samples()) fixed_nodes = np.array(list(fixed_node_set)) - mn_post[fixed_nodes] = tables.nodes.time[fixed_nodes] + mn_post[fixed_nodes] = ts.nodes_time[fixed_nodes] vr_post[fixed_nodes] = 0 - metadata_array = tskit.unpack_bytes( - ts.tables.nodes.metadata, ts.tables.nodes.metadata_offset - ) for u in posterior.nonfixed_nodes: probs = posterior[u] times = posterior.timepoints mn_post[u] = np.sum(probs * times) / np.sum(probs) vr_post[u] = np.sum(((mn_post[u] - (times)) ** 2) * (probs / np.sum(probs))) - metadata_array[u] = json.dumps({"mn": mn_post[u], "vr": vr_post[u]}).encode() - md, md_offset = tskit.pack_bytes(metadata_array) - tables.nodes.set_columns( - flags=tables.nodes.flags, - time=tables.nodes.time, - population=tables.nodes.population, - individual=tables.nodes.individual, - metadata=md, - metadata_offset=md_offset, - ) - ts = tables.tree_sequence() + + if save_metadata: + ts = metadata.save_node_metadata(ts, mn_post, vr_post, fixed_node_set) return ts, mn_post, vr_post @@ -1138,6 +1128,7 @@ def date( *, Ne=None, return_posteriors=None, + set_metadata=None, progress=False, **kwargs, ): @@ -1194,7 +1185,17 @@ def date( conditional coalescent prior with a standard set of time points as given by :func:`build_prior_grid`. :param bool return_posteriors: If ``True``, instead of returning just a dated tree - sequence, return a tuple of ``(dated_ts, posteriors)`` (see note above). + sequence, return a tuple of ``(dated_ts, posteriors)`` (see note above). Default: + ``None`` (treated as ``False``). + :param bool set_metadata: If ``True``, replace all existing node metadata with + details of the times (means and variances) for each node. If ``False``, + do not touch any existing metadata in the tree sequence. If ``None`` + (default), attempt to modify any existing node metadata to add the + times (means and variances) for each node, overwriting only those specific + metadata values. If no node metadata schema has been set, this will be possible + only if either (a) the raw metadata can be decoded as JSON, in which case the + schema is set to permissive_json or (b) no node metadata exists (in which case + a default schema will be set), otherwise an error will be raised. :param float eps: Specify minimum distance separating time points. Also specifies the error factor in time difference calculations. Default: 1e-6 :param int num_threads: The number of threads to use. A simpler unthreaded algorithm @@ -1227,41 +1228,47 @@ def date( ) else: population_size = Ne - + ts, save_metadata = metadata.set_tsdate_node_md_schema(tree_sequence, set_metadata) if isinstance(population_size, dict): population_size = demography.PopulationSizeHistory(**population_size) if method == "variational_gamma": - tree_sequence, dates, posteriors, timepoints, eps, nds = variational_dates( - tree_sequence, + ts, dates, posteriors, timepoints, eps, nds = variational_dates( + ts, population_size=population_size, mutation_rate=mutation_rate, recombination_rate=recombination_rate, priors=priors, progress=progress, + save_metadata=save_metadata, **kwargs, ) else: - tree_sequence, dates, posteriors, timepoints, eps, nds = get_dates( - tree_sequence, + ts, dates, posteriors, timepoints, eps, nds = get_dates( + ts, population_size=population_size, mutation_rate=mutation_rate, recombination_rate=recombination_rate, priors=priors, progress=progress, method=method, + save_metadata=save_metadata, **kwargs, ) - constrained = constrain_ages_topo(tree_sequence, dates, eps, progress) - tables = tree_sequence.dump_tables() + constrained = constrain_ages_topo(ts, dates, eps, progress) + tables = ts.dump_tables() tables.time_units = time_units tables.nodes.time = constrained # Remove any times associated with mutations - tables.mutations.time = np.full(tree_sequence.num_mutations, tskit.UNKNOWN_TIME) + tables.mutations.time = np.full(ts.num_mutations, tskit.UNKNOWN_TIME) tables.sort() params = dict( mutation_rate=mutation_rate, recombination_rate=recombination_rate, + time_units=time_units, + method=method, + set_metadata=set_metadata, + return_posteriors=return_posteriors, progress=progress, ) if isinstance(population_size, (int, float)): @@ -1298,6 +1305,7 @@ def get_dates( progress=False, cache_inside=False, probability_space=None, + save_metadata=True, ): """ Infer dates for the nodes in a tree sequence, returning an array of inferred dates @@ -1305,6 +1313,11 @@ def get_dates( etc. Parameters are identical to the date() method, which calls this method, then injects the resulting date estimates into the tree sequence + If ``save_metadata`` is ``True``, "mn" and "vr" fields are set in the node metdata of + the returned tree sequence, representing the mean and variance of the node ages. This + assumes that the node metadata schema allows "mn" and "vr" fields to be set (if not, + use ``save_metadata=False``). + :return: a tuple of ``(mn_post, posteriors, timepoints, eps, nodes_to_date)``. If the "inside_outside" method is used, ``posteriors`` will contain the posterior probabilities for each node in each time slice, else the returned @@ -1383,7 +1396,10 @@ def get_dates( posterior.force_probability_space(base.LIN) posterior.to_probabilities() tree_sequence, mn_post, _ = posterior_mean_var( - tree_sequence, posterior, fixed_node_set=fixed_nodes + tree_sequence, + posterior, + fixed_node_set=fixed_nodes, + save_metadata=save_metadata, ) elif method == "maximization": if mutation_rate is not None: @@ -1405,47 +1421,39 @@ def get_dates( ) -def variational_mean_var(ts, posterior, *, fixed_node_set=None): +def variational_mean_var(ts, posterior, *, save_metadata=True, fixed_node_set=None): """ - Mean and variance of node age from variational posterior (e.g. gamma - distributions). Fixed nodes will be given a mean of their exact time in - the tree sequence, and zero variance (as long as they are identified by the + Return the mean and variance of node age from variational posterior (e.g. gamma + distributions). The returned mean for fixed nodes will be their exact time in + the tree sequence with zero variance (as long as they are identified by the fixed_node_set). If fixed_node_set is None, we attempt to date all the - non-sample nodes Also assigns the estimated mean and variance of the age of + non-sample nodes. Also assigns the estimated mean and variance of the age of each node as metadata in the tree sequence. + + + TODO - we should be able to get the set of fixed nodes from the posterior """ mn_post = np.full(ts.num_nodes, np.nan) # Fill with NaNs so we detect when there's vr_post = np.full(ts.num_nodes, np.nan) # been an error - tables = ts.dump_tables() if fixed_node_set is None: - fixed_node_set = ts.samples() + fixed_node_set = set(ts.samples()) fixed_nodes = np.array(list(fixed_node_set)) - mn_post[fixed_nodes] = tables.nodes.time[fixed_nodes] + mn_post[fixed_nodes] = ts.nodes_time[fixed_nodes] vr_post[fixed_nodes] = 0 assert np.all(posterior.grid_data[:, 0] > 0), "Invalid posterior" - metadata_array = tskit.unpack_bytes( - ts.tables.nodes.metadata, ts.tables.nodes.metadata_offset - ) for u in posterior.nonfixed_nodes: # TODO: with method posterior.mean_and_var(node_id) this could be # easily combined with posterior_mean_var pars = posterior[u] mn_post[u] = pars[0] / pars[1] vr_post[u] = pars[0] / pars[1] ** 2 - metadata_array[u] = json.dumps({"mn": mn_post[u], "vr": vr_post[u]}).encode() - md, md_offset = tskit.pack_bytes(metadata_array) - tables.nodes.set_columns( - flags=tables.nodes.flags, - time=tables.nodes.time, - population=tables.nodes.population, - individual=tables.nodes.individual, - metadata=md, - metadata_offset=md_offset, - ) - ts = tables.tree_sequence() + + if save_metadata: + ts = metadata.save_node_metadata(ts, mn_post, vr_post, fixed_node_set) + return ts, mn_post, vr_post @@ -1460,6 +1468,7 @@ def variational_dates( global_prior=True, eps=1e-6, progress=False, + save_metadata=True, num_threads=None, # Unused, matches get_dates() probability_space=None, # Can only be None, simply to match get_dates() ignore_oldest_root=False, # Can only be False, simply to match get_dates() @@ -1547,7 +1556,10 @@ def variational_dates( dynamic_prog.iterate(iter_num=it) posterior = dynamic_prog.posterior tree_sequence, mn_post, _ = variational_mean_var( - tree_sequence, posterior, fixed_node_set=fixed_nodes + tree_sequence, + posterior, + fixed_node_set=fixed_nodes, + save_metadata=save_metadata, ) return ( diff --git a/tsdate/metadata.py b/tsdate/metadata.py new file mode 100644 index 00000000..c1d1d0e7 --- /dev/null +++ b/tsdate/metadata.py @@ -0,0 +1,173 @@ +# MIT License +# +# Copyright (c) 2023 Tskit Developers +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +Functions for setting or merging schemas in tsdate-generated tree sequences. Note that +tsdate will only add metadata to the node table, so this is the only relevant schema. +""" +import json +import logging + +import tskit + +MEAN_KEY = "mn" +VARIANCE_KEY = "vr" + +node_md_mean = { + "description": ( + "The mean time of this node, calculated from the tsdate posterior " + "probabilities. This may not be the same as the node time, as it " + "is not constrained by parent-child order." + ), + "type": "number", + "binaryFormat": "d", +} + +node_md_variance = { + "description": ( + "The variance in times of this node, calculated from the tsdate " + "posterior probabilities" + ), + "type": "number", + "binaryFormat": "d", +} + +node_md_struct = tskit.MetadataSchema( + { + "codec": "struct", + "type": ["object", "null"], + "default": None, + "properties": {MEAN_KEY: node_md_mean, VARIANCE_KEY: node_md_variance}, + "additionalProperties": False, + } +) + + +def set_tsdate_node_md_schema(ts, set_metadata=None): + """ + Taken from the ``tsdate.date()`` docs: + If set_metadata is ``True``, replace all existing node metadata with + details of the times (means and variances) for each node. If ``False``, + do not touch any existing metadata in the tree sequence. If ``None`` + (default), attempt to modify any existing node metadata to add the + times (means and variances) for each node, overwriting only those specific + metadata values. If no node metadata schema has been set, this will be possible + only if either (a) the raw metadata can be decoded as JSON, in which case the + schema is set to permissive_json or (b) no node metadata exists (in which + case a default schema will be set), otherwise an error will be raised. + + Returns: + A tuple of the tree sequence and a boolean indicating whether the + metadata should be saved or not. + """ + if set_metadata is not None and not set_metadata: + logging.debug("Not setting node metadata") + return ts, False + + tables = ts.dump_tables() + if set_metadata: + # Erase existing metadata, force schema to be node_md_struct + if len(tables.nodes.metadata) != 0: + logging.warning("Erasing existing node metadata") + tables.nodes.packset_metadata([b"" for _ in range(tables.nodes.num_rows)]) + tables.nodes.metadata_schema = node_md_struct + return tables.tree_sequence(), True + + # set_metadata is None: try to set or modify any existing metadata schema + schema_object = node_md_struct + if tables.nodes.metadata_schema == tskit.MetadataSchema(schema=None): + if len(tables.nodes.metadata) > 0: + # For backwards compatibility, if the node metadata is bytes (schema=None) + # but can be decoded as JSON, we change the schema to permissive_json + schema_object = tskit.MetadataSchema.permissive_json() + try: + for node in ts.nodes(): + _ = json.loads(node.metadata.decode() or "{}") + except json.JSONDecodeError: + raise ValueError( + "Cannot modify node metadata if schema is " + "None and non-JSON node metadata already exists" + ) + logging.info("Null schema now set to permissive_json") + else: + # Make new schema on basis of existing + schema = tables.nodes.metadata_schema.schema + if "properties" not in schema: + schema["properties"] = {} + prop = schema["properties"] + for key, dfn in zip((MEAN_KEY, VARIANCE_KEY), (node_md_mean, node_md_variance)): + if key in prop: + if not prop[key].get("type") in {dfn["type"], None}: + raise ValueError( + f"Cannot change type of node.metadata['{key}'] in schema" + ) + else: + logging.info(f"Replacing '{key}' in existing node metadata schema") + + prop[key] = dfn.copy() + if schema["codec"] == "struct": + # If we are adding to an existing struct codec, a "null" entry may + # not be allowed, so we need to add null defaults for the new fields + logging.info("Adding NaN default to schema") + prop[key]["default"] = float("NaN") + + # Repack, erasing old metadata present in the target keys + schema_object = tskit.MetadataSchema(schema) + metadata_array = [] + for node in ts.nodes(): + md = node.metadata + for key in [MEAN_KEY, VARIANCE_KEY]: + try: + del md[key] + logging.debug(f"Deleting existing '{key}' value in node metadata") + except (KeyError, TypeError): + pass + metadata_array.append(schema_object.validate_and_encode_row(md)) + tables.nodes.packset_metadata(metadata_array) + logging.info("Schema modified") + + tables.nodes.metadata_schema = schema_object + return tables.tree_sequence(), True + + +def save_node_metadata(ts, means, variances, fixed_node_set): + """ + Assign means and variances (both arrays of length ts.num_nodes) + to the node metadata and return the resulting tree sequence. + + Assumes that the metadata schema in the tree sequence allows + MEAN_KEY and VARIANCE_KEY to be set to numbers + """ + if len(means) != ts.num_nodes or len(variances) != ts.num_nodes: + raise ValueError("means and variances must be arrays of length ts.num_nodes") + tables = ts.dump_tables() + nodes = tables.nodes + metadata_array = [] + for node, mean, var in zip(ts.nodes(), means, variances): + md = node.metadata + if node.id not in fixed_node_set: + if md is None: + md = {} + md[MEAN_KEY] = mean + md[VARIANCE_KEY] = var + metadata_array.append(nodes.metadata_schema.validate_and_encode_row(md)) + nodes.packset_metadata(metadata_array) + return tables.tree_sequence() diff --git a/tsdate/util.py b/tsdate/util.py index 8ab6bbb5..5b25f638 100644 --- a/tsdate/util.py +++ b/tsdate/util.py @@ -1,6 +1,7 @@ # MIT License # # Copyright (c) 2020 University of Oxford +# Copyright (c) 2021-23 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -23,7 +24,6 @@ Utility functions for tsdate. Many of these can be removed when tskit is updated to a more recent version which has the functionality built-in """ -import json import logging import numpy as np @@ -186,13 +186,12 @@ def nodes_time_unconstrained(tree_sequence): not contain this information. """ nodes_time = tree_sequence.tables.nodes.time.copy() - metadata = tree_sequence.tables.nodes.metadata - metadata_offset = tree_sequence.tables.nodes.metadata_offset - for index, met in enumerate(tskit.unpack_bytes(metadata, metadata_offset)): - if index not in tree_sequence.samples(): + sample_set = set(tree_sequence.samples()) + for node in tree_sequence.nodes(): + if node.id not in sample_set: try: - nodes_time[index] = json.loads(met.decode())["mn"] - except (KeyError, json.decoder.JSONDecodeError): + nodes_time[node.id] = node.metadata["mn"] + except (KeyError, TypeError): raise ValueError( "Tree Sequence must be tsdated with the Inside-Outside Method." ) @@ -265,8 +264,9 @@ def sites_time_from_ts( try: nodes_time = nodes_time_unconstrained(tree_sequence) except ValueError as e: - e.args += "Try calling sites_time_from_ts() with unconstrained=False." - raise + raise ValueError( + "Try calling sites_time_from_ts() with unconstrained=False." + ) from e else: nodes_time = tree_sequence.tables.nodes.time sites_time = np.full(tree_sequence.num_sites, np.nan)