diff --git a/DESCRIPTION b/DESCRIPTION index 1c77f87..5bcbba0 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,7 +1,7 @@ Package: tidyrules Type: Package Title: Obtain Rules from Rule Based Models as Tidy Dataframe -Version: 0.2.1 +Version: 0.2.2 Authors@R: c( person("Srikanth", "Komala Sheshachala", email = "sri.teach@gmail.com", role = c("aut", "cre")), person("Amith Kumar", "Ullur Raghavendra", email = "amith54@gmail.com", role = c("aut")) @@ -18,6 +18,8 @@ Imports: checkmate (>= 2.3.1), tidytable (>= 0.11.0), data.table (>= 1.14.6), + DescTools, + MetricsWeighted Suggests: AmesHousing (>= 0.0.3), dplyr (>= 0.8), diff --git a/NAMESPACE b/NAMESPACE index 821f68f..9947427 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -5,6 +5,7 @@ S3method(predict,ruleset) S3method(print,rulelist) S3method(print,ruleset) S3method(tidy,C5.0) +S3method(tidy,constparty) S3method(tidy,cubist) S3method(tidy,rpart) export(convert_rule_flavor) @@ -14,4 +15,6 @@ importFrom(data.table,":=") importFrom(generics,tidy) importFrom(magrittr,"%>%") importFrom(rlang,"%||%") +importFrom(stats,IQR) +importFrom(stats,weighted.mean) importFrom(utils,data) diff --git a/R/globals.R b/R/globals.R index 40c943b..46a2da6 100644 --- a/R/globals.R +++ b/R/globals.R @@ -21,6 +21,14 @@ utils::globalVariables(c(".", "rn__", "row_nbr", "pref__", - "data" + "data", + "weight", + "response", + "terminal_node_id", + "sum_weight", + "prevalence", + "winning_response", + "average", + "RMSE" ) ) \ No newline at end of file diff --git a/R/package.R b/R/package.R index cbce464..5a699a3 100644 --- a/R/package.R +++ b/R/package.R @@ -10,6 +10,8 @@ #' @importFrom rlang %||% #' @importFrom data.table := #' @importFrom utils data +#' @importFrom stats IQR +#' @importFrom stats weighted.mean "_PACKAGE" list.rules.party = getFromNamespace(".list.rules.party", "partykit") diff --git a/R/party.R b/R/party.R new file mode 100644 index 0000000..e3697b5 --- /dev/null +++ b/R/party.R @@ -0,0 +1,176 @@ +################################################################################ +# This is the part of the 'tidyrules' R package hosted at +# https://github.com/talegari/tidyrules with GPL-3 license. +################################################################################ + +#' @name tidy.constparty +#' @title Obtain rules as a ruleset/tidytable from a party model +#' @description Each row corresponds to a rule. A rule can be copied into +#' `dplyr::filter` to filter the observations corresponding to a rule +#' @param x party model +#' @param ... Other arguments (currently unused) +#' @details These party models are supported: regression (y is numeric), +#' classification (y is factor) +#' @return A tidytable where each row corresponds to a rule. The columns are: +#' rule_nbr, LHS, RHS, support, confidence (for classification only), lift +#' (for classification only) +#' @examples +#' model_party_cl = partykit::ctree(species ~ .,data = palmerpenguins::penguins) +#' model_party_cl +#' tidy(model_party_cl) +#' +#' model_party_re = partykit::ctree(bill_length_mm ~ ., +#' data = palmerpenguins::penguins +#' ) +#' model_party_re +#' tidy(model_party_re) +#' @export + +tidy.constparty = function(x, ...){ + + ##### assertions and prep #################################################### + arguments = list(...) + + # column names from the x: This will be used at the end to handle the + # variables with a space + col_names = + attr(x$terms, which = "term.labels") %>% + stringr::str_remove_all(pattern = "`") + + # throw error if there are consecutive spaces in the column names + if (any(stringr::str_count(col_names, " ") > 0)){ + rlang::abort( + "Variable names should not have two or more consecutive spaces.") + } + + # detect method using 'fitted' + fitted_df = tidytable::as_tidytable(x$fitted) + colnames(fitted_df) = c("terminal_node_id", "weight", "response") + fitted_df[["terminal_node_id"]] = as.character(fitted_df[["terminal_node_id"]]) + + y_class = class(fitted_df[["response"]]) + if (y_class == "factor") { + type = "classification" + } else if (y_class %in% c("numeric", "integer")) { + type = "regression" + } else { + rlang::inform("tidy supports only classification and regression 'party' models") + rlang::abort("Unsupported party object") + } + + #### core extraction work #################################################### + + # extract rules + raw_rules = list.rules.party(x) + + rules_df = + raw_rules %>% + stringr::str_replace_all(pattern = "\\\"","'") %>% + stringr::str_remove_all(pattern = ", 'NA'") %>% + stringr::str_remove_all(pattern = "'NA',") %>% + stringr::str_remove_all(pattern = "'NA'") %>% + stringr::str_squish() %>% + stringr::str_split(" & ") %>% + purrr::map(~ stringr::str_c("( ", .x, " )")) %>% + purrr::map_chr(~ stringr::str_c(.x, collapse = " & ")) %>% + tidytable::tidytable(LHS = .) %>% + tidytable::mutate(terminal_node_id = names(raw_rules)) + + # create metrics df + if (type == "classification"){ + + terminal_response_df = + fitted_df %>% + tidytable::summarise(sum_weight = sum(weight, na.rm = TRUE), + .by = c(terminal_node_id, response) + ) %>% + tidytable::slice_max(n = 1, + order_by = sum_weight, + by = terminal_node_id, + with_ties = FALSE + ) %>% + tidytable::select(terminal_node_id, + winning_response = response + ) + + prevalence_df = + fitted_df %>% + tidytable::summarise(prevalence = sum(weight, na.rm = TRUE), + .by = response + ) %>% + tidytable::mutate(prevalence = prevalence / sum(prevalence)) %>% + tidytable::select(response, prevalence) + + res = + fitted_df %>% + # bring 'winning_response' column + tidytable::left_join(terminal_response_df, + by = "terminal_node_id" + ) %>% + # bring 'prevalence' column + tidytable::left_join(prevalence_df, + by = c("winning_response" = "response") + ) %>% + tidytable::summarise( + support = sum(weight), + confidence = weighted.mean(response == winning_response, weight, na.rm = TRUE), + lift = weighted.mean(response == winning_response, weight, na.rm = TRUE) / prevalence[1], + RHS = winning_response[1], + .by = terminal_node_id + ) %>% + tidytable::left_join(rules_df, by = "terminal_node_id") %>% + tidytable::arrange(tidytable::desc(confidence)) %>% + tidytable::mutate(., rule_nbr = 1:nrow(.)) %>% + tidytable::select(rule_nbr, LHS, RHS, + support, confidence, lift, + terminal_node_id + ) + + } else if (type == "regression"){ + + res = + fitted_df %>% + tidytable::mutate(average = weighted.mean(response, weight, na.rm = TRUE), + .by = terminal_node_id + ) %>% + tidytable::summarise( + support = sum(weight), + IQR = DescTools::IQRw(response, weight, na.rm = TRUE), + RMSE = MetricsWeighted::rmse(actual = response, + predicted = average, + w = weight, + na.rm = TRUE + ), + average = mean(average), + .by = terminal_node_id + ) %>% + tidytable::left_join(rules_df, by = "terminal_node_id") %>% + tidytable::arrange(tidytable::desc(RMSE)) %>% + tidytable::mutate(., rule_nbr = 1:nrow(.)) %>% + tidytable::select(rule_nbr, LHS, RHS = average, + support, IQR, RMSE, + terminal_node_id + ) + } + + #### finalize output ######################################################### + + # replace variable names with spaces within backquotes + for (i in 1:length(col_names)) { + res[["LHS"]] = + stringr::str_replace_all(res[["LHS"]], + col_names[i], + addBackquotes(col_names[i]) + ) + } + + #### return ################################################################## + + class(res) = c("ruleset", class(res)) + + attr(res, "keys") = NULL + attr(res, "model_type") = "constparty" + attr(res, "estimation_type") = type + + return(res) +} diff --git a/R/rule_translators.R b/R/rule_translators.R index 73354a5..e8e0bba 100644 --- a/R/rule_translators.R +++ b/R/rule_translators.R @@ -15,19 +15,34 @@ convert_rule_flavor = function(rule, flavor){ if (flavor == "python"){ res = rule %>% + stringr::str_replace_all("\\( ", "") %>% + stringr::str_replace_all(" \\)", "") %>% + stringr::str_replace_all("%in%", "in") %>% stringr::str_replace_all("c\\(", "[") %>% stringr::str_replace_all("\\)", "]") %>% - stringr::str_replace_all("&", "and") + + stringr::str_replace_all("&", " ) and (") %>% + + stringr::str_c("( ", ., " )") %>% + stringr::str_squish() } else if (flavor == "sql"){ res = rule %>% - stringr::str_replace_all("==", "=") %>% + stringr::str_replace_all("\\( ", "") %>% + stringr::str_replace_all(" \\)", "") %>% + stringr::str_replace_all("%in%", "IN") %>% - stringr::str_replace_all("c\\(", "(") %>% - stringr::str_replace_all("&", "AND") + stringr::str_replace_all("c\\(", "[") %>% + stringr::str_replace_all("\\)", "]") %>% + + stringr::str_replace_all("&", " ) AND (") %>% + + stringr::str_c("( ", ., " )") %>% + stringr::str_squish() } + attr(res, "flavor") = flavor return(res) } diff --git a/man/tidy.constparty.Rd b/man/tidy.constparty.Rd new file mode 100644 index 0000000..75ee7de --- /dev/null +++ b/man/tidy.constparty.Rd @@ -0,0 +1,37 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/party.R +\name{tidy.constparty} +\alias{tidy.constparty} +\title{Obtain rules as a ruleset/tidytable from a party model} +\usage{ +\method{tidy}{constparty}(x, ...) +} +\arguments{ +\item{x}{party model} + +\item{...}{Other arguments (currently unused)} +} +\value{ +A tidytable where each row corresponds to a rule. The columns are: + rule_nbr, LHS, RHS, support, confidence (for classification only), lift + (for classification only) +} +\description{ +Each row corresponds to a rule. A rule can be copied into + `dplyr::filter` to filter the observations corresponding to a rule +} +\details{ +These party models are supported: regression (y is numeric), + classification (y is factor) +} +\examples{ +model_party_cl = partykit::ctree(species ~ .,data = palmerpenguins::penguins) +model_party_cl +tidy(model_party_cl) + +model_party_re = partykit::ctree(bill_length_mm ~ ., + data = palmerpenguins::penguins + ) +model_party_re +tidy(model_party_re) +} diff --git a/tests/testthat/test-party.R b/tests/testthat/test-party.R new file mode 100644 index 0000000..97e57e5 --- /dev/null +++ b/tests/testthat/test-party.R @@ -0,0 +1,51 @@ +################################################################################ +# This is the part of the 'tidyrules' R package hosted at +# https://github.com/talegari/tidyrules with GPL-3 license. +################################################################################ + +context("test-party") + +# setup some models ---- +data("penguins", package = "palmerpenguins") + +model_party_cl = partykit::ctree(species ~ .,data = penguins) +model_party_cl +tidy(model_party_cl) + +model_party_re = partykit::ctree(bill_length_mm ~ ., + data = penguins + ) +model_party_re +tidy(model_party_re) + +# function to check whether a rule is filterable +ruleFilterable = function(rule, data){ + dplyr::filter(data, eval(parse(text = rule))) +} + +# function to check whether all rules are filterable +allRulesFilterable = function(tr, data){ + parse_status = sapply( + tr[["LHS"]], + function(arule){ + trydf = try(ruleFilterable(arule, data), silent = TRUE) + if (nrow(trydf) == 0) print(arule) + inherits(trydf, "data.frame") + } + ) + return(parse_status) +} + +# test output type ---- + +test_that("creates ruleset", { + expect_is(tidy(model_party_cl), "ruleset") + expect_is(tidy(model_party_re), "ruleset") +}) + +# test parsable ---- +test_that("rules are parsable", { + expect_true(all(allRulesFilterable(tidy(model_party_cl), penguins))) + expect_true(all(allRulesFilterable(tidy(model_party_re), penguins))) +}) +