Skip to content

Commit

Permalink
fixed confidence issue, added ruleset class (#31)
Browse files Browse the repository at this point in the history
  • Loading branch information
talegari authored Jun 25, 2024
1 parent 7578f46 commit 97788de
Show file tree
Hide file tree
Showing 20 changed files with 306 additions and 294 deletions.
2 changes: 2 additions & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
^renv$
^renv\.lock$
^\.travis\.yml$
^.*\.Rproj$
^\.Rproj\.user$
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,6 @@
.Rhistory
.RData
.Ruserdata
.Rprofile
renv*
*.Rproj
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: tidyrules
Type: Package
Title: Utilities to Retrieve Rulelists from Model Fits, Filter, Prune, Reorder and Predict on unseen data
Version: 0.2.5
Version: 0.2.6
Authors@R: c(
person("Srikanth", "Komala Sheshachala", email = "sri.teach@gmail.com", role = c("aut", "cre")),
person("Amith Kumar", "Ullur Raghavendra", email = "amith54@gmail.com", role = c("aut"))
Expand Down Expand Up @@ -38,7 +38,7 @@ Suggests:
knitr (>= 1.23),
rmarkdown (>= 1.13),
palmerpenguins (>= 0.1.1),
Description: Extract rules as a rulelist (a class based on dataframe) along with metrics per rule such as support, confidence, lift, RMSE, IQR. Rulelists can be augmented using validation data, manipulated using standard dataframe operations, rulelists can be used to predict on unseen data, prune them based on some metrics and reoder them to optimize them for a metric. Utilities include manually creating rulesets, exporting a rulelist to SQL syntax and so on.
Description: Provides a framework to work with decision rules. Rules can be extracted from supported models, augmented with (custom) metrics using validation data, manipulated using standard dataframe operations, reordered and pruned based on a metric, predict on unseen (test) data. Utilities include; Creating a rulelist manually, Exporting a rulelist as a SQL case statement and so on. The package offers two classes; rulelist and rulelset based on dataframe.
URL: https://github.com/talegari/tidyrules
BugReports: https://github.com/talegari/tidyrules/issues
License: GPL-3
Expand Down
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,18 @@ S3method(calculate,rulelist)
S3method(plot,prune_rulelist)
S3method(plot,rulelist)
S3method(predict,rulelist)
S3method(predict,ruleset)
S3method(print,prune_rulelist)
S3method(print,rulelist)
S3method(print,ruleset)
S3method(prune,rulelist)
S3method(reorder,rulelist)
S3method(tidy,C5.0)
S3method(tidy,constparty)
S3method(tidy,cubist)
S3method(tidy,rpart)
export(as_rulelist)
export(as_ruleset)
export(augment)
export(calculate)
export(convert_rule_flavor)
Expand Down
17 changes: 9 additions & 8 deletions R/package.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
#' @name package_tidyrules
#' @title `tidyrules`
#' @description `tidyrules` package provides a framework to work with decision
#' rules stored as a [rulelist] backed by a tidy dataframe. Rules can be
#' extracted from supported models using [tidy], augmented using validation data
#' by [augment][augment.rulelist], manipulated using standard dataframe
#' operations, (modified) rulelists can be used to [predict][predict.rulelist]
#' on unseen (test) data. Utilities include: Create a rulelist
#' manually ([as_rulelist][as_rulelist.data.frame]), Export a rulelist to SQL
#' ([to_sql_case]) and so on.
#' @seealso [rulelist], [tidy], [augment][augment.rulelist], [predict][predict.rulelist]
#' rules. Rules can be extracted from supported models using [tidy], augmented
#' using validation data by [augment][augment.rulelist], manipulated using
#' standard dataframe operations, (modified) rulelists can be used to
#' [predict][predict.rulelist] on unseen (test) data. Utilities include:
#' Create a rulelist manually ([as_rulelist][as_rulelist.data.frame]), Export
#' a rulelist to SQL ([to_sql_case]) and so on. The package offers two
#' classes; [rulelist] and [ruleset] based on dataframe.
#' @seealso [rulelist], [tidy], [augment][augment.rulelist],
#' [predict][predict.rulelist]
#' @importFrom magrittr %>%
#' @importFrom rlang %||%
#' @importFrom data.table :=
Expand Down
81 changes: 45 additions & 36 deletions R/rulelist.R
Original file line number Diff line number Diff line change
Expand Up @@ -359,23 +359,27 @@ set_validation_data = function(x, validation_data, y_name, weight = 1){
#' @title Print method for [rulelist] class
#' @description Prints [rulelist] attributes and first few rows.
#' @param x A [rulelist] object
#' @param banner (flag, default: `TRUE`) Should the banner be displayed
#' @param ... Passed to `tidytable::print`
#' @return input [rulelist] (invisibly)
#' @seealso [rulelist], [tidy], [augment][augment.rulelist],
#' [predict][predict.rulelist], [calculate][calculate.rulelist],
#' [prune][prune.rulelist], [reorder][reorder.rulelist]
#' @export
print.rulelist = function(x, ...){
print.rulelist = function(x, banner = TRUE, ...){

validate_rulelist(x)
rulelist = rlang::duplicate(x)

keys = attr(x, "keys")
estimation_type = attr(x, "estimation_type")
model_type = attr(x, "model_type")
validation_data = attr(x, "validation_data")
keys = attr(rulelist, "keys")
estimation_type = attr(rulelist, "estimation_type")
model_type = attr(rulelist, "model_type")
validation_data = attr(rulelist, "validation_data")

cli::cli_rule(left = "Rulelist")
cli::cli_text("")
if (banner) {
cli::cli_rule(left = "Rulelist")
cli::cli_text("")
}

if (is.null(keys)) {
cli::cli_alert_info("{.emph Keys}: {.strong NULL}")
Expand Down Expand Up @@ -407,10 +411,13 @@ print.rulelist = function(x, ...){

cli::cli_text("")

class(x) = setdiff(class(x), "rulelist")
print(x, ...)
cli::cli_rule()
class(x) = c("rulelist", class(x))
class(rulelist) = setdiff(class(rulelist), "rulelist")
# now 'rulelist' is a dataframe and not a 'rulelist'
print(rulelist, ...)

if (banner) {
cli::cli_rule()
}

return(invisible(x))
}
Expand Down Expand Up @@ -706,20 +713,21 @@ predict_rulelist = function(rulelist, new_data){
#' @returns A dataframe. See **Details**.
#'
#' @details If a `row_nbr` is covered more than one `rule_nbr` per 'keys', then
#' `rule_nbr` appearing earlier (as in row order of the [rulelist]) takes
#' precedence.
#' `rule_nbr` appearing earlier (as in row order of the [rulelist]) takes
#' precedence.
#'
#' ## Output Format
#'
#' - When multiple is `FALSE`(default), output is a dataframe with three
#' or more columns: `row_number` (int), columns corresponding to 'keys',
#' `rule_nbr` (int).
#' or more columns: `row_number` (int), columns corresponding to 'keys',
#' `rule_nbr` (int).
#'
#' - When multiple is `TRUE`(default), output is a tidytable/dataframe with three
#' or more columns: `row_number` (int), columns corresponding to 'keys',
#' `rule_nbr` (list column of integers).
#' - When multiple is `TRUE`, output is a dataframe with three
#' or more columns: `row_number` (int), columns corresponding to 'keys',
#' `rule_nbr` (list column of integers).
#'
#' - If a row number and 'keys' combination is not covered by any rule, then `rule_nbr` column has missing value.
#' - If a row number and 'keys' combination is not covered by any rule, then
#' `rule_nbr` column has missing value.
#'
#' @examples
#' model_c5 = C50::C5.0(species ~.,
Expand All @@ -740,7 +748,6 @@ predict_rulelist = function(rulelist, new_data){
#' [predict][predict.rulelist], [calculate][calculate.rulelist],
#' [prune][prune.rulelist], [reorder][reorder.rulelist]
#' @importFrom stats predict
#' @family Core Rulelist Utility
#' @export
#'
predict.rulelist = function(object, new_data, multiple = FALSE, ...){
Expand Down Expand Up @@ -790,22 +797,23 @@ augment_class_no_keys = function(x, new_data, y_name, weight, ...){
mutate(prevalence = prevalence_0 / sum(prevalence_0)) %>%
select(all_of(c(eval(y_name), "prevalence")))

na_to_false = function(x) ifelse(is.na(x), FALSE, x)

aggregatees_df =
new_data_with_rule_nbr %>%
# bring 'prevalence' column
left_join(prevalence_df,by = eval(y_name)) %>%
summarise(
support = sum(weight__, na.rm = TRUE),
confidence = weighted.mean(ifelse(is.na(eval(y_name) == RHS), FALSE, TRUE),
weight__,
na.rm = TRUE
),
lift = weighted.mean(ifelse(is.na(eval(y_name) == RHS), FALSE, TRUE),
weight__,
na.rm = TRUE
) / prevalence[1],
confidence =
( as.character(.data[[y_name]]) == as.character(RHS) ) %>%
na_to_false() %>%
weighted.mean(weight__, na.rm = TRUE),
prevalence = prevalence[1],
.by = rule_nbr
) %>%
mutate(lift = confidence / prevalence) %>%
select(-prevalence) %>%
nest(.by = rule_nbr, .key = "augmented_stats")

# output has all columns of 'tidy' along with 'augment_stats'
Expand Down Expand Up @@ -858,23 +866,24 @@ augment_class_keys = function(x, new_data, y_name, weight, ...){
) %>%
select(all_of(c(keys, eval(y_name), "prevalence")))

na_to_false = function(x) ifelse(is.na(x), FALSE, x)

# add aggregates at rule_nbr and 'keys' level
aggregatees_df =
new_data_with_rule_nbr %>%
left_join(prevalence_df, by = c(keys, eval(y_name))) %>%
summarise(
support = sum(weight__, na.rm = TRUE),
confidence = weighted.mean(ifelse(is.na(eval(y_name) == RHS), FALSE, TRUE),
weight__,
na.rm = TRUE
),
lift = weighted.mean(ifelse(is.na(eval(y_name) == RHS), FALSE, TRUE),
weight__,
na.rm = TRUE
) / prevalence[1],
confidence =
( as.character(.data[[y_name]]) == as.character(RHS) ) %>%
na_to_false() %>%
weighted.mean(weight__, na.rm = TRUE),
prevalence = prevalence[1],
...,
.by = c(keys, "rule_nbr")
) %>%
mutate(lift = confidence / prevalence) %>%
select(-prevalence) %>%
nest(.by = c("rule_nbr", keys), .key = "augmented_stats")

# output has all columns of 'tidy' along with 'augment_stats'
Expand Down
109 changes: 109 additions & 0 deletions R/ruleset.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
#*******************************************************************************
# This is the part of the 'tidyrules' R package hosted at
# https://github.com/talegari/tidyrules with GPL-3 license.
#*******************************************************************************

#' @name ruleset
#' @title Ruleset
#' @description ruleset class is a piggyback class that inherits [rulelist]
#' class for convenience of [print] and [predict] methods.
identity # just a placeholder for 'ruleset' documentation, not exported

#' @name as_ruleset
#' @title Get a ruleset from a rulelist
#' @description Returns a ruleset object
#' @param rulelist A [rulelist]
#' @returns A [ruleset]
#'
#' @examples
#' model_class_party = partykit::ctree(species ~ .,
#' data = palmerpenguins::penguins
#' )
#' as_ruleset(tidy(model_class_party))
#'
#' @seealso [rulelist]
#' @export
as_ruleset = function(rulelist){

validate_rulelist(rulelist)

x = rlang::duplicate(rulelist)
class(x) = c("ruleset", class(x))

return(x)
}

#' @name print.ruleset
#' @title Print method for ruleset class
#' @description Prints the ruleset object
#' @param x A [rulelist]
#' @param banner (flag, default: `TRUE`) Should the banner be displayed
#' @param ... Passed to `print.rulelist`
#' @returns (invisibly) Returns the ruleset object
#'
#' @examples
#' model_class_party = partykit::ctree(species ~ .,
#' data = palmerpenguins::penguins
#' )
#' as_ruleset(tidy(model_class_party))
#'
#' @seealso [print.rulelist]
#' @export
print.ruleset = function(x, banner = TRUE, ...){

ruleset = rlang::duplicate(x)

if (banner) {
cli::cli_rule(left = "Ruleset")
cli::cli_text("")
}

class(ruleset) = setdiff(class(ruleset), "ruleset")
# now 'ruleset' is a rulelist
print(ruleset, banner = FALSE, ...)

if (banner) {
cli::cli_rule()
}

return(invisible(x))
}

#' @name predict.ruleset
#' @title `predict` method for a [ruleset]
#' @description Predicts multiple `rule_nbr`(s) applicable for a `row_nbr` (per
#' key) in new_data
#'
#' @param object A [ruleset]
#' @param new_data (dataframe)
#' @param ... unused
#'
#' @returns A dataframe with three or more columns: `row_number` (int), columns
#' corresponding to 'keys', `rule_nbr` (list column of integers). If a row
#' number and 'keys' combination is not covered by any rule, then `rule_nbr`
#' column has missing value.
#'
#' @examples
#' model_c5 = C50::C5.0(species ~.,
#' data = palmerpenguins::penguins,
#' trials = 5,
#' rules = TRUE
#' )
#' tidy_c5_ruleset = as_ruleset(tidy(model_c5))
#' tidy_c5_ruleset
#'
#' predict(tidy_c5_ruleset, palmerpenguins::penguins)
#'
#' @seealso [predict.rulelist]
#' @importFrom stats predict
#' @export
predict.ruleset = function(object, new_data, ...){

x = rlang::duplicate(object)
class(x) = setdiff(class(x), "ruleset")

# now 'ruleset' is a rulelist
res = predict(x, new_data, multiple = TRUE, ...)

return(res)
}
8 changes: 4 additions & 4 deletions R/utils.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
################################################################################
#*******************************************************************************
# This is the part of the 'tidyrules' R package hosted at
# https://github.com/talegari/tidyrules with GPL-3 license.
################################################################################
#*******************************************************************************

#' @keywords internal
#' @name positionSpaceOutsideSinglequotes
Expand Down Expand Up @@ -334,10 +334,10 @@ convert_rule_flavor = function(rule, flavor){
#' @description Extract SQL case statement from a [rulelist]
#' @param rulelist A [rulelist] object
#' @param rhs_column_name (string, default: "RHS") Name of the column in the
#' rulelist to be used as RHS (WHEN <some rule> THEN {rhs}) in the sql case
#' rulelist to be used as RHS (WHEN some_rule THEN rhs) in the sql case
#' statement
#' @param output_colname (string, default: "output") Name of the output column
#' created by the SQL statement (used in case ... AS {output_column})
#' created by the SQL statement (used in case ... AS output_column)
#' @return (string invisibly) SQL case statement
#' @details As a side-effect, the SQL statement is cat to stdout. The output
#' contains newline character.
Expand Down
27 changes: 27 additions & 0 deletions man/as_ruleset.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 97788de

Please sign in to comment.