Skip to content

Commit

Permalink
add taxa ordering and an Alignment class for keeping track of it
Browse files Browse the repository at this point in the history
  • Loading branch information
js51 committed Feb 21, 2024
1 parent 6400054 commit eb6140d
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 47 deletions.
10 changes: 5 additions & 5 deletions examples/playground.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@

# %%
# Define a 8-taxon rooted binary tree in newick format with random branch lengths
#newick = "((((A:0.1,B:0.04):0.06,(C:0.13,D:0.05):0.06):0.1,(E:0.1,F:0.1):0.1):0.15,(G:0.1,H:0.2):0.08);"
newick = "((A:0.05,B:0.4):0.025,(C:0.05,D:0.4):0.025)"
newick = "((((A:0.1,B:0.04):0.06,(C:0.13,D:0.05):0.06):0.1,(E:0.1,F:0.1):0.1):0.15,(G:0.1,H:0.2):0.08);"
#newick = "((A:0.05,B:0.4):0.025,(C:0.05,D:0.4):0.025)"
tree = sp.Phylogeny(newick)
tree.draw()

# %%
# Define the model
model = sp.model.GTR.Kimura(0.5)
model = sp.model.GTR.JukesCantor(0.5)

# %%
# Generate a sequence alignment
alignment = sp.simulation.generate_alignment(tree, model, 5000)
alignment = sp.simulation.generate_alignment(tree, model, 10000)

# %%
# Get all the true splits
Expand Down Expand Up @@ -44,7 +44,7 @@
score_mutual_information = 0
for a in range(100):
print(f"Alignment {a}")
alignment = sp.simulation.generate_alignment(tree, model, 1000)
alignment = sp.simulation.generate_alignment(tree, model, 200)
splits_flat = sp.phylogenetics.erickson_SVD(alignment, taxa=tree.get_taxa(), method=sp.Method.flattening)
splits_KL = sp.phylogenetics.erickson_SVD(alignment, taxa=tree.get_taxa(), method=sp.Method.mutual_information)
if set(splits_flat) >= set(splits):
Expand Down
2 changes: 2 additions & 0 deletions splitp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from splitp import simulation
from splitp import phylogenetics
from splitp import trees
from splitp import alignment
from splitp import constants

# Import other important functions and classes
from splitp.constructions import flattening, subflattening
Expand Down
6 changes: 6 additions & 0 deletions splitp/alignment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from collections import UserDict

