Skip to content

Commit

Permalink
refactor: custom cost measure does not require the task anymore
Browse files Browse the repository at this point in the history
  • Loading branch information
mllg committed Aug 10, 2022
1 parent 38a7120 commit 0b1cb8a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 21 deletions.
29 changes: 10 additions & 19 deletions R/MeasureClassifCosts.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
#' For calculation of the score, the confusion matrix is multiplied element-wise with the cost matrix.
#' The costs are then summed up (and potentially divided by the number of observations if `normalize` is set to `TRUE` (default)).
#'
#' This measure requires the [Task] during scoring to ensure that the rows and columns of the cost matrix are in the same order as in the confusion matrix.
#'
#' @templateVar id classif.costs
#' @template measure
#'
Expand Down Expand Up @@ -52,7 +50,6 @@ MeasureClassifCosts = R6Class("MeasureClassifCosts",
super$initialize(
id = "classif.costs",
param_set = param_set,
properties = "requires_task",
range = c(-Inf, Inf),
minimize = TRUE,
label = "Cost-sensitive Classification",
Expand All @@ -68,7 +65,10 @@ MeasureClassifCosts = R6Class("MeasureClassifCosts",
if (missing(rhs)) {
return(private$.costs)
}
private$.costs = assert_cost_matrix(rhs)

assert_matrix(rhs, mode = "numeric", any.missing = FALSE, col.names = "unique", row.names = "unique")
assert_set_equal(rownames(rhs), colnames(rhs))
private$.costs = rhs

if (min(rhs) >= 0) {
self$range[1L] = 0
Expand All @@ -82,8 +82,11 @@ MeasureClassifCosts = R6Class("MeasureClassifCosts",
private = list(
.costs = NULL,

.score = function(prediction, task, ...) {
costs = assert_cost_matrix(private$.costs, task)
.score = function(prediction, ...) {
costs = self$costs
lvls = levels(prediction$truth)
assert_set_equal(lvls, colnames(costs))

confusion = table(response = prediction$response, truth = prediction$truth, useNA = "ifany")

# reorder rows / cols if necessary
Expand All @@ -97,6 +100,7 @@ MeasureClassifCosts = R6Class("MeasureClassifCosts",
if (self$param_set$values$normalize) {
perf = perf / sum(confusion)
}

perf
},

Expand All @@ -106,16 +110,3 @@ MeasureClassifCosts = R6Class("MeasureClassifCosts",

#' @include mlr_measures.R
mlr_measures$add("classif.costs", function() MeasureClassifCosts$new())

assert_cost_matrix = function(costs, task = NULL) {
if (is.null(task)) {
assert_matrix(costs, mode = "numeric", any.missing = FALSE, col.names = "unique", row.names = "unique")
} else {
lvls = task$class_names
assert_matrix(costs, mode = "numeric", any.missing = FALSE, nrows = length(lvls), ncols = length(lvls), row.names = "unique", col.names = "unique")
assert_names(colnames(costs), permutation.of = lvls)
assert_names(rownames(costs), permutation.of = lvls)
}

costs
}
2 changes: 0 additions & 2 deletions man/mlr_measures_classif.costs.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 0b1cb8a

Please sign in to comment.