Skip to content

Commit

Permalink
residuals.nnet now handling missing values properly
Browse files Browse the repository at this point in the history
  • Loading branch information
robjhyndman committed Dec 11, 2023
1 parent af4d9fb commit bfcce99
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions R/nnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,33 +52,33 @@ train_nnet <- function(.data, specials, n_nodes, n_networks, scale_inputs, wts =
}

# Remove missing values if present
j <- complete.cases(xreg, y)
xreg <- xreg[j, , drop = FALSE]
y <- y[j]
nonmissing <- complete.cases(xreg, y)
## Stop if there's no data to fit
if (NROW(xreg) == 0) {
if (NROW(xreg[nonmissing, , drop = FALSE]) == 0) {
abort("No data to fit (possibly due to missing values)")
}

# Fit the nnet and consider the Wts argument for nnet::nnet() if provided:
if (is.null(wts)) {
nn_models <- map(
seq_len(n_networks),
function(.) wrap_nnet(xreg, y, size = n_nodes, ...)
function(.) wrap_nnet(xreg[nonmissing, , drop = FALSE], y[nonmissing], size = n_nodes, ...)
)
} else {
maxnwts <- max(lengths(wts), na.rm = TRUE)
nn_models <- map(
wts,
function(i) {
wrap_nnet(x = xreg, y = y, size = n_nodes, MaxNWts = maxnwts, Wts = i, ...)
wrap_nnet(x = xreg[nonmissing, , drop = FALSE], y = y[nonmissing], size = n_nodes, MaxNWts = maxnwts, Wts = i, ...)
})
}

# Calculate fitted values
pred <- map_dbl(transpose(map(nn_models, predict)), function(x) mean(unlist(x)))
fits <- pred
fits <- y*NA
fits[nonmissing] <- pred
res <- y - fits


# Construct model output
structure(
Expand Down

0 comments on commit bfcce99

Please sign in to comment.