Skip to content

Commit

Permalink
fix swapped one_hot
Browse files Browse the repository at this point in the history
  • Loading branch information
EmilHvitfeldt committed Oct 18, 2024
1 parent d4890ec commit 3b28b51
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 41 deletions.
6 changes: 3 additions & 3 deletions R/sparse_dummy.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#'
#' @param x A factor.
#' @param one_hot A single logical value. Should the first factor level be
#' ignored. Defaults to `FALSE`.
#' included or not. Defaults to `FALSE`.
#'
#' @details
#' Only factor variables can be used with [sparse_dummy()]. A call to
Expand Down Expand Up @@ -33,7 +33,7 @@
#'
#' sparse_dummy(x, one_hot = TRUE)
#' @export
sparse_dummy <- function(x, one_hot = FALSE) {
sparse_dummy <- function(x, one_hot = TRUE) {
if (!is.factor(x)) {
cli::cli_abort("{.arg x} must be a factor, not {.obj_type_friendly {x}}.")
}
Expand All @@ -42,7 +42,7 @@ sparse_dummy <- function(x, one_hot = FALSE) {

x <- as.integer(x)

if (one_hot) {
if (!one_hot) {
lvls <- lvls[-1]
x <- x - 1L
}
Expand Down
4 changes: 2 additions & 2 deletions man/sparse_dummy.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 10 additions & 11 deletions src/sparse-dummy.c
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,10 @@ SEXP ffi_sparse_dummy(SEXP x, SEXP lvls, SEXP counts, SEXP one_hot) {

// Itterate over input, find its position index, and place the position value
// at the index. Increment specific index.

if ((bool) one_hot) {
if (R_isTRUE(one_hot)) {
for (R_xlen_t i = 0; i < len; ++i) {
int current_val = v_x[i] - 1;

if (current_val == -1) {
continue;
}

int index = v_pos_index[current_val];

SEXP pos_vec = VECTOR_ELT(out, current_val);
Expand All @@ -103,6 +98,10 @@ SEXP ffi_sparse_dummy(SEXP x, SEXP lvls, SEXP counts, SEXP one_hot) {
for (R_xlen_t i = 0; i < len; ++i) {
int current_val = v_x[i] - 1;

if (current_val == -1) {
continue;
}

int index = v_pos_index[current_val];

SEXP pos_vec = VECTOR_ELT(out, current_val);
Expand Down Expand Up @@ -155,7 +154,7 @@ SEXP ffi_sparse_dummy_na(SEXP x, SEXP lvls, SEXP counts, SEXP one_hot) {
// Itterate over input, find its position index, and place the position value
// at the index. Increment specific index.

if ((bool) one_hot) {
if (R_isTRUE(one_hot)) {
for (R_xlen_t i = 0; i < len; ++i) {
int current_val = v_x[i];

Expand All @@ -174,10 +173,6 @@ SEXP ffi_sparse_dummy_na(SEXP x, SEXP lvls, SEXP counts, SEXP one_hot) {
}
} else {
--current_val;
if (current_val == -1) {
continue;
}

int index = v_pos_index[current_val];

SEXP pos_vec = VECTOR_ELT(out_positions, current_val);
Expand Down Expand Up @@ -210,6 +205,10 @@ SEXP ffi_sparse_dummy_na(SEXP x, SEXP lvls, SEXP counts, SEXP one_hot) {
}
} else {
--current_val;
if (current_val == -1) {
continue;
}

int index = v_pos_index[current_val];

SEXP pos_vec = VECTOR_ELT(out_positions, current_val);
Expand Down
50 changes: 25 additions & 25 deletions tests/testthat/test-sparse_dummy.R
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# one_hot = FALSE --------------------------------------------------------------
# one_hot = TRUE --------------------------------------------------------------

test_that("sparse_dummy(one_hot = FALSE) works with single level", {
test_that("sparse_dummy(one_hot = TRUE) works with single level", {
x <- factor(c("a", "a", "a"))
exp <- list(
a = c(1L, 1L, 1L)
)

res <- sparse_dummy(x, one_hot = FALSE)
res <- sparse_dummy(x, one_hot = TRUE)
expect_identical(
res,
exp
Expand All @@ -20,7 +20,7 @@ test_that("sparse_dummy(one_hot = FALSE) works zero length input", {
x <- factor(character())
exp <- structure(list(), names = character(0))

res <- sparse_dummy(x, one_hot = TRUE)
res <- sparse_dummy(x, one_hot = FALSE)
expect_identical(
res,
exp
Expand All @@ -29,7 +29,7 @@ test_that("sparse_dummy(one_hot = FALSE) works zero length input", {

## anyNA = FALSE ---------------------------------------------------------------

test_that("sparse_dummy(one_hot = FALSE) works with no NAs", {
test_that("sparse_dummy(one_hot = TRUE) works with no NAs", {
x <- factor(c("a", "b", "c", "d", "a"))
exp <- list(
a = sparse_integer(c(1, 1), c(1, 5), 5),
Expand All @@ -38,7 +38,7 @@ test_that("sparse_dummy(one_hot = FALSE) works with no NAs", {
d = sparse_integer(1, 4, 5)
)

res <- sparse_dummy(x, one_hot = FALSE)
res <- sparse_dummy(x, one_hot = TRUE)
expect_identical(
res,
exp
Expand All @@ -49,7 +49,7 @@ test_that("sparse_dummy(one_hot = FALSE) works with no NAs", {
)
})

test_that("sparse_dummy(one_hot = FALSE) works with no NAs and unseen levels", {
test_that("sparse_dummy(one_hot = TRUE) works with no NAs and unseen levels", {
x <- factor(c("a", "b", "c", "d", "a"), levels = letters[1:6])
exp <- list(
a = sparse_integer(c(1, 1), c(1, 5), 5),
Expand All @@ -60,7 +60,7 @@ test_that("sparse_dummy(one_hot = FALSE) works with no NAs and unseen levels", {
f = sparse_integer(integer(), integer(), 5)
)

res <- sparse_dummy(x, one_hot = FALSE)
res <- sparse_dummy(x, one_hot = TRUE)
expect_identical(
res,
exp
Expand All @@ -73,15 +73,15 @@ test_that("sparse_dummy(one_hot = FALSE) works with no NAs and unseen levels", {

## anyNA = TRUE ----------------------------------------------------------------

test_that("sparse_dummy(one_hot = FALSE) works with NA", {
test_that("sparse_dummy(one_hot = TRUE) works with NA", {
x <- factor(c("a", NA, "b", "c", "a", NA))
exp <- list(
a = sparse_integer(c(1, NA, 1, NA), c(1, 2, 5, 6), 6),
b = sparse_integer(c(NA, 1, NA), c(2, 3, 6), 6),
c = sparse_integer(c(NA, 1, NA), c(2, 4, 6), 6)
)

res <- sparse_dummy(x, one_hot = FALSE)
res <- sparse_dummy(x, one_hot = TRUE)
expect_identical(
res,
exp
Expand All @@ -92,7 +92,7 @@ test_that("sparse_dummy(one_hot = FALSE) works with NA", {
)
})

test_that("sparse_dummy(one_hot = FALSE) works with NA and unseen levels", {
test_that("sparse_dummy(one_hot = TRUE) works with NA and unseen levels", {
x <- factor(c("a", NA, "b", "c", "a", NA), levels = letters[1:5])
exp <- list(
a = sparse_integer(c(1, NA, 1, NA), c(1, 2, 5, 6), 6),
Expand All @@ -102,7 +102,7 @@ test_that("sparse_dummy(one_hot = FALSE) works with NA and unseen levels", {
e = sparse_integer(c(NA, NA), c(2, 6), 6)
)

res <- sparse_dummy(x, one_hot = FALSE)
res <- sparse_dummy(x, one_hot = TRUE)
expect_identical(
res,
exp
Expand All @@ -113,25 +113,25 @@ test_that("sparse_dummy(one_hot = FALSE) works with NA and unseen levels", {
)
})

# one_hot = TRUE ---------------------------------------------------------------
# one_hot = FALSE ---------------------------------------------------------------

test_that("sparse_dummy(one_hot = TRUE) works with single level", {
test_that("sparse_dummy(one_hot = FALSE) works with single level", {
x <- factor(c("a", "a", "a"))
exp <- structure(list(), names = character(0))

res <- sparse_dummy(x, one_hot = TRUE)
res <- sparse_dummy(x, one_hot = FALSE)

expect_identical(
res,
exp
)
})

test_that("sparse_dummy(one_hot = FALSE) works zero length input", {
test_that("sparse_dummy(one_hot = TRUE) works zero length input", {
x <- factor(character())
exp <- structure(list(), names = character(0))

res <- sparse_dummy(x, one_hot = TRUE)
res <- sparse_dummy(x, one_hot = FALSE)
expect_identical(
res,
exp
Expand All @@ -140,15 +140,15 @@ test_that("sparse_dummy(one_hot = FALSE) works zero length input", {

## anyNA = FALSE ---------------------------------------------------------------

test_that("sparse_dummy(one_hot = TRUE) works with no NAs", {
test_that("sparse_dummy(one_hot = FALSE) works with no NAs", {
x <- factor(c("a", "b", "c", "d", "a"))
exp <- list(
b = sparse_integer(1, 2, 5),
c = sparse_integer(1, 3, 5),
d = sparse_integer(1, 4, 5)
)

res <- sparse_dummy(x, one_hot = TRUE)
res <- sparse_dummy(x, one_hot = FALSE)
expect_identical(
res,
exp
Expand All @@ -159,7 +159,7 @@ test_that("sparse_dummy(one_hot = TRUE) works with no NAs", {
)
})

test_that("sparse_dummy(one_hot = TRUE) works with no NAs and unseen levels", {
test_that("sparse_dummy(one_hot = FALSE) works with no NAs and unseen levels", {
x <- factor(c("a", "b", "c", "d", "a"), levels = letters[1:6])
exp <- list(
b = sparse_integer(1, 2, 5),
Expand All @@ -169,7 +169,7 @@ test_that("sparse_dummy(one_hot = TRUE) works with no NAs and unseen levels", {
f = sparse_integer(integer(), integer(), 5)
)

res <- sparse_dummy(x, one_hot = TRUE)
res <- sparse_dummy(x, one_hot = FALSE)
expect_identical(
res,
exp
Expand All @@ -182,14 +182,14 @@ test_that("sparse_dummy(one_hot = TRUE) works with no NAs and unseen levels", {

## anyNA = TRUE ----------------------------------------------------------------

test_that("sparse_dummy(one_hot = TRUE) works with NA", {
test_that("sparse_dummy(one_hot = FALSE) works with NA", {
x <- factor(c("a", NA, "b", "c", "a", NA))
exp <- list(
b = sparse_integer(c(NA, 1, NA), c(2, 3, 6), 6),
c = sparse_integer(c(NA, 1, NA), c(2, 4, 6), 6)
)

res <- sparse_dummy(x, one_hot = TRUE)
res <- sparse_dummy(x, one_hot = FALSE)
expect_identical(
res,
exp
Expand All @@ -200,7 +200,7 @@ test_that("sparse_dummy(one_hot = TRUE) works with NA", {
)
})

test_that("sparse_dummy(one_hot = TRUE) works with NA and unseen levels", {
test_that("sparse_dummy(one_hot = FALSE) works with NA and unseen levels", {
x <- factor(c("a", NA, "b", "c", "a", NA), levels = letters[1:5])
exp <- list(
b = sparse_integer(c(NA, 1, NA), c(2, 3, 6), 6),
Expand All @@ -209,7 +209,7 @@ test_that("sparse_dummy(one_hot = TRUE) works with NA and unseen levels", {
e = sparse_integer(c(NA, NA), c(2, 6), 6)
)

res <- sparse_dummy(x, one_hot = TRUE)
res <- sparse_dummy(x, one_hot = FALSE)
expect_identical(
res,
exp
Expand Down

0 comments on commit 3b28b51

Please sign in to comment.