diff --git a/python/tests/ibd.py b/python/tests/ibd.py index db30ead0f8..6db35201b7 100644 --- a/python/tests/ibd.py +++ b/python/tests/ibd.py @@ -134,27 +134,69 @@ def __repr__(self): def __str__(self): return repr(self) - def add_segment(self, a, b, seg): + def add_segment_deprecated(self, a, b, seg): + # The original version of add_segment that doesn't sort or squash. key = (a, b) if a < b else (b, a) self.segments[key].append(tskit.IdentitySegment(seg.left, seg.right, seg.node)) - def sort_and_squash_segments(self, a, b): - # Sort and squash in-place. This will work differently in the C implementation. + def add_segment(self, a, b, seg): key = (a, b) if a < b else (b, a) - self.segments[key] = sorted(self.segments[key]) - num_segs = len(self.segments[key]) - new_list = [] - if num_segs > 1: - for seg in self.segments[key]: - if len(new_list) == 0: - new_list.append(seg) - else: - n = new_list[-1] # last segment - if seg.node != n.node or seg.left != n.right: - new_list.append(seg) - else: - new_list[-1] = tskit.IdentitySegment(n.left, seg.right, n.node) - self.segments[key] = new_list + + # Get position and add into the correct position. + current_segs = self.segments[key] + num_segs = len(current_segs) + + if num_segs == 0: + self.segments[key].append( + tskit.IdentitySegment(seg.left, seg.right, seg.node) + ) + else: + # Find the position for the new segment. + i = 0 + while ( + i < num_segs + and current_segs[i].node <= seg.node + and current_segs[i].right <= seg.left + ): + i += 1 + + # Calculate boolean values that determine whether to squash + # and if so, where. + PUT_FIRST = False # Insert segment at start of list. + PUT_LAST = False # Insert segment at end of list. + SQUASH_LEFT = False # Squash with the left segment. + SQUASH_RIGHT = False # Squash with the right segment. + + if i == 0: + PUT_FIRST = True + if i == num_segs: + PUT_LAST = True + if not PUT_FIRST: + if ( + current_segs[i - 1].node == seg.node + and current_segs[i - 1].right == seg.left + ): + SQUASH_LEFT = True + if not PUT_LAST: + if ( + seg.node == current_segs[i].node + and seg.right == current_segs[i].left + ): + SQUASH_RIGHT = True + + # Insert the new segment and squash if needed. + if SQUASH_LEFT and not SQUASH_RIGHT: + current_segs[i - 1].right = seg.right + elif SQUASH_RIGHT and not SQUASH_LEFT: + current_segs[i].left = seg.left + elif SQUASH_LEFT and SQUASH_RIGHT: + # To squash twice, must pop one of the existing segments. + current_segs[i - 1].right = current_segs[i].right + current_segs.pop(i) + else: + self.segments[key].insert( + i, tskit.IdentitySegment(seg.left, seg.right, seg.node) + ) class IbdFinder: @@ -238,13 +280,18 @@ def record_ibd(self, current_parent, child_segs, squash): # If there are any overlapping segments, record as a new # IBD relationship. if self.passes_filters(seg0.node, seg1.node, left, right): - self.result.add_segment( - seg0.node, - seg1.node, - Segment(left, right, current_parent), - ) if squash: - self.result.sort_and_squash_segments(seg0.node, seg1.node) + self.result.add_segment( + seg0.node, + seg1.node, + Segment(left, right, current_parent), + ) + else: + self.result.add_segment_deprecated( + seg0.node, + seg1.node, + Segment(left, right, current_parent), + ) seg1 = seg1.next seg0 = seg0.next