From 741f902f16f643ab80442fe170ff8ff957cd5819 Mon Sep 17 00:00:00 2001 From: "George G. Vega Yon" Date: Wed, 25 Oct 2023 16:39:22 -0600 Subject: [PATCH] Adding more syntax to DEFM and tests --- include/barry/model-meat.hpp | 6 +-- include/barry/models/defm/counters.hpp | 33 +++++++++--- include/barry/models/defm/formula.hpp | 66 ++++++++++++++++------- tests/15-defm-counts.cpp | 72 ++++++++++++++++++++++++++ tests/16-defm-counts-with-formulas.cpp | 15 +++--- 5 files changed, 158 insertions(+), 34 deletions(-) diff --git a/include/barry/model-meat.hpp b/include/barry/model-meat.hpp index 89a3ad2e9..c3784bd16 100644 --- a/include/barry/model-meat.hpp +++ b/include/barry/model-meat.hpp @@ -1272,8 +1272,8 @@ inline void Model(nterms() + 1); - min_v /= static_cast(nterms() + 1); + // max_v /= static_cast(nterms() + 1); + // min_v /= static_cast(nterms() + 1); printf_barry("Num. of Arrays : %li\n", this->size()); printf_barry("Support size : %li\n", this->size_unique()); @@ -1282,7 +1282,7 @@ inline void Model(std::accumulate(pset_sizes.begin(), pset_sizes.end(), 0u)) ); } printf_barry("Transform. Fun. : %s\n", transform_model_fun ? "yes": "no"); diff --git a/include/barry/models/defm/counters.hpp b/include/barry/models/defm/counters.hpp index 94cea3e5a..54fb15897 100644 --- a/include/barry/models/defm/counters.hpp +++ b/include/barry/models/defm/counters.hpp @@ -80,7 +80,7 @@ barry::Rule_fun_type a = \ */ inline void counter_ones( DEFMCounters * counters, - int covar_index = -1, + int covar_index = -1, std::string vname = "", const std::vector< std::string > * x_names = nullptr ) @@ -103,7 +103,6 @@ inline void counter_ones( }; - if (vname == "") { if (x_names != nullptr) @@ -119,9 +118,9 @@ inline void counter_ones( "Overall number of ones" ); - - - } else { + } + else + { DEFM_COUNTER_LAMBDA(count_ones) { @@ -589,11 +588,33 @@ inline void counter_transition_formula( std::vector< size_t > coords; std::vector< bool > signs; + std::string covar_name = ""; defm_motif_parser( - formula, coords, signs, m_order, n_y + formula, coords, signs, m_order, n_y, covar_name, vname ); + if ((covar_name != "") && (covar_index >= 0)) + throw std::logic_error("Can't have both a formula and a covariate index."); + + if (covar_name != "") + { + + if (x_names != nullptr) + { + for (size_t i = 0u; i < x_names->size(); ++i) + if (x_names->operator[](i) == covar_name) + { + covar_index = static_cast(i); + break; + } + } + + if (covar_index < 0) + throw std::logic_error("The covariate name was not found in the list of covariates."); + + } + counter_transition( counters, coords, signs, m_order, n_y, covar_index, vname, x_names, y_names diff --git a/include/barry/models/defm/formula.hpp b/include/barry/models/defm/formula.hpp index b8a0739e3..58bbe3f3e 100644 --- a/include/barry/models/defm/formula.hpp +++ b/include/barry/models/defm/formula.hpp @@ -18,7 +18,7 @@ * ## Intercept effects * * Intercept effects only involve a single set of curly brackets. Using the - * 'greater-than' symbol (i.e., '<') is only for transition effects. When + * 'greater-than' symbol (i.e., '>') is only for transition effects. When * specifying intercept effects, users can skip the `row_id`, e.g., * `y0_0` is equivalent to `y0`. If the passed `row id` is different from * the Markov order, i.e., `row_id != m_order`, then the function returns @@ -48,7 +48,9 @@ inline void defm_motif_parser( std::vector< size_t > & locations, std::vector< bool > & signs, size_t m_order, - size_t y_ncol + size_t y_ncol, + std::string & covar_name, + std::string & vname ) { // Resetting the results @@ -56,11 +58,13 @@ inline void defm_motif_parser( signs.clear(); std::regex pattern_intercept( - "\\{\\s*0?y[0-9]+(_[0-9]+)?(\\s*,\\s*0?y[0-9]+(_[0-9]+)?)*\\s*\\}" + std::string("\\{\\s*[01]?y[0-9]+(_[0-9]+)?(\\s*,\\s*[01]?y[0-9]+(_[0-9]+)?)*\\s*\\}") + + std::string("(\\s*\\|\\s*[^\\s]+([(][^)]+[)])?\\s*)?") ); std::regex pattern_transition( - std::string("\\{\\s*0?y[0-9]+(_[0-9]+)?(\\s*,\\s*0?y[0-9]+(_[0-9]+)?)*\\}\\s*(>)\\s*") + - std::string("\\{\\s*0?y[0-9]+(_[0-9]+)?(\\s*,\\s*0?y[0-9]+(_[0-9]+)?)*\\s*\\}") + std::string("\\{\\s*[01]?y[0-9]+(_[0-9]+)?(\\s*,\\s*[01]?y[0-9]+(_[0-9]+)?)*\\}\\s*(>)\\s*") + + std::string("\\{\\s*[01]?y[0-9]+(_[0-9]+)?(\\s*,\\s*[01]?y[0-9]+(_[0-9]+)?)*\\s*\\}") + + std::string("(\\s*\\|\\s*[^\\s]+([(][^)]+[)])?\\s*)?") ); auto empty = std::sregex_iterator(); @@ -77,6 +81,24 @@ inline void defm_motif_parser( if (m_order == 0) throw std::logic_error("Transition effects are only valid when the data is a markov process."); + // Matching the pattern '| [no spaces]$' + std::regex pattern_conditional(".+\\|\\s*([^\\s(]+)([(][^)]+[)])?\\s*$"); + + std::smatch condmatch; + std::regex_match(formula, condmatch, pattern_conditional); + // Extracting the [no_spaces] part of the conditional + if (!condmatch.empty()) + { + covar_name = condmatch[1].str(); + vname = condmatch[2].str(); + + // Removing starting and ending parenthesis + if (vname != "") + vname = vname.substr(1, vname.size() - 2); + + } + + // Will indicate where the arrow is located at size_t arrow_position = match.position(4u); @@ -92,13 +114,9 @@ inline void defm_motif_parser( size_t current_location = i->position(0u); // First value true/false - bool is_positive; - if (i->operator[](1u).str() == "") - is_positive = true; - else if (i->operator[](1u).str() == "0") + bool is_positive = true; + if (i->operator[](1u).str() == "0") is_positive = false; - else - throw std::logic_error("The number preceding y should be either none or zero."); // Variable position size_t y_col = std::stoul(i->operator[](2u).str()); @@ -161,7 +179,23 @@ inline void defm_motif_parser( } std::regex_match(formula, match, pattern_intercept); - if (!match.empty()){ + if (!match.empty()) + { + + // Matching the pattern '| [no spaces]$' + std::regex pattern_conditional(".+\\|\\s*([^\\s(]+)([(][^)]+[)])?\\s*$"); + std::smatch condmatch; + std::regex_match(formula, condmatch, pattern_conditional); + // Extracting the [no_spaces] part of the conditional + if (!condmatch.empty()) + { + covar_name = condmatch[1].str(); + vname = condmatch[2].str(); + + // Removing starting and ending parenthesis + if (vname != "") + vname = vname.substr(1, vname.size() - 2); + } // This pattern will match std::regex pattern("(0?)y([0-9]+)(_([0-9]+))?"); @@ -172,13 +206,9 @@ inline void defm_motif_parser( { // First value true/false - bool is_positive; - if (i->operator[](1u).str() == "") - is_positive = true; - else if (i->operator[](1u).str() == "0") + bool is_positive = true; + if (i->operator[](1u).str() == "0") is_positive = false; - else - throw std::logic_error("The number preceding y should be either none or zero."); // Variable position size_t y_col = std::stoul(i->operator[](2u).str()); diff --git a/tests/15-defm-counts.cpp b/tests/15-defm-counts.cpp index eb23439cb..004194ab5 100644 --- a/tests/15-defm-counts.cpp +++ b/tests/15-defm-counts.cpp @@ -85,6 +85,8 @@ BARRY_TEST_CASE("DEFM counts work", "[DEFM counts]") { // Creating the model, need to pass the data DEFM model(&id[0u], &Y[0u], &X[0u], 8, 3, 2, 2); model.get_model().store_psets(); + + model.set_names({"A", "B", "C"}, {"X1", "X2"}); // Generating the model specification counter_ones(model.get_model().get_counters()); @@ -158,6 +160,66 @@ BARRY_TEST_CASE("DEFM counts work", "[DEFM counts]") { std::vector< double > params2(model2.get_model().nterms(), 0.0); model2.simulate(params2, &out_sim2[0u]); + // Checking formulas ------------------------------------------------------- + + // Creating the model, need to pass the data + DEFM model3(&id[0u], &Y[0u], &X[0u], 8, 3, 2, 2); + model3.get_model().store_psets(); + + model3.set_names({"A", "B", "C"}, {"X1", "X2"}); + + // Generating the model specification + counter_ones(model3.get_model().get_counters()); + counter_transition_formula( + model3.get_counters(), "{y0}", 2, 3, -1, + "", &model3.get_X_names(), &model3.get_Y_names() + ); + counter_transition_formula( + model3.get_counters(), "{y0} | X2", 2, 3, -1, + "", &model3.get_X_names(), &model3.get_Y_names() + ); + + counter_transition_formula( + model3.get_counters(), "{0y0_0} > {1y0, 1y2} | X2(Space 1)", 2, 3, -1, + "", &model3.get_X_names(), &model3.get_Y_names() + ); + + counter_transition_formula( + model3.get_counters(), "{0y0_0} > {1y0, 1y2} | X1(excess)", 2, 3, -1, + "", &model3.get_X_names(), &model3.get_Y_names() + ); + + + model3.init(); + + DEFM model3b(&id[0u], &Y[0u], &X[0u], 8, 3, 2, 2); + model3b.get_model().store_psets(); + + model3b.set_names({"A", "B", "C"}, {"X1", "X2"}); + + // Generating the model specification + counter_ones(model3b.get_model().get_counters()); + counter_transition_formula( + model3b.get_counters(), "{y0}", 2, 3, -1, + "", &model3b.get_X_names(), &model3b.get_Y_names() + ); + counter_transition_formula( + model3b.get_counters(), "{y0}", 2, 3, 1, + "", &model3b.get_X_names(), &model3b.get_Y_names() + ); + + counter_transition_formula( + model3b.get_counters(), "{0y0_0} > {1y0, 1y2}", 2, 3, 1, + "Space 1", &model3b.get_X_names(), &model3b.get_Y_names() + ); + + counter_transition_formula( + model3b.get_counters(), "{0y0_0} > {1y0, 1y2}", 2, 3, 0, + "excess", &model3b.get_X_names(), &model3b.get_Y_names() + ); + + model3b.init(); + #ifndef CATCH_CONFIG_MAIN auto res = model.get_model().likelihood_total(par0, true); model.get_model().print(); @@ -165,7 +227,17 @@ BARRY_TEST_CASE("DEFM counts work", "[DEFM counts]") { model.get_model().print_stats(1u); model.get_model().print_stats(2u); (void) model.get_model().get_stats_target(); + model.print(); + model3.print(); + model3b.print(); return 0; + #else + + auto terms3 = model3.get_counters()->get_names(); + auto terms3b = model3b.get_counters()->get_names(); + + REQUIRE_THAT(terms3, Catch::Equals(terms3b)); + #endif diff --git a/tests/16-defm-counts-with-formulas.cpp b/tests/16-defm-counts-with-formulas.cpp index b15888910..64b8f9aeb 100644 --- a/tests/16-defm-counts-with-formulas.cpp +++ b/tests/16-defm-counts-with-formulas.cpp @@ -29,13 +29,14 @@ BARRY_TEST_CASE("DEFM motif formula", "[DEFM motif formula]") { std::vector< bool > sign_2 = {true, true, true, false}; std::vector< bool > sign_3 = {true}; - defm::defm_motif_parser("{y1, y3}", res_locations1a, res_sign1a, 1, 4); - defm::defm_motif_parser("{y1_0, y3} > {y1_1, 0y3_1}", res_locations2a, res_sign2a, 1, 4); - defm::defm_motif_parser("{y1}", res_locations3a, res_sign3a, 1, 4); - - defm::defm_motif_parser("{y1_1, y3}", res_locations1b, res_sign1b, 1, 4); - defm::defm_motif_parser("{y1_0, y3_0} > {y1, 0y3_1}", res_locations2b, res_sign2b, 1, 4); - defm::defm_motif_parser("{y1_1}", res_locations3b, res_sign3b, 1, 4); + std::string covar_name; + defm::defm_motif_parser("{y1, y3}", res_locations1a, res_sign1a, 1, 4, covar_name, covar_name); + defm::defm_motif_parser("{y1_0, y3} > {y1_1, 0y3_1}", res_locations2a, res_sign2a, 1, 4, covar_name, covar_name); + defm::defm_motif_parser("{y1}", res_locations3a, res_sign3a, 1, 4, covar_name, covar_name); + + defm::defm_motif_parser("{y1_1, y3}", res_locations1b, res_sign1b, 1, 4, covar_name, covar_name); + defm::defm_motif_parser("{y1_0, y3_0} > {y1, 0y3_1}", res_locations2b, res_sign2b, 1, 4, covar_name, covar_name); + defm::defm_motif_parser("{y1_1}", res_locations3b, res_sign3b, 1, 4, covar_name, covar_name); #ifdef CATCH_CONFIG_MAIN