Skip to content

Commit

Permalink
Formulas renaming interactions
Browse files Browse the repository at this point in the history
  • Loading branch information
gvegayon committed Nov 1, 2023
1 parent 741f902 commit eecdcd8
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 27 deletions.
85 changes: 67 additions & 18 deletions include/barry/models/defm/counters.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,27 @@

///@}


#define MAKE_DEFM_HASHER(hasher,a,cov) barry::Hasher_fun_type<DEFMArray,DEFMCounterData> \
hasher = [cov](const DEFMArray & array, DEFMCounterData * d) { \
std::vector< double > res; \
/* Adding the column feature */ \
for (size_t i = 0u; i < array.nrow(); ++i) \
res.push_back(array.D()(i, cov)); \
/* Adding the fixed dims */ \
for (size_t i = 0u; i < (array.nrow() - 1); ++i) \
for (size_t j = 0u; j < array.ncol(); ++j) \
res.push_back(array(i, j)); \
return res;\
};
/**
* @brief Data for the counters
*
* @details This class is used to store the data for the counters. It is
* used by the `Counters` class.
*
*/
#define MAKE_DEFM_HASHER(hasher,a,cov) \
barry::Hasher_fun_type<DEFMArray, DEFMCounterData> \
hasher = [cov](const DEFMArray & array, DEFMCounterData * d) -> \
std::vector< double > { \
std::vector< double > res; \
/* Adding the column feature */ \
for (size_t i = 0u; i < array.nrow(); ++i) \
res.push_back(array.D()(i, cov)); \
/* Adding the fixed dims */ \
for (size_t i = 0u; i < (array.nrow() - 1); ++i) \
for (size_t j = 0u; j < array.ncol(); ++j) \
res.push_back(array(i, j)); \
return res;\
};


/**@name Macros for defining counters
Expand Down Expand Up @@ -147,6 +155,18 @@ inline void counter_ones(

}


/**
* Calculates the logit intercept for the DEFM model.
*
* @param counters A pointer to the DEFMCounters object.
* @param n_y The number of response variables.
* @param which A vector of indices indicating which response variables to use. If empty, all response variables are used.
* @param covar_index The index of the covariate to use as the intercept.
* @param vname The name of the variable to use as the intercept. If empty, the intercept is set to zero.
* @param x_names A pointer to a vector of strings containing the names of the covariates.
* @param y_names A pointer to a vector of strings containing the names of the response variables.
*/
inline void counter_logit_intercept(
DEFMCounters * counters,
size_t n_y,
Expand Down Expand Up @@ -611,14 +631,43 @@ inline void counter_transition_formula(
}

if (covar_index < 0)
throw std::logic_error("The covariate name was not found in the list of covariates.");
throw std::logic_error(
std::string("The covariate name '") +
covar_name +
std::string("' was not found in the list of covariates.")
);

}

// Checking the number of coords, could be single intercept
if (coords.size() == 1u)
{

// Getting the column
size_t coord = static_cast< size_t >(
std::floor(
static_cast<double>(coords[0u]) / static_cast<double>(m_order + 1)
));

counter_logit_intercept(
counters, n_y, {coord},
covar_index,
vname,
x_names,
y_names
);

}
else
{

counter_transition(
counters, coords, signs, m_order, n_y, covar_index, vname,
x_names, y_names
);

}

counter_transition(
counters, coords, signs, m_order, n_y, covar_index, vname,
x_names, y_names
);

}

Expand Down
10 changes: 4 additions & 6 deletions include/barry/models/defm/formula.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,12 @@ inline void defm_motif_parser(

std::regex pattern_intercept(
std::string("\\{\\s*[01]?y[0-9]+(_[0-9]+)?(\\s*,\\s*[01]?y[0-9]+(_[0-9]+)?)*\\s*\\}") +
std::string("(\\s*\\|\\s*[^\\s]+([(][^)]+[)])?\\s*)?")
std::string("(\\s*x\\s*[^\\s]+([(].+[)])?\\s*)?")
);
std::regex pattern_transition(
std::string("\\{\\s*[01]?y[0-9]+(_[0-9]+)?(\\s*,\\s*[01]?y[0-9]+(_[0-9]+)?)*\\}\\s*(>)\\s*") +
std::string("\\{\\s*[01]?y[0-9]+(_[0-9]+)?(\\s*,\\s*[01]?y[0-9]+(_[0-9]+)?)*\\s*\\}") +
std::string("(\\s*\\|\\s*[^\\s]+([(][^)]+[)])?\\s*)?")
std::string("(\\s*x\\s*[^\\s]+([(].+[)])?\\s*)?")
);

auto empty = std::sregex_iterator();
Expand All @@ -82,8 +82,7 @@ inline void defm_motif_parser(
throw std::logic_error("Transition effects are only valid when the data is a markov process.");

// Matching the pattern '| [no spaces]$'
std::regex pattern_conditional(".+\\|\\s*([^\\s(]+)([(][^)]+[)])?\\s*$");

std::regex pattern_conditional(".+[}]\\s+x\\s+([^(]+)([(][^)]+[)])?\\s*$");
std::smatch condmatch;
std::regex_match(formula, condmatch, pattern_conditional);
// Extracting the [no_spaces] part of the conditional
Expand All @@ -98,7 +97,6 @@ inline void defm_motif_parser(

}


// Will indicate where the arrow is located at
size_t arrow_position = match.position(4u);

Expand Down Expand Up @@ -183,7 +181,7 @@ inline void defm_motif_parser(
{

// Matching the pattern '| [no spaces]$'
std::regex pattern_conditional(".+\\|\\s*([^\\s(]+)([(][^)]+[)])?\\s*$");
std::regex pattern_conditional(".+[}]\\s+x\\s+([^(]+)([(][^)]+[)])?\\s*$");
std::smatch condmatch;
std::regex_match(formula, condmatch, pattern_conditional);
// Extracting the [no_spaces] part of the conditional
Expand Down
6 changes: 3 additions & 3 deletions tests/15-defm-counts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,17 +175,17 @@ BARRY_TEST_CASE("DEFM counts work", "[DEFM counts]") {
"", &model3.get_X_names(), &model3.get_Y_names()
);
counter_transition_formula(
model3.get_counters(), "{y0} | X2", 2, 3, -1,
model3.get_counters(), "{y0} x X2", 2, 3, -1,
"", &model3.get_X_names(), &model3.get_Y_names()
);

counter_transition_formula(
model3.get_counters(), "{0y0_0} > {1y0, 1y2} | X2(Space 1)", 2, 3, -1,
model3.get_counters(), "{0y0_0} > {1y0, 1y2} x X2(Space 1)", 2, 3, -1,
"", &model3.get_X_names(), &model3.get_Y_names()
);

counter_transition_formula(
model3.get_counters(), "{0y0_0} > {1y0, 1y2} | X1(excess)", 2, 3, -1,
model3.get_counters(), "{0y0_0} > {1y0, 1y2} x X1(excess)", 2, 3, -1,
"", &model3.get_X_names(), &model3.get_Y_names()
);

Expand Down

0 comments on commit eecdcd8

Please sign in to comment.