From 0b1cb8a4a1f88c5ebb6da2b374027688815b7bdd Mon Sep 17 00:00:00 2001 From: Michel Lang Date: Wed, 10 Aug 2022 14:37:52 +0200 Subject: [PATCH] refactor: custom cost measure does not require the task anymore --- R/MeasureClassifCosts.R | 29 ++++++++++------------------- man/mlr_measures_classif.costs.Rd | 2 -- 2 files changed, 10 insertions(+), 21 deletions(-) diff --git a/R/MeasureClassifCosts.R b/R/MeasureClassifCosts.R index 75f832759..e79eb9c7b 100644 --- a/R/MeasureClassifCosts.R +++ b/R/MeasureClassifCosts.R @@ -11,8 +11,6 @@ #' For calculation of the score, the confusion matrix is multiplied element-wise with the cost matrix. #' The costs are then summed up (and potentially divided by the number of observations if `normalize` is set to `TRUE` (default)). #' -#' This measure requires the [Task] during scoring to ensure that the rows and columns of the cost matrix are in the same order as in the confusion matrix. -#' #' @templateVar id classif.costs #' @template measure #' @@ -52,7 +50,6 @@ MeasureClassifCosts = R6Class("MeasureClassifCosts", super$initialize( id = "classif.costs", param_set = param_set, - properties = "requires_task", range = c(-Inf, Inf), minimize = TRUE, label = "Cost-sensitive Classification", @@ -68,7 +65,10 @@ MeasureClassifCosts = R6Class("MeasureClassifCosts", if (missing(rhs)) { return(private$.costs) } - private$.costs = assert_cost_matrix(rhs) + + assert_matrix(rhs, mode = "numeric", any.missing = FALSE, col.names = "unique", row.names = "unique") + assert_set_equal(rownames(rhs), colnames(rhs)) + private$.costs = rhs if (min(rhs) >= 0) { self$range[1L] = 0 @@ -82,8 +82,11 @@ MeasureClassifCosts = R6Class("MeasureClassifCosts", private = list( .costs = NULL, - .score = function(prediction, task, ...) { - costs = assert_cost_matrix(private$.costs, task) + .score = function(prediction, ...) { + costs = self$costs + lvls = levels(prediction$truth) + assert_set_equal(lvls, colnames(costs)) + confusion = table(response = prediction$response, truth = prediction$truth, useNA = "ifany") # reorder rows / cols if necessary @@ -97,6 +100,7 @@ MeasureClassifCosts = R6Class("MeasureClassifCosts", if (self$param_set$values$normalize) { perf = perf / sum(confusion) } + perf }, @@ -106,16 +110,3 @@ MeasureClassifCosts = R6Class("MeasureClassifCosts", #' @include mlr_measures.R mlr_measures$add("classif.costs", function() MeasureClassifCosts$new()) - -assert_cost_matrix = function(costs, task = NULL) { - if (is.null(task)) { - assert_matrix(costs, mode = "numeric", any.missing = FALSE, col.names = "unique", row.names = "unique") - } else { - lvls = task$class_names - assert_matrix(costs, mode = "numeric", any.missing = FALSE, nrows = length(lvls), ncols = length(lvls), row.names = "unique", col.names = "unique") - assert_names(colnames(costs), permutation.of = lvls) - assert_names(rownames(costs), permutation.of = lvls) - } - - costs -} diff --git a/man/mlr_measures_classif.costs.Rd b/man/mlr_measures_classif.costs.Rd index 7a1652b43..c9383906b 100644 --- a/man/mlr_measures_classif.costs.Rd +++ b/man/mlr_measures_classif.costs.Rd @@ -11,8 +11,6 @@ The cost matrix is stored as slot \verb{$costs}. For calculation of the score, the confusion matrix is multiplied element-wise with the cost matrix. The costs are then summed up (and potentially divided by the number of observations if \code{normalize} is set to \code{TRUE} (default)). - -This measure requires the \link{Task} during scoring to ensure that the rows and columns of the cost matrix are in the same order as in the confusion matrix. } \section{Dictionary}{