Skip to content

Commit

Permalink
Merge pull request #2799 from jeromekelleher/ls-forward-native
Browse files Browse the repository at this point in the history
LS model refactoring
  • Loading branch information
jeromekelleher authored Aug 6, 2023
2 parents 11dd2f5 + 64aedbf commit b6f9872
Show file tree
Hide file tree
Showing 11 changed files with 854 additions and 451 deletions.
91 changes: 36 additions & 55 deletions c/tests/test_haplotype_matching.c
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* MIT License
*
* Copyright (c) 2019-2022 Tskit Developers
* Copyright (c) 2019-2023 Tskit Developers
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand All @@ -28,46 +28,6 @@
#include <unistd.h>
#include <stdlib.h>

/****************************************************************
* TestHMM
****************************************************************/

static double
tsk_ls_hmm_compute_normalisation_factor_site_test(tsk_ls_hmm_t *TSK_UNUSED(self))
{
return 1.0;
}

static int
tsk_ls_hmm_next_probability_test(tsk_ls_hmm_t *TSK_UNUSED(self),
tsk_id_t TSK_UNUSED(site_id), double TSK_UNUSED(p_last), bool TSK_UNUSED(is_match),
tsk_id_t TSK_UNUSED(node), double *result)
{
*result = rand();
/* printf("next proba = %f\n", *result); */
return 0;
}

static int
run_test_hmm(tsk_ls_hmm_t *hmm, int32_t *haplotype, tsk_compressed_matrix_t *output)
{
int ret = 0;

srand(1);

ret = tsk_ls_hmm_run(hmm, haplotype, tsk_ls_hmm_next_probability_test,
tsk_ls_hmm_compute_normalisation_factor_site_test, output);
if (ret != 0) {
goto out;
}
out:
return ret;
}

/****************************************************************
* TestHMM
****************************************************************/

static void
test_single_tree_missing_alleles(void)
{
Expand Down Expand Up @@ -206,6 +166,7 @@ test_single_tree_match_impossible(void)
tsk_treeseq_t ts;
tsk_ls_hmm_t ls_hmm;
tsk_compressed_matrix_t forward;
tsk_compressed_matrix_t backward;
tsk_viterbi_matrix_t viterbi;

double rho[] = { 0.0, 0.25, 0.25 };
Expand All @@ -228,8 +189,16 @@ test_single_tree_match_impossible(void)
tsk_viterbi_matrix_print_state(&viterbi, _devnull);
tsk_ls_hmm_print_state(&ls_hmm, _devnull);

ret = tsk_ls_hmm_backward(&ls_hmm, h, forward.normalisation_factor, &backward, 0);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_MATCH_IMPOSSIBLE);
tsk_compressed_matrix_print_state(&backward, _devnull);
/* tsk_compressed_matrix_print_state(&forward, stdout); */
/* tsk_compressed_matrix_print_state(&backward, stdout); */
tsk_ls_hmm_print_state(&ls_hmm, _devnull);

tsk_ls_hmm_free(&ls_hmm);
tsk_compressed_matrix_free(&forward);
tsk_compressed_matrix_free(&backward);
tsk_viterbi_matrix_free(&viterbi);
tsk_treeseq_free(&ts);
}
Expand Down Expand Up @@ -275,12 +244,15 @@ test_single_tree_errors(void)
ret = tsk_compressed_matrix_store_site(&forward, 4, 0, 0, NULL);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_SITE_OUT_OF_BOUNDS);

T[0].tree_node = -1;
T[0].value = 0;
ret = tsk_compressed_matrix_store_site(&forward, 0, 1, 1, T);
CU_ASSERT_EQUAL_FATAL(ret, 0);
ret = tsk_compressed_matrix_decode(&forward, (double *) decoded);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS);
/* FIXME disabling this tests for now because we filter out negative
* nodes when storing now, to accomodate some oddness in the initial
* conditions of the backward matrix. */
/* T[0].tree_node = -1; */
/* T[0].value = 0; */
/* ret = tsk_compressed_matrix_store_site(&forward, 0, 1, 1, T); */
/* CU_ASSERT_EQUAL_FATAL(ret, 0); */
/* ret = tsk_compressed_matrix_decode(&forward, (double *) decoded); */
/* CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); */

