Skip to content

Commit

Permalink
Add knndm method from {CAST} (#229)
Browse files Browse the repository at this point in the history
  • Loading branch information
pat-s authored Aug 23, 2023
1 parent 3e7eeb2 commit 01ce238
Show file tree
Hide file tree
Showing 64 changed files with 8,512 additions and 2,052 deletions.
5 changes: 2 additions & 3 deletions .github/workflows/tic.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ jobs:
# - { os: windows-latest, r: "release" }
# [Custom matrix env var]
# - { os: macOS-latest, r: "release", latex: "true", codecov: "false" }
- { os: ubuntu-latest, r: "devel", codecov: "true" }
- { os: ubuntu-22.04, r: "devel", codecov: "true" }
# [Custom matrix env var]
- { os: ubuntu-latest, r: "release", pkgdown: "true", latex: "true", codecov: "false" }
- { os: ubuntu-22.04, r: "release", pkgdown: "true", latex: "true", codecov: "false" }

env:
# make sure to run `tic::use_ghactions_deploy()` to set up deployment
Expand Down Expand Up @@ -82,7 +82,6 @@ jobs:
- name: "[Custom block] [macOS] Install spatial libraries"
if: runner.os == 'macOS'
run: |
rm '/usr/local/bin/gfortran'
brew install ccache gdal geos proj udunits jpeg sqlite
brew install xquartz
mkdir ~/.R && echo -e "CPPFLAGS += -L/usr/local/opt/jpeg/lib" >> ~/.R/Makevars
Expand Down
3 changes: 3 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,15 @@ Suggests:
bbotk,
blockCV (>= 3.1.2),
caret,
CAST (>= 0.8.0),
ggsci,
ggtext,
here,
knitr,
lgr,
mlr3filters (>= 0.7.0.9000),
mlr3pipelines,
mlr3spatial,
mlr3tuning,
patchwork,
plotly,
Expand All @@ -51,6 +53,7 @@ Suggests:
sperrorest,
terra,
testthat (>= 3.0.0),
twosamples,
vdiffr (>= 1.0.0),
withr
VignetteBuilder:
Expand Down
6 changes: 6 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@ S3method(autoplot,ResamplingRepeatedSpCVBlock)
S3method(autoplot,ResamplingRepeatedSpCVCoords)
S3method(autoplot,ResamplingRepeatedSpCVDisc)
S3method(autoplot,ResamplingRepeatedSpCVEnv)
S3method(autoplot,ResamplingRepeatedSpCVKnndm)
S3method(autoplot,ResamplingRepeatedSpCVTiles)
S3method(autoplot,ResamplingRepeatedSptCVCstf)
S3method(autoplot,ResamplingSpCVBlock)
S3method(autoplot,ResamplingSpCVBuffer)
S3method(autoplot,ResamplingSpCVCoords)
S3method(autoplot,ResamplingSpCVDisc)
S3method(autoplot,ResamplingSpCVEnv)
S3method(autoplot,ResamplingSpCVKnndm)
S3method(autoplot,ResamplingSpCVTiles)
S3method(autoplot,ResamplingSptCVCstf)
S3method(plot,ResamplingCV)
Expand All @@ -32,26 +34,30 @@ S3method(plot,ResamplingRepeatedSpCVBlock)
S3method(plot,ResamplingRepeatedSpCVCoords)
S3method(plot,ResamplingRepeatedSpCVDisc)
S3method(plot,ResamplingRepeatedSpCVEnv)
S3method(plot,ResamplingRepeatedSpCVKnndm)
S3method(plot,ResamplingRepeatedSpCVTiles)
S3method(plot,ResamplingRepeatedSptCVCstf)
S3method(plot,ResamplingSpCVBlock)
S3method(plot,ResamplingSpCVBuffer)
S3method(plot,ResamplingSpCVCoords)
S3method(plot,ResamplingSpCVDisc)
S3method(plot,ResamplingSpCVEnv)
S3method(plot,ResamplingSpCVKnndm)
S3method(plot,ResamplingSpCVTiles)
S3method(plot,ResamplingSptCVCstf)
export(ResamplingRepeatedSpCVBlock)
export(ResamplingRepeatedSpCVCoords)
export(ResamplingRepeatedSpCVDisc)
export(ResamplingRepeatedSpCVEnv)
export(ResamplingRepeatedSpCVKnndm)
export(ResamplingRepeatedSpCVTiles)
export(ResamplingRepeatedSptCVCstf)
export(ResamplingSpCVBlock)
export(ResamplingSpCVBuffer)
export(ResamplingSpCVCoords)
export(ResamplingSpCVDisc)
export(ResamplingSpCVEnv)
export(ResamplingSpCVKnndm)
export(ResamplingSpCVTiles)
export(ResamplingSptCVCstf)
export(TaskClassifST)
Expand Down
199 changes: 199 additions & 0 deletions R/ResamplingRepeatedSpCVknndm.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
#' @title (CAST) Repeated K-fold Nearest Neighbour Distance Matching
#'
#' @template rox_sptcv_knndm
#' @name mlr_resamplings_repeated_spcv_knndm
#'
#' @section Parameters:
#'
#' * `repeats` (`integer(1)`)\cr
#' Number of repeats.
#'
#' @references
#' `r format_bib("linnenbrink2023")`
#'
#' @export
#' @examples
#' library(mlr3)
#' library(mlr3spatial)
#' set.seed(42)
#' simarea = list(matrix(c(0, 0, 0, 100, 100, 100, 100, 0, 0, 0), ncol = 2, byrow = TRUE))
#' simarea = sf::st_polygon(simarea)
#' train_points = sf::st_sample(simarea, 1000, type = "random")
#' train_points = sf::st_as_sf(train_points)
#' train_points$target = as.factor(sample(c("TRUE", "FALSE"), 1000, replace = TRUE))
#' pred_points = sf::st_sample(simarea, 1000, type = "regular")
#'
#' task = mlr3spatial::as_task_classif_st(sf::st_as_sf(train_points), "target", positive = "TRUE")
#'
#' cv_knndm = rsmp("repeated_spcv_knndm", ppoints = pred_points, repeats = 2)
#' cv_knndm$instantiate(task)

#' #' ### Individual sets:
#' # cv_knndm$train_set(1)
#' # cv_knndm$test_set(1)
#' # check that no obs are in both sets
#' intersect(cv_knndm$train_set(1), cv_knndm$test_set(1)) # good!
#'
#' # Internal storage:
#' # cv_knndm$instance # table
ResamplingRepeatedSpCVKnndm = R6Class("ResamplingRepeatedSpCVKnndm",
inherit = mlr3::Resampling,
public = list(
#' @description
#' Create a "K-fold Nearest Neighbour Distance Matching" resampling instance.
#'
#' @param id `character(1)`\cr
#' Identifier for the resampling strategy.
initialize = function(id = "repeated_spcv_knndm") {
ps = ParamSet$new(params = list(
ParamUty$new("modeldomain", default = NULL,
custom_check = function(x) {
checkmate::check_class(x, "SpatRaster",
null.ok = TRUE)
}
),
ParamUty$new("ppoints", default = NULL,
custom_check = function(x) {
checkmate::check_class(x, "sfc_POINT",
null.ok = TRUE)
}
),
ParamFct$new("space", levels = "geographical", default = "geographical"),
ParamInt$new("folds", default = 10, lower = 2),
ParamDbl$new("maxp", default = 0.5, lower = 0, upper = 1),
ParamFct$new("clustering", default = "hierarchical",
levels = c("hierarchical", "kmeans"), tags = "required"),
ParamUty$new("linkf", default = "ward.D2"),
ParamInt$new("samplesize"),
ParamFct$new("sampling", levels = c("random", "hexagonal", "regular", "Fibonacci")),
ParamInt$new("repeats", lower = 1, default = 1L, tags = "required")
))
ps$values = list(repeats = 1, folds = 10)

super$initialize(
id = id,
param_set = ps,
label = "Repeated Spatial 'K-fold Nearest Neighbour Distance Matching",
man = "mlr3spatiotempcv::mlr_resamplings_repeated_spcv_knndm"
)
},

#' @description Translates iteration numbers to fold number.
#' @param iters `integer()`\cr
#' Iteration number.
folds = function(iters) {
iters = assert_integerish(iters, any.missing = FALSE, coerce = TRUE)
((iters - 1L) %% as.integer(self$param_set$values$folds)) + 1L
},

#' @description Translates iteration numbers to repetition number.
#' @param iters `integer()`\cr
#' Iteration number.
repeats = function(iters) {
iters = assert_integerish(iters, any.missing = FALSE, coerce = TRUE)
((iters - 1L) %/% as.integer(self$param_set$values$folds)) + 1L
},

#' @description
#' Materializes fixed training and test splits for a given task.
#' @param task [Task]\cr
#' A task to instantiate.
instantiate = function(task) {

mlr3::assert_task(task)
assert_spatial_task(task)
groups = task$groups

pv = self$param_set$values

# Set values to default if missing
mlr3misc::map(
c("modeldomain", "ppoints", "space", "maxp", "folds", "clustering",
"linkf", "samplesize", "sampling"),
function(x) private$.set_default_param_values(x)
)

if (is.null(pv$modeldomain) && is.null(pv$ppoints)) {
stopf("Either 'modeldomain' or 'ppoints' need to be set.")
}

if (!is.null(groups)) {
stopf("Grouping is not supported for spatial resampling methods.")
}

if (!is.null(groups)) {
stopf("Grouping is not supported for spatial resampling methods")
}

instance = private$.sample(
task$row_ids,
task$coordinates(),
task$crs
)

self$instance = instance
self$task_hash = task$hash
self$task_nrow = task$nrow
invisible(self)
}
),
active = list(

#' @field iters `integer(1)`\cr
#' Returns the number of resampling iterations, depending on the
#' values stored in the `param_set`.
iters = function() {
pv = self$param_set$values
# hack for autoplot
if (!is.null(names(self$instance))) {
as.integer(pv$repeats) * as.integer(length(self$instance$train))
} else {
as.integer(pv$repeats) * as.integer(length(self$instance[[1]]$test))
}
}
),
private = list(
.sample = function(ids, coords, crs) {

points = sf::st_as_sf(coords,
coords = colnames(coords),
crs = crs
)

pv = self$param_set$values
map(seq_len(pv$repeats), function(i) {
inds = CAST::knndm(tpoints = points,
modeldomain = self$param_set$values$modeldomain,
ppoints = self$param_set$values$ppoints,
k = self$param_set$values$folds,
maxp = self$param_set$values$maxp,
clustering = self$param_set$values$clustering,
linkf = self$param_set$values$linkf,
samplesize = self$param_set$values$samplesize,
sampling = self$param_set$values$sampling)

list(train = inds$indx_train, test = inds$indx_test)
})

},
.set_default_param_values = function(param) {
if (is.null(self$param_set$values[[param]])) {
self$param_set$values[[param]] = self$param_set$default[[param]]
}
},
.get_train = function(i) {
i = as.integer(i) - 1L
folds = as.integer(self$param_set$values$folds)
rep = i %/% folds + 1L
fold = i %% folds + 1L
self$instance[[rep]]$train[[fold]]
},
.get_test = function(i) {
i = as.integer(i) - 1L
folds = as.integer(self$param_set$values$folds)
rep = i %/% folds + 1L
fold = i %% folds + 1L
self$instance[[rep]]$test[[fold]]
}
)
)
Loading

0 comments on commit 01ce238

Please sign in to comment.