Skip to content

Commit

Permalink
Bug fixes
Browse files Browse the repository at this point in the history
- fix square root in sparse split score
- fix reconstruct pattern function used by the subflattening function to work with non-integer (or specifically formatted t_0, t_1 etc) taxon labels.
  • Loading branch information
js51 committed Feb 16, 2024
1 parent 2111356 commit 31d941c
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 15 deletions.
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.2.0
0.3.0
18 changes: 10 additions & 8 deletions splitp/constructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ def subflattening(split, pattern_probabilities, data=None):
re-usable information to reduce the number of calls to the multiplications function.
"""
state_space = constants.DNA_state_space
taxa = sorted(set(split[0]) | set(split[1]))
taxa_indexer = {taxon: i for i, taxon in enumerate(taxa)}

if data is None:
data = {}
try:
Expand All @@ -107,7 +110,7 @@ def subflattening(split, pattern_probabilities, data=None):

if isinstance(split, str):
split = split.split("|")
sp1, sp2 = len(split[0]), len(split[1])
sp1, sp2 = map(len, split)
subflattening = [[0 for _ in range(3 * sp2 + 1)] for _ in range(3 * sp1 + 1)]
try:
row_labels = labels[sp1]
Expand All @@ -126,7 +129,7 @@ def subflattening(split, pattern_probabilities, data=None):
)
for r, row in enumerate(row_labels):
for c, col in enumerate(col_labels):
pattern = __reconstruct_pattern(split, row, col)
pattern = __reconstruct_pattern(split, row, col, taxa_indexer)
signed_sum = 0
for table_pattern, value in pattern_probabilities.items():
try:
Expand Down Expand Up @@ -168,11 +171,10 @@ def __subflattening_labels_generator(length):
yield "".join(special_state for _ in range(n))


def __reconstruct_pattern(split, row_label, col_label):
n = len(split[0]) + len(split[1])
def __reconstruct_pattern(split, row_label, col_label, taxa_indexer):
n = len(taxa_indexer)
pattern = {}
for splindex, loc in enumerate(split[0]):
pattern[int(str(loc), n) if len(str(loc)) == 1 else int(str(loc)[1:])] = row_label[splindex]
for splindex, loc in enumerate(split[1]):
pattern[int(str(loc), n) if len(str(loc)) == 1 else int(str(loc)[1:])] = col_label[splindex]
for split_half, label in zip(split, (row_label, col_label)):
for split_index, taxon in enumerate(split_half):
pattern[taxa_indexer[taxon]] = label[split_index]
return "".join(pattern[i] for i in range(n))
12 changes: 6 additions & 6 deletions splitp/phylogenetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from splitp import splits
from splitp.matrix import is_sparse, frobenius_norm
import scipy
from math import sqrt


def parsimony_score(self, pattern):
Expand Down Expand Up @@ -94,7 +95,7 @@ def hartigan_algorithm(self, pattern):
return score


def erickson_SVD(alignment, taxa=None, method=sp.Method.flattening):
def erickson_SVD(alignment, taxa=None, method=sp.Method.flattening, show_work=False):
all_scores = {}
subflattening_data = {}

Expand Down Expand Up @@ -135,6 +136,7 @@ def _erickstep(all_taxa, alignment):

all_scores[split] = score
scores[pair] = (pair, split, score)
if show_work: print(f"Scores: {scores}")
best_pair, best_split, best_score = min(scores.values(), key=lambda x: x[2])
return best_pair, best_split, best_score

Expand Down Expand Up @@ -299,7 +301,8 @@ def __sparse_split_score(
)
squared_singular_values = [sigma**2 for sigma in largest_four_singular_values]
norm = frobenius_norm(matrix, data_table=data_table_for_frob_norm)
return (1 - (sum(squared_singular_values) / (norm**2))) ** (1 / 2)
operand = 1 - (sum(squared_singular_values) / (norm**2))
return sqrt(operand if operand > 0 else 0)


def split_score(
Expand Down Expand Up @@ -497,7 +500,4 @@ def midpoint_rooting(networkx_tree, weight_label="weight"):
networkx_tree.edges[new_node, path[k]][weight_label] = current_dist - midpoint_dist
networkx_tree.edges[new_node, path[k + 1]][weight_label] = midpoint_dist - prev_dist
break





0 comments on commit 31d941c

Please sign in to comment.