diff --git a/R/logistic.R b/R/logistic.R index 1c76d0d..07bc0ef 100644 --- a/R/logistic.R +++ b/R/logistic.R @@ -19,12 +19,22 @@ train_logistic <- function(.data, specials, ...) { keep <- complete.cases(xreg) & complete.cases(y) fit <- stats::glm.fit(xreg[keep, , drop = FALSE], y[keep, , drop = FALSE], family = stats::binomial()) - resid <- fits <- matrix(nrow = nrow(y), ncol = ncol(y)) - resid[keep, ] <- as.matrix(fit$residuals) - fit$residuals <- resid - fits[keep, ] <- as.matrix(fit$fitted.values) - fit$fitted.values <- fits - + # Fill in missing values + tmp <- matrix(nrow = nrow(y), ncol = ncol(y)) + fit$y <- y + tmp[keep, ] <- as.matrix(fit$residuals) + fit$residuals <- tmp + tmp[keep, ] <- as.matrix(fit$fitted.values) + fit$fitted.values <- tmp + tmp[keep, ] <- as.matrix(fit$effects) + fit$effects <- tmp + tmp[keep, ] <- as.matrix(fit$linear.predictors) + fit$linear.predictors <- tmp + tmp[keep,] <- as.matrix(fit$weights) + fit$weights <- tmp + tmp[keep,] <- as.matrix(fit$prior.weights) + fit$prior.weights <- tmp + if (is_empty(fit$coefficients)) { fit$coefficients <- matrix(nrow = 0, ncol = NCOL(y)) }