class Alignment(UserDict):
def __init__(self, data, taxa):
self.data = data
self.taxa = taxa
20 changes: 16 additions & 4 deletions splitp/constructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ def flattening(split, pattern_probabilities, flattening_format=FlatFormat.sparse
"""
if isinstance(split, str):
split = split.split("|")
taxa = sorted(set(split[0]) | set(split[1]))
try:
taxa = pattern_probabilities.taxa
except AttributeError:
taxa = sorted(set.union(*map(set, split)))
if flattening_format is FlatFormat.sparse:
return __sparse_flattening(split, pattern_probabilities, taxa)
if flattening_format is FlatFormat.reduced:
Expand Down Expand Up @@ -88,21 +91,30 @@ def __sparse_flattening(
row = __index_of(row_pattern)
col_pattern = "".join([str(pattern[taxa_indexer[s]]) for s in split[1]])
col = __index_of(col_pattern)
if (ban_col_patterns is not None and col_pattern.count(ban_col_patterns) > 1) or (ban_row_patterns is not None and row_pattern.count(ban_row_patterns) > 1):
flattening[row, col] = 0
if (
ban_col_patterns is not None and col_pattern.count(ban_col_patterns) > 1
) or (
ban_row_patterns is not None and row_pattern.count(ban_row_patterns) > 1
):
flattening[row, col] = 0
else:
flattening[row, col] = r[1]
return flattening


sparse_flattening_with_banned_patterns = __sparse_flattening


def subflattening(split, pattern_probabilities, data=None):
"""
A faster version of signed sum subflattening. Requires a data dictionary and can be supplied with a bundle of
re-usable information to reduce the number of calls to the multiplications function.
"""
state_space = constants.DNA_state_space
taxa = sorted(set(split[0]) | set(split[1]))
try:
taxa = pattern_probabilities.taxa
except AttributeError:
taxa = sorted(set.union(*map(set, split)))
taxa_indexer = {taxon: i for i, taxon in enumerate(taxa)}

if data is None:
Expand Down
68 changes: 52 additions & 16 deletions splitp/phylogenetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,16 +128,21 @@ def _erickstep(all_taxa, alignment):
score = sp.split_score(flattening)

elif method == sp.Method.subflattening:
subflattening = sp.subflattening(split, alignment, subflattening_data)
subflattening = sp.subflattening(
split, alignment, subflattening_data
)
score = sp.split_score(subflattening)

elif method == sp.Method.mutual_information:
flattening = sp.flattening(split, alignment, sp.FlatFormat.reduced)
score = sp.phylogenetics.flattening_rank_1_approximation_divergence(flattening)
score = sp.phylogenetics.flattening_rank_1_approximation_divergence(
flattening
)

all_scores[split] = score
scores[pair] = (pair, split, score)
if show_work: print(f"Scores: {scores}")
if show_work:
print(f"Scores: {scores}")
best_pair, best_split, best_score = min(scores.values(), key=lambda x: x[2])
return best_pair, best_split, best_score

Expand Down Expand Up @@ -189,6 +194,7 @@ def _consolidate(tup, smaller_halves):
_consolidate(smaller_half, smaller_halves),
)
)

splits = sorted(splits, key=lambda x: min(len(x[0]), len(x[1])), reverse=True)
if len(splits) == 1:
return str(splits[0]).replace("'", "").replace(" ", "") + ";"
Expand Down Expand Up @@ -322,7 +328,9 @@ def split_score(
)


def flattening_rank_1_approximation(flattening, return_vectors=False, dont_compute_matrix=False):
def flattening_rank_1_approximation(
flattening, return_vectors=False, dont_compute_matrix=False
):
r = np.array([sum(flattening)])
c = np.array([sum(flattening.T)])
approximation = None if dont_compute_matrix else r.T @ c
Expand All @@ -335,15 +343,28 @@ def flattening_rank_1_approximation(flattening, return_vectors=False, dont_compu
def flattening_rank_k_approximation(split, alignment):
taxa = sorted(set(split[0]) | set(split[1]))
sums_of_rows = [
sum(sp.constructions.sparse_flattening_with_banned_patterns(split, alignment, taxa, ban_row_patterns=char)) for char in constants.DNA_state_space
sum(
sp.constructions.sparse_flattening_with_banned_patterns(
split, alignment, taxa, ban_row_patterns=char
)
)
for char in constants.DNA_state_space
]
sums_of_cols = [
sum(sp.constructions.sparse_flattening_with_banned_patterns(split, alignment, taxa, ban_col_patterns=char).T) for char in constants.DNA_state_space
sum(
sp.constructions.sparse_flattening_with_banned_patterns(
split, alignment, taxa, ban_col_patterns=char
).T
)
for char in constants.DNA_state_space
]
return sum(A.T * B for A, B in zip(sums_of_rows, sums_of_cols))


def flattening_rank_1_approximation_divergence(flattening):
_, r, c = flattening_rank_1_approximation(flattening, return_vectors=True, dont_compute_matrix=True)
_, r, c = flattening_rank_1_approximation(
flattening, return_vectors=True, dont_compute_matrix=True
)
total = 0
for x in range(len(c)):
for y in range(len(r)):
Expand Down Expand Up @@ -388,7 +409,7 @@ def neighbour_joining(distance_matrix, labels=None, return_newick=False):
if labels is not None:
# Add a label to each node
for i in range(n):
T.nodes[i]['label'] = labels[i]
T.nodes[i]["label"] = labels[i]

# NJ Algorithm
while num_leaves > 2:
Expand Down Expand Up @@ -446,37 +467,48 @@ def neighbour_joining(distance_matrix, labels=None, return_newick=False):

return T


def distance_matrix(networkx_tree):
"""Distance matrix of a tree.
Args:
networkx_tree (networkx.DiGraph): A tree.
Returns:
numpy.ndarray: A distance matrix.
"""
# Get all the leaves
leaf_nodes = [node for node in networkx_tree.nodes if networkx_tree.out_degree(node) == 0]
leaf_nodes = [
node for node in networkx_tree.nodes if networkx_tree.out_degree(node) == 0
]
# Create the distance matrix
distance_matrix = np.zeros((len(leaf_nodes), len(leaf_nodes)))
for i in range(len(leaf_nodes)):
for j in range(i + 1, len(leaf_nodes)):
distance = nx.shortest_path_length(networkx_tree.to_undirected(), leaf_nodes[i], leaf_nodes[j], weight="weight")
distance = nx.shortest_path_length(
networkx_tree.to_undirected(),
leaf_nodes[i],
leaf_nodes[j],
weight="weight",
)
distance_matrix[i, j] = distance
distance_matrix[j, i] = distance
return distance_matrix


def midpoint_rooting(networkx_tree, weight_label="weight"):
"""Midpoint rooting of a tree.
Args:
networkx_tree (networkx.DiGraph): A tree.
Returns:
networkx.DiGraph: A rooted tree.
"""
# Get all the leaves
leaf_nodes = [node for node in networkx_tree.nodes if networkx_tree.out_degree(node) == 0]
leaf_nodes = [
node for node in networkx_tree.nodes if networkx_tree.out_degree(node) == 0
]
# Get the distance matrix
D = distance_matrix(networkx_tree)
# Get the index of the largest distance
Expand All @@ -486,7 +518,7 @@ def midpoint_rooting(networkx_tree, weight_label="weight"):
tree_undirected = networkx_tree.to_undirected()
# Get the path between the two leaves
path = nx.shortest_path(tree_undirected, leaf_nodes[i], leaf_nodes[j])
midpoint_dist = max_dist / 2
midpoint_dist = max_dist / 2
# Travel along the path until the midpoint is reached. Then go back and add a new node
current_dist = 0
prev_dist = 0
Expand All @@ -509,6 +541,10 @@ def midpoint_rooting(networkx_tree, weight_label="weight"):
else:
raise ValueError("Edge not found. Is tree already rooted?")
# Add the branch lengths
networkx_tree.edges[new_node, path[k]][weight_label] = current_dist - midpoint_dist
networkx_tree.edges[new_node, path[k + 1]][weight_label] = midpoint_dist - prev_dist
networkx_tree.edges[new_node, path[k]][weight_label] = (
current_dist - midpoint_dist
)
networkx_tree.edges[new_node, path[k + 1]][weight_label] = (
midpoint_dist - prev_dist
)
break
31 changes: 16 additions & 15 deletions splitp/phylogeny.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,14 @@


class Phylogeny:
__slots__ = (
"name",
"networkx_graph",
"newick_string",
)
__slots__ = ("name", "networkx_graph", "newick_string", "taxa")

def __init__(
self,
newick_string,
name=None,
override_branch_length=None,
taxa_sort_order=None,
):
"""A rooted phylogenetic tree.
Expand All @@ -27,6 +24,8 @@ def __init__(
Attributes:
name: A name for the tree
networkx_graph: the underlying networkx graph
newick_string: the newick string representation of the tree
taxa_sort_order: the order in which taxa are sorted in the tree
"""

# Set chosen tree properties
Expand All @@ -47,6 +46,15 @@ def __init__(
json_graph.tree_data(self.networkx_graph, self.root(return_index=False))
)

if taxa_sort_order is None:
self.taxa = sorted(self.get_taxa())
else:
self.taxa = taxa_sort_order
if set(self.taxa) != set(self.get_taxa()):
raise ValueError(
"The taxa sort order must contain all taxa in the tree."
)

def __str__(self):
"""Return the tree in JSON format"""
return json_to_newick(
Expand All @@ -73,13 +81,6 @@ def unrooted_networkx_graph(self):
# Return the unrooted graph
return unrooted_graph

def reassign_transition_matrices(self, transition_matrix):
"""DEPRECATED: Reassign transition matrices to all nodes in the tree"""
for node in self.networkx_graph.nodes:
self.networkx_graph.nodes[node]["transition_matrix"] = np.array(
transition_matrix
)

def get_num_nodes(self):
return len(self.networkx_graph.nodes)

Expand Down Expand Up @@ -120,7 +121,7 @@ def get_taxa(self):
return [n for n in self.networkx_graph.nodes if self.is_leaf(n)]

def get_num_taxa(self):
return len(self.get_taxa())
return len(self.taxa)

def is_leaf(self, n_index_or_name):
"""Determines whether a node is a leaf node from it's index."""
Expand Down Expand Up @@ -150,7 +151,7 @@ def splits(self, include_trivial=False, as_strings=False):
"""Returns set of all splits displayed by the tree."""
from networkx.algorithms.traversal.depth_first_search import dfs_tree

all_taxa = [x for x in self.get_taxa()]
all_taxa = [x for x in self.taxa]
splits = set()
for node in list(self.nodes()):
subtree = dfs_tree(self.networkx_graph, node)
Expand Down Expand Up @@ -178,7 +179,7 @@ def format_split(self, split):
raise ValueError(
"Cannot produce string format for split with more than 35 taxa."
)
if all(len(taxon) == 1 for taxon in self.get_taxa()):
if all(len(taxon) == 1 for taxon in self.taxa):
return f'{"".join(split[0])}|{"".join(split[1])}'

def draw(self, draw_format=DrawFormat.ASCII):
Expand Down
9 changes: 5 additions & 4 deletions splitp/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from warnings import warn
from random import choices
from networkx import dfs_tree
from splitp import constants
from splitp import constants, alignment


def evolve_pattern(tree, model=None):
Expand Down Expand Up @@ -35,7 +35,7 @@ def __evolve_on_subtree(subtree, state):
result = {
pair.split(":")[0]: pair.split(":")[1] for pair in result_string.split(",")
}
taxa = tree.get_taxa()
taxa = tree.taxa
return "".join(result[k] for k in sorted(result.keys(), key=taxa.index))


Expand All @@ -52,6 +52,7 @@ def generate_alignment(tree, model, sequence_length):
counts.keys(), key=lambda p: [constants.DNA_state_space.index(c) for c in p]
):
probs[k] = counts[k] / float(sequence_length)

return probs


Expand All @@ -69,7 +70,6 @@ def draw_from_multinomial(pattern_probabilities, n):

def get_pattern_probabilities(tree, model=None):
"""Returns a full table of site-pattern probabilities (binary character set)"""
# Creating a table with binary labels and calling likelihood_start() to fill it in with probabilities
combinations = list(
itertools.product(
"".join(s for s in constants.DNA_state_space), repeat=tree.get_num_taxa()
Expand All @@ -79,6 +79,7 @@ def get_pattern_probabilities(tree, model=None):
emptyArray = {
combination: __likelihood_start(tree, combination, model) for combination in combinations
}
pattern_probs = alignment.Alignment(emptyArray, taxa=tree.taxa)
return emptyArray

pattern_probabilities = get_pattern_probabilities
Expand Down Expand Up @@ -124,7 +125,7 @@ def _to_int(p):
pattern = [
_to_int(p) for p in pattern
] # A list of indices which correspond to taxa.
taxa = tree.get_taxa() # The list of taxa.
taxa = tree.taxa # The list of taxa.
# Likelihood table for dynamic prog alg ~ lTable[node_index, character]
likelihood_table = np.array(
[[None for _ in range(4)] for _ in range(tree.get_num_nodes())]
Expand Down
Loading

0 comments on commit eb6140d

Please sign in to comment.