Skip to content

Commit

Permalink
Adding more syntax to DEFM and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
gvegayon committed Oct 25, 2023
1 parent 66bf0fc commit 741f902
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 34 deletions.
6 changes: 3 additions & 3 deletions include/barry/model-meat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1272,8 +1272,8 @@ inline void Model<Array_Type,Data_Counter_Type,Data_Rule_Type, Data_Rule_Dyn_Typ
}

// The vectors in the support reflec the size of nterms x entries
max_v /= static_cast<int>(nterms() + 1);
min_v /= static_cast<int>(nterms() + 1);
// max_v /= static_cast<int>(nterms() + 1);
// min_v /= static_cast<int>(nterms() + 1);

printf_barry("Num. of Arrays : %li\n", this->size());
printf_barry("Support size : %li\n", this->size_unique());
Expand All @@ -1282,7 +1282,7 @@ inline void Model<Array_Type,Data_Counter_Type,Data_Rule_Type, Data_Rule_Dyn_Typ
if (with_pset)
{
printf_barry("Arrays in powerset : %li\n",
std::accumulate(pset_sizes.begin(), pset_sizes.end(), 0u)
static_cast<size_t>(std::accumulate(pset_sizes.begin(), pset_sizes.end(), 0u))
);
}
printf_barry("Transform. Fun. : %s\n", transform_model_fun ? "yes": "no");
Expand Down
33 changes: 27 additions & 6 deletions include/barry/models/defm/counters.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ barry::Rule_fun_type<DEFMArray, DEFMRuleDynData> a = \
*/
inline void counter_ones(
DEFMCounters * counters,
int covar_index = -1,
int covar_index = -1,
std::string vname = "",
const std::vector< std::string > * x_names = nullptr
)
Expand All @@ -103,7 +103,6 @@ inline void counter_ones(

};


