Skip to content

Commit

Permalink
Adding functions to combine likelihood lists with proper likelihood c…
Browse files Browse the repository at this point in the history
…alculations now
  • Loading branch information
jhellewell14 committed Jul 21, 2023
1 parent 5faf482 commit 3aa17d7
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 26 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ edition = "2021"
ndarray = "0.15.6"
rand = "0.8.5"
needletail = "0.5.1"
nalgebra = "0.32.3"

[profile.release]
debug = 1
debug = 1
36 changes: 32 additions & 4 deletions src/gen_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,19 @@ use needletail::*;
#[derive(Debug, Copy, Clone)]
pub struct Mutation(pub usize, pub f64, pub f64, pub f64, pub f64);

impl Mutation {
pub fn prod(self, r: Mutation) -> Mutation {
Mutation(self.0, self.1 * r.1, self.2 *r.2, self.3 * r.3, self.4 * r.4)
}

pub fn likelihood(self, branch_length: f64, prob_matrix: &na::Matrix4<f64>) -> Mutation {

let x = prob_matrix * na::Vector4::new(self.1, self.2, self.3, self.4);

Mutation(self.0, x[0], x[1], x[2], x[3])
}
}

pub fn char_to_mutation(i: usize, e: &char) -> Mutation {
match e {
// (A, C, G, T)
Expand Down Expand Up @@ -41,7 +54,11 @@ pub fn create_list(refseq: &[char], seq: &[char]) -> Vec<Mutation> {
}

// Combines two vectors of Mutations into a single vector
pub fn combine_lists(seq1: Option<&Vec<Mutation>>, seq2: Option<&Vec<Mutation>>) -> Vec<Mutation> {
pub fn combine_lists(seq1: Option<&Vec<Mutation>>,
seq2: Option<&Vec<Mutation>>,
branchlengths: (f64, f64),
rate_matrix: &na::Matrix4<f64>) -> Vec<Mutation> {

let mut out: Vec<Mutation> = Vec::new();
let seq1 = seq1.unwrap();
let seq2 = seq2.unwrap();
Expand All @@ -58,11 +75,22 @@ pub fn combine_lists(seq1: Option<&Vec<Mutation>>, seq2: Option<&Vec<Mutation>>)
let mut s1_loc = s1_node.unwrap().0;
let mut s2_loc = s2_node.unwrap().0;

// Branch lengths
let b1 = branchlengths.0;
let b2 = branchlengths.1;

// Probability matrices
let p1 = na::Matrix::exp(&(rate_matrix * b1));
let p2 = na::Matrix::exp(&(rate_matrix * b2));

while (s1_i > 0) | (s2_i > 0) {
match s1_loc.cmp(&s2_loc) {
Ordering::Equal => {
// There should be a step here to calculate combined likelihoods
out.push(Mutation(s1_loc, 5.0, 5.0, 5.0, 5.0));

out.push(s1_node.unwrap()
.likelihood(b1, &p1)
.prod(s2_node.unwrap().likelihood(b2, &p2)));

s1_i -= 1;
s1_node = seq1.get(s1_i);
Expand All @@ -73,13 +101,13 @@ pub fn combine_lists(seq1: Option<&Vec<Mutation>>, seq2: Option<&Vec<Mutation>>)
s2_loc = s2_node.unwrap().0;
},
Ordering::Greater => {
out.push(*s1_node.unwrap());
out.push(s1_node.unwrap().likelihood(b1, &p1));
s1_i -= 1;
s1_node = seq1.get(s1_i);
s1_loc = s1_node.unwrap().0;
},
Ordering::Less => {
out.push(*s2_node.unwrap());
out.push(s2_node.unwrap().likelihood(b2, &p2));
s2_i -= 1;
s2_node = seq2.get(s2_i);
s2_loc = s2_node.unwrap().0;
Expand Down
45 changes: 25 additions & 20 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,37 +13,42 @@ use ndarray::ViewRepr;
use ndarray::array;
use needletail::*;
use std::time::Instant;
extern crate nalgebra as na;
// extern crate blas;
// extern crate openblas_src;
// use approx::assert_ulps_eq;

fn main() {
let start = Instant::now();


let filename = "listeria0.aln";
let ll = create_genetic_data(filename);

let tr = phylo2vec_quad(vec![0, 1, 0]);
let mut q: na::Matrix4<f64> = na::Matrix4::new(-2.0, 1.0, 1.0, 1.0,
1.0, -2.0, 1.0, 1.0,
1.0, 1.0, -2.0, 1.0,
1.0, 1.0, 1.0 , -2.0);
// println!("{:?}", ll.likelihood_lists.get_mut());

let combo = combine_lists(ll.likelihood_lists.get(0), ll.likelihood_lists.get(1));
let start = Instant::now();
let combo = combine_lists(ll.likelihood_lists.get(0),
ll.likelihood_lists.get(1),
(tr.get_branchlength(0), tr.get_branchlength(1)),
&q);
// println!("seq1: {:?}",ll.likelihood_lists);
// println!("combined seq: {:?}", combo);

// let q = ndarray::Array2::<f64>::eye(4);
let q = array![[-2., 1., 1., 1.], [1., -2., 1., 1.], [1., 1., -2., 1.], [1., 1., 1., -2.]];
println!("{:?}", q);
let mut p = ndarray::Array::<f64, _>::zeros((4, 4));
println!("combined seq: {:?}", combo);

// let a = Mutation(1, 0.55, 0.15, 0.15, 0.1);
// let b = Mutation(1, 0.35, 0.25, 0.25, 0.1);

let a = Mutation(1, 0.25, 0.25, 0.25, 0.25);
let b: ArrayBase<ViewRepr<&f64>, ndarray::Dim<[usize; 1]>> = q.row(0);
println!("{:?}", b);

pub fn likelihood_sum(a: Mutation, b: ArrayBase<ViewRepr<&f64>, ndarray::Dim<[usize; 1]>>) -> Mutation {
Mutation(a.0, a.1 * b[[0]], a.2 * b [[1]], a.3 * b[[2]], a.4 * b[[3]])
}

println!("{:?}", likelihood_sum(a, b));
// println!("{:?}", q);

// let temp = a.likelihood(0.125, &q)
// .prod(b.likelihood(0.5, &q));
// println!("{:?}", temp);

// println!("{:?}", ll.likelihood_lists.get(0).unwrap().get(0));
// let tr = phylo2vec_quad(vec![0, 1, 0]);

// let tr2 = phylo2vec_lin(vec![0, 0, 2, 3], false);

// println!("{:?}", tr);
Expand All @@ -57,7 +62,7 @@ fn main() {
// println!("{:?}", tr);
// println!("{:?}", tr2);

// for el in tr.postorder(tr.get_root()) {
// for el in tr.postorder_notips(tr.get_root()) {
// println!("{}", el);
// }

Expand Down
4 changes: 4 additions & 0 deletions src/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pub struct Node {
pub index: usize,
pub depth: usize,
pub ll_list: Option<Vec<Mutation>>,
pub branch_length: f64,
}

impl Node {
Expand All @@ -18,6 +19,7 @@ impl Node {
index: usize,
depth: usize,
ll_list: Option<Vec<Mutation>>,
branch_length: f64,
) -> Node {
Node {
children,
Expand All @@ -26,6 +28,7 @@ impl Node {
depth,
tip: matches!(children, (None, None)),
ll_list,
branch_length,
}
}

Expand Down Expand Up @@ -80,6 +83,7 @@ impl Default for Node {
index: 0,
depth: 0,
ll_list: None,
branch_length: 1.0,
}
}
}
6 changes: 5 additions & 1 deletion src/tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ impl<'a> Tree {
}
}

pub fn get_branchlength(&self, index: usize) -> f64 {
self.get_node(index).unwrap().branch_length
}

// Returns vector of nodes in tree that are tips
pub fn get_tips(&self) -> Vec<&Node> {
self.nodes.iter().filter(|n| n.tip).collect()
Expand Down Expand Up @@ -106,7 +110,7 @@ impl<'a> Tree {
dpth = self.get_node(par).unwrap().depth + 1;
}

self.nodes[index] = Node::new(parent, (None, None), index, dpth, None);
self.nodes[index] = Node::new(parent, (None, None), index, dpth, None, 1.0);
}

pub fn get_handedness(&self, index: usize) -> Handedness {
Expand Down

0 comments on commit 3aa17d7

Please sign in to comment.