diff --git a/tsinfer/eval_util.py b/tsinfer/eval_util.py index 3badc889..859e7f60 100644 --- a/tsinfer/eval_util.py +++ b/tsinfer/eval_util.py @@ -710,8 +710,6 @@ def run_perfect_inference( extended_checks=extended_checks, progress_monitor=progress_monitor, ) - # If time_chunking is turned on we need to stabilise the node ordering in the output - # to ensure that the node IDs are comparable. inferred_ts = inference.match_samples( sample_data, ancestors_ts, @@ -720,13 +718,34 @@ def run_perfect_inference( num_threads=num_threads, extended_checks=extended_checks, progress_monitor=progress_monitor, - stabilise_node_ordering=time_chunking and not path_compression, + simplify=False, # Don't simplify until we have stabilised the node order below ) + # If time_chunking is turned on we need to stabilise the node ordering in the output + # to ensure that the node IDs are comparable. + if time_chunking and not path_compression: + inferred_ts = stabilise_node_ordering(inferred_ts) + # to compare against the original, we need to remove unary nodes from the inferred TS inferred_ts = inferred_ts.simplify(keep_unary=False, filter_sites=False) return ts, inferred_ts +def stabilise_node_ordering(ts): + # Ensure all the node times are distinct so that they will have + # stable IDs after simplifying. This could possibly also be done + # by reversing the IDs within a time slice. This is used for comparing + # tree sequences produced by perfect inference. + tables = ts.dump_tables() + times = tables.nodes.time + for t in range(1, int(times[0])): + index = np.where(times == t)[0] + k = index.shape[0] + times[index] += np.arange(k)[::-1] / k + tables.nodes.time = times + tables.sort() + return tables.tree_sequence() + + def count_sample_child_edges(ts): """ Returns an array counting the number of edges where each sample is a child. diff --git a/tsinfer/inference.py b/tsinfer/inference.py index b78b9d64..2890bd7f 100644 --- a/tsinfer/inference.py +++ b/tsinfer/inference.py @@ -550,7 +550,6 @@ def match_samples( recombination=None, # See :class:`Matcher` mismatch=None, # See :class:`Matcher` precision=None, - stabilise_node_ordering=False, extended_checks=False, engine=constants.C_ENGINE, progress_monitor=None, @@ -636,9 +635,7 @@ def match_samples( # we sometimes assume they are in the same order as in the file manager.match_samples(sample_indexes, sample_times) - ts = manager.finalise( - simplify=simplify, stabilise_node_ordering=stabilise_node_ordering - ) + ts = manager.finalise(simplify=simplify) return ts @@ -1641,7 +1638,7 @@ def match_samples(self, sample_indexes, sample_times): progress_monitor.update() progress_monitor.close() - def finalise(self, simplify, stabilise_node_ordering): + def finalise(self, simplify): logger.info("Finalising tree sequence") ts = self.get_samples_tree_sequence() if simplify: @@ -1650,20 +1647,6 @@ def finalise(self, simplify, stabilise_node_ordering): "filter_individuals=False, keep_unary=True) on " f"{ts.num_nodes} nodes and {ts.num_edges} edges" ) - if stabilise_node_ordering: - # Ensure all the node times are distinct so that they will have - # stable IDs after simplifying. This could possibly also be done - # by reversing the IDs within a time slice. This is used for comparing - # tree sequences produced by perfect inference. - tables = ts.dump_tables() - times = tables.nodes.time - for t in range(1, int(times[0])): - index = np.where(times == t)[0] - k = index.shape[0] - times[index] += np.arange(k)[::-1] / k - tables.nodes.time = times - tables.sort() - ts = tables.tree_sequence() ts = ts.simplify( samples=list(self.sample_id_map.values()), filter_sites=False,