T[0].tree_node = 7;
T[0].value = 0;
Expand Down Expand Up @@ -443,7 +415,7 @@ test_multi_tree_exact_match(void)
int ret = 0;
tsk_treeseq_t ts;
tsk_ls_hmm_t ls_hmm;
tsk_compressed_matrix_t forward;
tsk_compressed_matrix_t forward, backward;
tsk_viterbi_matrix_t viterbi;

double rho[] = { 0.0, 0.25, 0.25 };
Expand All @@ -465,6 +437,13 @@ test_multi_tree_exact_match(void)
ret = tsk_compressed_matrix_decode(&forward, decoded_compressed_matrix);
CU_ASSERT_EQUAL_FATAL(ret, 0);

ret = tsk_ls_hmm_backward(&ls_hmm, h, forward.normalisation_factor, &backward, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);
tsk_ls_hmm_print_state(&ls_hmm, _devnull);
tsk_compressed_matrix_print_state(&backward, _devnull);
ret = tsk_compressed_matrix_decode(&backward, decoded_compressed_matrix);
CU_ASSERT_EQUAL_FATAL(ret, 0);

ret = tsk_ls_hmm_viterbi(&ls_hmm, h, &viterbi, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);
tsk_viterbi_matrix_print_state(&viterbi, _devnull);
Expand Down Expand Up @@ -492,6 +471,7 @@ test_multi_tree_exact_match(void)

tsk_ls_hmm_free(&ls_hmm);
tsk_compressed_matrix_free(&forward);
tsk_compressed_matrix_free(&backward);
tsk_viterbi_matrix_free(&viterbi);
tsk_treeseq_free(&ts);
}
Expand Down Expand Up @@ -529,7 +509,8 @@ test_caterpillar_tree_many_values(void)
int ret = 0;
tsk_ls_hmm_t ls_hmm;
tsk_compressed_matrix_t matrix;
double unused[] = { 0, 0, 0, 0, 0 };
double rho[] = { 0.1, 0.1, 0.1, 0.1, 0.1 };
double mu[] = { 0.0, 0.0, 0.0, 0.0, 0.0 };
int32_t h[] = { 0, 0, 0, 0, 0 };
tsk_size_t n[] = {
8,
Expand All @@ -542,11 +523,11 @@ test_caterpillar_tree_many_values(void)

for (j = 0; j < sizeof(n) / sizeof(*n); j++) {
ts = caterpillar_tree(n[j], 5, n[j] - 2);
ret = tsk_ls_hmm_init(&ls_hmm, ts, unused, unused, 0);
ret = tsk_ls_hmm_init(&ls_hmm, ts, rho, mu, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);
ret = tsk_compressed_matrix_init(&matrix, ts, 1 << 10, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);
ret = run_test_hmm(&ls_hmm, h, &matrix);
ret = tsk_ls_hmm_forward(&ls_hmm, h, &matrix, TSK_NO_INIT);
CU_ASSERT_EQUAL_FATAL(ret, 0);
tsk_compressed_matrix_print_state(&matrix, _devnull);
tsk_ls_hmm_print_state(&ls_hmm, _devnull);
Expand All @@ -559,13 +540,13 @@ test_caterpillar_tree_many_values(void)

j = 40;
ts = caterpillar_tree(j, 5, j - 2);
ret = tsk_ls_hmm_init(&ls_hmm, ts, unused, unused, 0);
ret = tsk_ls_hmm_init(&ls_hmm, ts, rho, mu, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);
ret = tsk_compressed_matrix_init(&matrix, ts, 1 << 20, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);
/* Short circuit this value so we can run the test in reasonable time */
ls_hmm.max_parsimony_words = 1;
ret = run_test_hmm(&ls_hmm, h, &matrix);
/* Short circuit this value so we can run the test */
ls_hmm.max_parsimony_words = 0;
ret = tsk_ls_hmm_forward(&ls_hmm, h, &matrix, TSK_NO_INIT);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_TOO_MANY_VALUES);

tsk_ls_hmm_free(&ls_hmm);
Expand Down
Loading

0 comments on commit b6f9872

Please sign in to comment.