if (vname == "")
{
if (x_names != nullptr)
Expand All @@ -119,9 +118,9 @@ inline void counter_ones(
"Overall number of ones"
);



} else {
}
else
{

DEFM_COUNTER_LAMBDA(count_ones)
{
Expand Down Expand Up @@ -589,11 +588,33 @@ inline void counter_transition_formula(

std::vector< size_t > coords;
std::vector< bool > signs;
std::string covar_name = "";

defm_motif_parser(
formula, coords, signs, m_order, n_y
formula, coords, signs, m_order, n_y, covar_name, vname
);

if ((covar_name != "") && (covar_index >= 0))
throw std::logic_error("Can't have both a formula and a covariate index.");

if (covar_name != "")
{

if (x_names != nullptr)
{
for (size_t i = 0u; i < x_names->size(); ++i)
if (x_names->operator[](i) == covar_name)
{
covar_index = static_cast<int>(i);
break;
}
}

if (covar_index < 0)
throw std::logic_error("The covariate name was not found in the list of covariates.");

}

counter_transition(
counters, coords, signs, m_order, n_y, covar_index, vname,
x_names, y_names
Expand Down
66 changes: 48 additions & 18 deletions include/barry/models/defm/formula.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
* ## Intercept effects
*
* Intercept effects only involve a single set of curly brackets. Using the
* 'greater-than' symbol (i.e., '<') is only for transition effects. When
* 'greater-than' symbol (i.e., '>') is only for transition effects. When
* specifying intercept effects, users can skip the `row_id`, e.g.,
* `y0_0` is equivalent to `y0`. If the passed `row id` is different from
* the Markov order, i.e., `row_id != m_order`, then the function returns
Expand Down Expand Up @@ -48,19 +48,23 @@ inline void defm_motif_parser(
std::vector< size_t > & locations,
std::vector< bool > & signs,
size_t m_order,
size_t y_ncol
size_t y_ncol,
std::string & covar_name,
std::string & vname
)
{
// Resetting the results
locations.clear();
signs.clear();

std::regex pattern_intercept(
"\\{\\s*0?y[0-9]+(_[0-9]+)?(\\s*,\\s*0?y[0-9]+(_[0-9]+)?)*\\s*\\}"
std::string("\\{\\s*[01]?y[0-9]+(_[0-9]+)?(\\s*,\\s*[01]?y[0-9]+(_[0-9]+)?)*\\s*\\}") +
std::string("(\\s*\\|\\s*[^\\s]+([(][^)]+[)])?\\s*)?")
);
std::regex pattern_transition(
std::string("\\{\\s*0?y[0-9]+(_[0-9]+)?(\\s*,\\s*0?y[0-9]+(_[0-9]+)?)*\\}\\s*(>)\\s*") +
std::string("\\{\\s*0?y[0-9]+(_[0-9]+)?(\\s*,\\s*0?y[0-9]+(_[0-9]+)?)*\\s*\\}")
std::string("\\{\\s*[01]?y[0-9]+(_[0-9]+)?(\\s*,\\s*[01]?y[0-9]+(_[0-9]+)?)*\\}\\s*(>)\\s*") +
std::string("\\{\\s*[01]?y[0-9]+(_[0-9]+)?(\\s*,\\s*[01]?y[0-9]+(_[0-9]+)?)*\\s*\\}") +
std::string("(\\s*\\|\\s*[^\\s]+([(][^)]+[)])?\\s*)?")
);

auto empty = std::sregex_iterator();
Expand All @@ -77,6 +81,24 @@ inline void defm_motif_parser(
if (m_order == 0)
throw std::logic_error("Transition effects are only valid when the data is a markov process.");

// Matching the pattern '| [no spaces]$'
std::regex pattern_conditional(".+\\|\\s*([^\\s(]+)([(][^)]+[)])?\\s*$");

std::smatch condmatch;
std::regex_match(formula, condmatch, pattern_conditional);
// Extracting the [no_spaces] part of the conditional
if (!condmatch.empty())
{
covar_name = condmatch[1].str();
vname = condmatch[2].str();

// Removing starting and ending parenthesis
if (vname != "")
vname = vname.substr(1, vname.size() - 2);

}


// Will indicate where the arrow is located at
size_t arrow_position = match.position(4u);

Expand All @@ -92,13 +114,9 @@ inline void defm_motif_parser(
size_t current_location = i->position(0u);

// First value true/false
bool is_positive;
if (i->operator[](1u).str() == "")
is_positive = true;
else if (i->operator[](1u).str() == "0")
bool is_positive = true;
if (i->operator[](1u).str() == "0")
is_positive = false;
else
throw std::logic_error("The number preceding y should be either none or zero.");

// Variable position
size_t y_col = std::stoul(i->operator[](2u).str());
Expand Down Expand Up @@ -161,7 +179,23 @@ inline void defm_motif_parser(
}

std::regex_match(formula, match, pattern_intercept);
if (!match.empty()){
if (!match.empty())
{

// Matching the pattern '| [no spaces]$'
std::regex pattern_conditional(".+\\|\\s*([^\\s(]+)([(][^)]+[)])?\\s*$");
std::smatch condmatch;
std::regex_match(formula, condmatch, pattern_conditional);
// Extracting the [no_spaces] part of the conditional
if (!condmatch.empty())
{
covar_name = condmatch[1].str();
vname = condmatch[2].str();

// Removing starting and ending parenthesis
if (vname != "")
vname = vname.substr(1, vname.size() - 2);
}

// This pattern will match
std::regex pattern("(0?)y([0-9]+)(_([0-9]+))?");
Expand All @@ -172,13 +206,9 @@ inline void defm_motif_parser(
{

// First value true/false
bool is_positive;
if (i->operator[](1u).str() == "")
is_positive = true;
else if (i->operator[](1u).str() == "0")
bool is_positive = true;
if (i->operator[](1u).str() == "0")
is_positive = false;
else
throw std::logic_error("The number preceding y should be either none or zero.");

// Variable position
size_t y_col = std::stoul(i->operator[](2u).str());
Expand Down
72 changes: 72 additions & 0 deletions tests/15-defm-counts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ BARRY_TEST_CASE("DEFM counts work", "[DEFM counts]") {
// Creating the model, need to pass the data
DEFM model(&id[0u], &Y[0u], &X[0u], 8, 3, 2, 2);
model.get_model().store_psets();

model.set_names({"A", "B", "C"}, {"X1", "X2"});

// Generating the model specification
counter_ones(model.get_model().get_counters());
Expand Down Expand Up @@ -158,14 +160,84 @@ BARRY_TEST_CASE("DEFM counts work", "[DEFM counts]") {
std::vector< double > params2(model2.get_model().nterms(), 0.0);
model2.simulate(params2, &out_sim2[0u]);

// Checking formulas -------------------------------------------------------

// Creating the model, need to pass the data
DEFM model3(&id[0u], &Y[0u], &X[0u], 8, 3, 2, 2);
model3.get_model().store_psets();

model3.set_names({"A", "B", "C"}, {"X1", "X2"});

// Generating the model specification
counter_ones(model3.get_model().get_counters());
counter_transition_formula(
model3.get_counters(), "{y0}", 2, 3, -1,
"", &model3.get_X_names(), &model3.get_Y_names()
);
counter_transition_formula(
model3.get_counters(), "{y0} | X2", 2, 3, -1,
"", &model3.get_X_names(), &model3.get_Y_names()
);

counter_transition_formula(
model3.get_counters(), "{0y0_0} > {1y0, 1y2} | X2(Space 1)", 2, 3, -1,
"", &model3.get_X_names(), &model3.get_Y_names()
);

counter_transition_formula(
model3.get_counters(), "{0y0_0} > {1y0, 1y2} | X1(excess)", 2, 3, -1,
"", &model3.get_X_names(), &model3.get_Y_names()
);


model3.init();

DEFM model3b(&id[0u], &Y[0u], &X[0u], 8, 3, 2, 2);
model3b.get_model().store_psets();

model3b.set_names({"A", "B", "C"}, {"X1", "X2"});

// Generating the model specification
counter_ones(model3b.get_model().get_counters());
counter_transition_formula(
model3b.get_counters(), "{y0}", 2, 3, -1,
"", &model3b.get_X_names(), &model3b.get_Y_names()
);
counter_transition_formula(
model3b.get_counters(), "{y0}", 2, 3, 1,
"", &model3b.get_X_names(), &model3b.get_Y_names()
);

counter_transition_formula(
model3b.get_counters(), "{0y0_0} > {1y0, 1y2}", 2, 3, 1,
"Space 1", &model3b.get_X_names(), &model3b.get_Y_names()
);

counter_transition_formula(
model3b.get_counters(), "{0y0_0} > {1y0, 1y2}", 2, 3, 0,
"excess", &model3b.get_X_names(), &model3b.get_Y_names()
);

model3b.init();

#ifndef CATCH_CONFIG_MAIN
auto res = model.get_model().likelihood_total(par0, true);
model.get_model().print();
model.get_model().print_stats(0u);
model.get_model().print_stats(1u);
model.get_model().print_stats(2u);
(void) model.get_model().get_stats_target();
model.print();
model3.print();
model3b.print();
return 0;
#else

auto terms3 = model3.get_counters()->get_names();
auto terms3b = model3b.get_counters()->get_names();

REQUIRE_THAT(terms3, Catch::Equals(terms3b));

#endif


Expand Down
15 changes: 8 additions & 7 deletions tests/16-defm-counts-with-formulas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,14 @@ BARRY_TEST_CASE("DEFM motif formula", "[DEFM motif formula]") {
std::vector< bool > sign_2 = {true, true, true, false};
std::vector< bool > sign_3 = {true};

defm::defm_motif_parser("{y1, y3}", res_locations1a, res_sign1a, 1, 4);
defm::defm_motif_parser("{y1_0, y3} > {y1_1, 0y3_1}", res_locations2a, res_sign2a, 1, 4);
defm::defm_motif_parser("{y1}", res_locations3a, res_sign3a, 1, 4);

defm::defm_motif_parser("{y1_1, y3}", res_locations1b, res_sign1b, 1, 4);
defm::defm_motif_parser("{y1_0, y3_0} > {y1, 0y3_1}", res_locations2b, res_sign2b, 1, 4);
defm::defm_motif_parser("{y1_1}", res_locations3b, res_sign3b, 1, 4);
std::string covar_name;
defm::defm_motif_parser("{y1, y3}", res_locations1a, res_sign1a, 1, 4, covar_name, covar_name);
defm::defm_motif_parser("{y1_0, y3} > {y1_1, 0y3_1}", res_locations2a, res_sign2a, 1, 4, covar_name, covar_name);
defm::defm_motif_parser("{y1}", res_locations3a, res_sign3a, 1, 4, covar_name, covar_name);

defm::defm_motif_parser("{y1_1, y3}", res_locations1b, res_sign1b, 1, 4, covar_name, covar_name);
defm::defm_motif_parser("{y1_0, y3_0} > {y1, 0y3_1}", res_locations2b, res_sign2b, 1, 4, covar_name, covar_name);
defm::defm_motif_parser("{y1_1}", res_locations3b, res_sign3b, 1, 4, covar_name, covar_name);


#ifdef CATCH_CONFIG_MAIN
Expand Down

0 comments on commit 741f902

Please sign in to comment.