diff --git a/include/barry/models/defm/counters.hpp b/include/barry/models/defm/counters.hpp index 54fb15897..4d0774557 100644 --- a/include/barry/models/defm/counters.hpp +++ b/include/barry/models/defm/counters.hpp @@ -17,19 +17,27 @@ ///@} - -#define MAKE_DEFM_HASHER(hasher,a,cov) barry::Hasher_fun_type \ - hasher = [cov](const DEFMArray & array, DEFMCounterData * d) { \ - std::vector< double > res; \ - /* Adding the column feature */ \ - for (size_t i = 0u; i < array.nrow(); ++i) \ - res.push_back(array.D()(i, cov)); \ - /* Adding the fixed dims */ \ - for (size_t i = 0u; i < (array.nrow() - 1); ++i) \ - for (size_t j = 0u; j < array.ncol(); ++j) \ - res.push_back(array(i, j)); \ - return res;\ - }; +/** + * @brief Data for the counters + * + * @details This class is used to store the data for the counters. It is + * used by the `Counters` class. + * + */ +#define MAKE_DEFM_HASHER(hasher,a,cov) \ + barry::Hasher_fun_type \ + hasher = [cov](const DEFMArray & array, DEFMCounterData * d) -> \ + std::vector< double > { \ + std::vector< double > res; \ + /* Adding the column feature */ \ + for (size_t i = 0u; i < array.nrow(); ++i) \ + res.push_back(array.D()(i, cov)); \ + /* Adding the fixed dims */ \ + for (size_t i = 0u; i < (array.nrow() - 1); ++i) \ + for (size_t j = 0u; j < array.ncol(); ++j) \ + res.push_back(array(i, j)); \ + return res;\ + }; /**@name Macros for defining counters @@ -147,6 +155,18 @@ inline void counter_ones( } + +/** + * Calculates the logit intercept for the DEFM model. + * + * @param counters A pointer to the DEFMCounters object. + * @param n_y The number of response variables. + * @param which A vector of indices indicating which response variables to use. If empty, all response variables are used. + * @param covar_index The index of the covariate to use as the intercept. + * @param vname The name of the variable to use as the intercept. If empty, the intercept is set to zero. + * @param x_names A pointer to a vector of strings containing the names of the covariates. + * @param y_names A pointer to a vector of strings containing the names of the response variables. + */ inline void counter_logit_intercept( DEFMCounters * counters, size_t n_y, @@ -611,14 +631,43 @@ inline void counter_transition_formula( } if (covar_index < 0) - throw std::logic_error("The covariate name was not found in the list of covariates."); + throw std::logic_error( + std::string("The covariate name '") + + covar_name + + std::string("' was not found in the list of covariates.") + ); + + } + + // Checking the number of coords, could be single intercept + if (coords.size() == 1u) + { + + // Getting the column + size_t coord = static_cast< size_t >( + std::floor( + static_cast(coords[0u]) / static_cast(m_order + 1) + )); + + counter_logit_intercept( + counters, n_y, {coord}, + covar_index, + vname, + x_names, + y_names + ); + + } + else + { + + counter_transition( + counters, coords, signs, m_order, n_y, covar_index, vname, + x_names, y_names + ); } - counter_transition( - counters, coords, signs, m_order, n_y, covar_index, vname, - x_names, y_names - ); } diff --git a/include/barry/models/defm/formula.hpp b/include/barry/models/defm/formula.hpp index 58bbe3f3e..4c7af40df 100644 --- a/include/barry/models/defm/formula.hpp +++ b/include/barry/models/defm/formula.hpp @@ -59,12 +59,12 @@ inline void defm_motif_parser( std::regex pattern_intercept( std::string("\\{\\s*[01]?y[0-9]+(_[0-9]+)?(\\s*,\\s*[01]?y[0-9]+(_[0-9]+)?)*\\s*\\}") + - std::string("(\\s*\\|\\s*[^\\s]+([(][^)]+[)])?\\s*)?") + std::string("(\\s*x\\s*[^\\s]+([(].+[)])?\\s*)?") ); std::regex pattern_transition( std::string("\\{\\s*[01]?y[0-9]+(_[0-9]+)?(\\s*,\\s*[01]?y[0-9]+(_[0-9]+)?)*\\}\\s*(>)\\s*") + std::string("\\{\\s*[01]?y[0-9]+(_[0-9]+)?(\\s*,\\s*[01]?y[0-9]+(_[0-9]+)?)*\\s*\\}") + - std::string("(\\s*\\|\\s*[^\\s]+([(][^)]+[)])?\\s*)?") + std::string("(\\s*x\\s*[^\\s]+([(].+[)])?\\s*)?") ); auto empty = std::sregex_iterator(); @@ -82,8 +82,7 @@ inline void defm_motif_parser( throw std::logic_error("Transition effects are only valid when the data is a markov process."); // Matching the pattern '| [no spaces]$' - std::regex pattern_conditional(".+\\|\\s*([^\\s(]+)([(][^)]+[)])?\\s*$"); - + std::regex pattern_conditional(".+[}]\\s+x\\s+([^(]+)([(][^)]+[)])?\\s*$"); std::smatch condmatch; std::regex_match(formula, condmatch, pattern_conditional); // Extracting the [no_spaces] part of the conditional @@ -98,7 +97,6 @@ inline void defm_motif_parser( } - // Will indicate where the arrow is located at size_t arrow_position = match.position(4u); @@ -183,7 +181,7 @@ inline void defm_motif_parser( { // Matching the pattern '| [no spaces]$' - std::regex pattern_conditional(".+\\|\\s*([^\\s(]+)([(][^)]+[)])?\\s*$"); + std::regex pattern_conditional(".+[}]\\s+x\\s+([^(]+)([(][^)]+[)])?\\s*$"); std::smatch condmatch; std::regex_match(formula, condmatch, pattern_conditional); // Extracting the [no_spaces] part of the conditional diff --git a/tests/15-defm-counts.cpp b/tests/15-defm-counts.cpp index 004194ab5..6a4240f52 100644 --- a/tests/15-defm-counts.cpp +++ b/tests/15-defm-counts.cpp @@ -175,17 +175,17 @@ BARRY_TEST_CASE("DEFM counts work", "[DEFM counts]") { "", &model3.get_X_names(), &model3.get_Y_names() ); counter_transition_formula( - model3.get_counters(), "{y0} | X2", 2, 3, -1, + model3.get_counters(), "{y0} x X2", 2, 3, -1, "", &model3.get_X_names(), &model3.get_Y_names() ); counter_transition_formula( - model3.get_counters(), "{0y0_0} > {1y0, 1y2} | X2(Space 1)", 2, 3, -1, + model3.get_counters(), "{0y0_0} > {1y0, 1y2} x X2(Space 1)", 2, 3, -1, "", &model3.get_X_names(), &model3.get_Y_names() ); counter_transition_formula( - model3.get_counters(), "{0y0_0} > {1y0, 1y2} | X1(excess)", 2, 3, -1, + model3.get_counters(), "{0y0_0} > {1y0, 1y2} x X1(excess)", 2, 3, -1, "", &model3.get_X_names(), &model3.get_Y_names() );