Skip to content

Commit

Permalink
Add default intercept to ulsif
Browse files Browse the repository at this point in the history
  • Loading branch information
thomvolker committed Dec 7, 2023
1 parent 93a9f65 commit fb649f2
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 9 deletions.
7 changes: 7 additions & 0 deletions R/checks.R
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,13 @@ check.centers <- function(nu, centers, ncenters) {
centers
}

check.intercept <- function(intercept) {
if (!is.logical(intercept)) {
stop("'intercept' must be either 'TRUE' or 'FALSE'")
}
intercept
}

check.symmetric <- function(nu, centers) {
if(isTRUE(all.equal(nu, centers))) {
symmetric <- TRUE
Expand Down
15 changes: 10 additions & 5 deletions R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,18 @@ predict.ulsif <- function(object, newdata = NULL, sigma = c("sigmaopt", "all"),
newlambda <- check.lambda.predict(object, lambda)
newdata <- check.newdata(object, newdata)

alpha <- extract.alpha(object, newsigma, newlambda)
nsigma <- dim(alpha)[2]
nlambda <- dim(alpha)[3]
dratio <- array(0, c(nrow(newdata), nlambda, nsigma))
alpha <- extract.alpha(object, newsigma, newlambda)
nsigma <- length(newsigma)
nlambda <- length(newlambda)
dratio <- array(0, c(nrow(newdata), nlambda, nsigma))
intercept <- nrow(object$alpha) > nrow(object$centers)

for (i in 1:nsigma) {
K <- distance(newdata, object$centers) |> kernel_gaussian(newsigma[i])
if (intercept) {
K <- cbind(0, distance(newdata, object$centers)) |> kernel_gaussian(newsigma[i])
} else {
K <- distance(newdata, object$centers) |> kernel_gaussian(newsigma[i])
}
for (j in 1:nlambda) {
dratio[ , i, j] <- K %*% alpha[, i, j]
}
Expand Down
15 changes: 11 additions & 4 deletions R/ulsif.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#' @param df_denominator \code{data.frame} with exclusively numeric variables
#' with the denominator samples (must have the same variables as
#' \code{df_denominator})
#' @param intercept \code{logical} Indicating whether to include an intercept
#' term in the model. Defaults to \code{TRUE}.
#' @param nsigma Integer indicating the number of sigma values (bandwidth
#' parameter of the Gaussian kernel gram matrix) to use in cross-validation.
#' @param sigma_quantile \code{NULL} or numeric vector with probabilities to
Expand Down Expand Up @@ -41,10 +43,10 @@
#' ulsif(x, y)
#' ulsif(x, y, sigma = 2, lambda = 2)

ulsif <- function(df_numerator, df_denominator, nsigma = 10, sigma_quantile = NULL,
sigma = NULL, nlambda = 20, lambda = NULL, ncenters = 200,
centers = NULL, parallel = FALSE, nthreads = NULL,
progressbar = TRUE) {
ulsif <- function(df_numerator, df_denominator, intercept = TRUE, nsigma = 10,
sigma_quantile = NULL, sigma = NULL, nlambda = 20,
lambda = NULL, ncenters = 200, centers = NULL,
parallel = FALSE, nthreads = NULL, progressbar = TRUE) {

cl <- match.call()
nu <- as.matrix(df_numerator)
Expand All @@ -55,9 +57,14 @@ ulsif <- function(df_numerator, df_denominator, nsigma = 10, sigma_quantile = NU
symmetric <- check.symmetric(nu, centers)
parallel <- check.parallel(parallel, nthreads, sigma, lambda)
nthreads <- check.threads(parallel, nthreads)
intercept <- check.intercept(intercept)

dist_nu <- distance(nu, centers, symmetric)
dist_de <- distance(de, centers)
if (intercept) {
dist_nu <- cbind(0, dist_nu)
dist_de <- cbind(0, dist_de)
}

sigma <- check.sigma(nsigma, sigma_quantile, sigma, dist_nu)
lambda <- check.lambda(nlambda, lambda)
Expand Down
4 changes: 4 additions & 0 deletions man/ulsif.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit fb649f2

Please sign in to comment.