Skip to content

Commit

Permalink
add vectorized versions of cox_lpdf etc. functions
Browse files Browse the repository at this point in the history
  • Loading branch information
paul-buerkner committed Sep 20, 2024
1 parent f5932e2 commit 0034377
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 45 deletions.
2 changes: 1 addition & 1 deletion R/brmsterms.R
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ brmsterms.brmsformula <- function(formula, check_response = TRUE,
y$dpars[[dp]]$respform <- y$respform
y$dpars[[dp]]$adforms <- y$adforms
}
y$dpars[[dp]]$transform <- stan_eta_transform(y$dpars[[dp]]$family, y)
y$dpars[[dp]]$transform <- stan_eta_transform(y, y$dpars[[dp]]$family)
check_cs(y$dpars[[dp]])
}

Expand Down
20 changes: 8 additions & 12 deletions R/stan-helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -87,23 +87,19 @@ stan_cor_gen_comp <- function(cor, ncol) {

# indicates if a family-link combination has a built in
# function in Stan (such as binomial_logit)
# @param bterms brmsterms object of the univariate model
# @param family a list with elements 'family' and 'link'
# ideally a (brms)family object
# @param bterms brmsterms object of the univariate model
stan_has_built_in_fun <- function(family, bterms) {
stopifnot(all(c("family", "link") %in% names(family)))
stan_has_built_in_fun <- function(bterms, family = NULL) {
stopifnot(is.brmsterms(bterms))
family <- family %||% bterms$family
stopifnot(all(c("family", "link") %in% names(family)))
link <- family[["link"]]
dpar <- family[["dpar"]]
if (has_ad_terms(bterms, c("cens", "trunc"))) {
# only few families have special lcdf and lccdf functions
out <- has_built_in_fun(family, link, cdf = TRUE) ||
has_built_in_fun(bterms, link, dpar = dpar, cdf = TRUE)
} else {
out <- has_built_in_fun(family, link) ||
has_built_in_fun(bterms, link, dpar = dpar)
}
out
# only few families have special lcdf and lccdf functions
cdf <- has_ad_terms(bterms, c("cens", "trunc"))
has_built_in_fun(family, link, cdf = cdf) ||
has_built_in_fun(bterms, link, dpar = dpar, cdf = cdf)
}

# get all variable names accepted in Stan
Expand Down
44 changes: 19 additions & 25 deletions R/stan-likelihood.R
Original file line number Diff line number Diff line change
Expand Up @@ -326,11 +326,10 @@ stan_log_lik_advars <- function(bterms, advars,

# adjust lpdf name if a more efficient version is available
# for a specific link. For instance 'poisson_log'
stan_log_lik_simple_lpdf <- function(lpdf, link, bterms, sep = "_") {
stan_log_lik_simple_lpdf <- function(lpdf, bterms, sep = "_") {
stopifnot(is.brmsterms(bterms))
has_cens_or_trunc <- has_ad_terms(bterms, c("cens", "trunc"))
if (bterms$family$link == link && !has_cens_or_trunc) {
lpdf <- paste0(lpdf, sep, link)
if (stan_has_built_in_fun(bterms)) {
lpdf <- paste0(lpdf, sep, bterms$family$link)
}
lpdf
}
Expand Down Expand Up @@ -576,7 +575,7 @@ stan_log_lik_poisson <- function(bterms, ...) {
} else {
p <- stan_log_lik_dpars(bterms)
p$mu <- stan_log_lik_multiply_rate_denom(p$mu, bterms, log = TRUE, ...)
lpdf <- stan_log_lik_simple_lpdf("poisson", "log", bterms)
lpdf <- stan_log_lik_simple_lpdf("poisson", bterms)
out <- sdist(lpdf, p$mu)
}
out
Expand All @@ -591,7 +590,7 @@ stan_log_lik_negbinomial <- function(bterms, ...) {
p <- stan_log_lik_dpars(bterms)
p$mu <- stan_log_lik_multiply_rate_denom(p$mu, bterms, log = TRUE, ...)
p$shape <- stan_log_lik_multiply_rate_denom(p$shape, bterms, ...)
lpdf <- stan_log_lik_simple_lpdf("neg_binomial_2", "log", bterms)
lpdf <- stan_log_lik_simple_lpdf("neg_binomial_2", bterms)
out <- sdist(lpdf, p$mu, p$shape)
}
out
Expand All @@ -609,7 +608,7 @@ stan_log_lik_negbinomial2 <- function(bterms, ...) {
p$shape <- stan_log_lik_multiply_rate_denom(
p$sigma, bterms, transform = "inv", ...
)
lpdf <- stan_log_lik_simple_lpdf("neg_binomial_2", "log", bterms)
lpdf <- stan_log_lik_simple_lpdf("neg_binomial_2", bterms)
out <- sdist(lpdf, p$mu, p$shape)
}
out
Expand All @@ -625,15 +624,15 @@ stan_log_lik_geometric <- function(bterms, ...) {
p$shape <- "1"
p$mu <- stan_log_lik_multiply_rate_denom(p$mu, bterms, log = TRUE, ...)
p$shape <- stan_log_lik_multiply_rate_denom(p$shape, bterms, ...)
lpdf <- stan_log_lik_simple_lpdf("neg_binomial_2", "log", bterms)
lpdf <- stan_log_lik_simple_lpdf("neg_binomial_2", bterms)
out <- sdist(lpdf, p$mu, p$shape)
}
}

stan_log_lik_binomial <- function(bterms, ...) {
p <- stan_log_lik_dpars(bterms)
p$trials <- stan_log_lik_advars(bterms, "trials", ...)$trials
lpdf <- stan_log_lik_simple_lpdf("binomial", "logit", bterms)
lpdf <- stan_log_lik_simple_lpdf("binomial", bterms)
sdist(lpdf, p$trials, p$mu)
}

Expand All @@ -655,7 +654,7 @@ stan_log_lik_bernoulli <- function(bterms, ...) {
out <- sdist("bernoulli_logit_glm", p$x, p$alpha, p$beta)
} else {
p <- stan_log_lik_dpars(bterms)
lpdf <- stan_log_lik_simple_lpdf("bernoulli", "logit", bterms)
lpdf <- stan_log_lik_simple_lpdf("bernoulli", bterms)
out <- sdist(lpdf, p$mu)
}
out
Expand All @@ -668,7 +667,7 @@ stan_log_lik_discrete_weibull <- function(bterms, ...) {

stan_log_lik_com_poisson <- function(bterms, ...) {
p <- stan_log_lik_dpars(bterms, reqn = TRUE)
lpdf <- stan_log_lik_simple_lpdf("com_poisson", "log", bterms)
lpdf <- stan_log_lik_simple_lpdf("com_poisson", bterms)
sdist(lpdf, p$mu, p$shape, vec = FALSE)
}

Expand Down Expand Up @@ -745,15 +744,10 @@ stan_log_lik_von_mises <- function(bterms, ...) {
}

stan_log_lik_cox <- function(bterms, ...) {
p <- stan_log_lik_dpars(bterms, reqn = TRUE)
resp <- usc(bterms$resp)
p$bhaz <- paste0("bhaz", resp, "[n]")
p$cbhaz <- paste0("cbhaz", resp, "[n]")
lpdf <- "cox"
if (bterms$family$link == "log") {
str_add(lpdf) <- "_log"
}
sdist(lpdf, p$mu, p$bhaz, p$cbhaz, vec = FALSE)
p <- stan_log_lik_dpars(bterms)
c(p) <- stan_log_lik_advars(bterms, c("bhaz", "cbhaz"))
lpdf <- stan_log_lik_simple_lpdf("cox", bterms)
sdist(lpdf, p$mu, p$bhaz, p$cbhaz, vec = TRUE)
}

stan_log_lik_cumulative <- function(bterms, ...) {
Expand Down Expand Up @@ -862,14 +856,14 @@ stan_log_lik_ordinal <- function(bterms, ...) {

stan_log_lik_hurdle_poisson <- function(bterms, ...) {
p <- stan_log_lik_dpars(bterms, reqn = TRUE)
lpdf <- stan_log_lik_simple_lpdf("hurdle_poisson", "log", bterms)
lpdf <- stan_log_lik_simple_lpdf("hurdle_poisson", bterms)
lpdf <- paste0(lpdf, stan_log_lik_dpar_usc_logit(bterms, "hu"))
sdist(lpdf, p$mu, p$hu, vec = FALSE)
}

stan_log_lik_hurdle_negbinomial <- function(bterms, ...) {
p <- stan_log_lik_dpars(bterms, reqn = TRUE)
lpdf <- stan_log_lik_simple_lpdf("hurdle_neg_binomial", "log", bterms)
lpdf <- stan_log_lik_simple_lpdf("hurdle_neg_binomial", bterms)
lpdf <- paste0(lpdf, stan_log_lik_dpar_usc_logit(bterms, "hu"))
sdist(lpdf, p$mu, p$shape, p$hu, vec = FALSE)
}
Expand Down Expand Up @@ -924,14 +918,14 @@ stan_log_lik_hurdle_cumulative <- function(bterms, ...) {

stan_log_lik_zero_inflated_poisson <- function(bterms, ...) {
p <- stan_log_lik_dpars(bterms, reqn = TRUE)
lpdf <- stan_log_lik_simple_lpdf("zero_inflated_poisson", "log", bterms)
lpdf <- stan_log_lik_simple_lpdf("zero_inflated_poisson", bterms)
lpdf <- paste0(lpdf, stan_log_lik_dpar_usc_logit(bterms, "zi"))
sdist(lpdf, p$mu, p$zi, vec = FALSE)
}

stan_log_lik_zero_inflated_negbinomial <- function(bterms, ...) {
p <- stan_log_lik_dpars(bterms, reqn = TRUE)
lpdf <- stan_log_lik_simple_lpdf("zero_inflated_neg_binomial", "log", bterms)
lpdf <- stan_log_lik_simple_lpdf("zero_inflated_neg_binomial", bterms)
lpdf <- paste0(lpdf, stan_log_lik_dpar_usc_logit(bterms, "zi"))
sdist(lpdf, p$mu, p$shape, p$zi, vec = FALSE)
}
Expand All @@ -940,7 +934,7 @@ stan_log_lik_zero_inflated_binomial <- function(bterms, ...) {
p <- stan_log_lik_dpars(bterms, reqn = TRUE)
p$trials <- stan_log_lik_advars(bterms, "trials", reqn = TRUE, ...)$trials
lpdf <- "zero_inflated_binomial"
lpdf <- stan_log_lik_simple_lpdf(lpdf, "logit", bterms, sep = "_b")
lpdf <- stan_log_lik_simple_lpdf(lpdf, bterms, sep = "_b")
lpdf <- paste0(lpdf, stan_log_lik_dpar_usc_logit(bterms, "zi"))
sdist(lpdf, p$trials, p$mu, p$zi, vec = FALSE)
}
Expand Down
4 changes: 2 additions & 2 deletions R/stan-predictor.R
Original file line number Diff line number Diff line change
Expand Up @@ -2110,10 +2110,10 @@ stan_eta_rsp <- function(r) {
}

# does eta need to be transformed manually using the inv_link function
stan_eta_transform <- function(family, bframe) {
stan_eta_transform <- function(bframe, family) {
no_transform <- family$link == "identity" ||
has_joint_link(family) && !is.customfamily(family)
!no_transform && !stan_has_built_in_fun(family, bframe)
!no_transform && !stan_has_built_in_fun(bframe, family)
}

# indicate if the population-level design matrix should be centered
Expand Down
34 changes: 33 additions & 1 deletion inst/chunks/fun_cox.stan
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,59 @@
real cox_lhaz(real y, real mu, real bhaz, real cbhaz) {
return log(bhaz) + log(mu);
}
vector cox_lhaz(vector y, vector mu, vector bhaz, vector cbhaz) {
return log(bhaz) + log(mu);
}

// equivalent to the log survival function
real cox_lccdf(real y, real mu, real bhaz, real cbhaz) {
// equivalent to the log survival function
return - cbhaz * mu;
}
real cox_lccdf(vector y, vector mu, vector bhaz, vector cbhaz) {
return - dot_product(cbhaz, mu);
}

real cox_lcdf(real y, real mu, real bhaz, real cbhaz) {
return log1m_exp(cox_lccdf(y | mu, bhaz, cbhaz));
}
real cox_lcdf(vector y, vector mu, vector bhaz, vector cbhaz) {
return sum(log1m_exp(- cbhaz .* mu));
}

real cox_lpdf(real y, real mu, real bhaz, real cbhaz) {
return cox_lhaz(y, mu, bhaz, cbhaz) + cox_lccdf(y | mu, bhaz, cbhaz);
}
real cox_lpdf(vector y, vector mu, vector bhaz, vector cbhaz) {
return sum(cox_lhaz(y, mu, bhaz, cbhaz)) + cox_lccdf(y | mu, bhaz, cbhaz);
}

// Distribution functions of the Cox model in log parameterization
real cox_log_lhaz(real y, real log_mu, real bhaz, real cbhaz) {
return log(bhaz) + log_mu;
}
vector cox_log_lhaz(vector y, vector log_mu, vector bhaz, vector cbhaz) {
return log(bhaz) + log_mu;
}

real cox_log_lccdf(real y, real log_mu, real bhaz, real cbhaz) {
return - cbhaz * exp(log_mu);
}
real cox_log_lccdf(vector y, vector log_mu, vector bhaz, vector cbhaz) {
return - dot_product(cbhaz, exp(log_mu));
}

real cox_log_lcdf(real y, real log_mu, real bhaz, real cbhaz) {
return log1m_exp(cox_log_lccdf(y | log_mu, bhaz, cbhaz));
}
real cox_log_lcdf(vector y, vector log_mu, vector bhaz, vector cbhaz) {
return sum(log1m_exp(- cbhaz .* exp(log_mu)));
}

real cox_log_lpdf(real y, real log_mu, real bhaz, real cbhaz) {
return cox_log_lhaz(y, log_mu, bhaz, cbhaz) +
cox_log_lccdf(y | log_mu, bhaz, cbhaz);
}
real cox_log_lpdf(vector y, vector log_mu, vector bhaz, vector cbhaz) {
return sum(cox_log_lhaz(y, log_mu, bhaz, cbhaz)) +
cox_log_lccdf(y | log_mu, bhaz, cbhaz);
}
11 changes: 7 additions & 4 deletions tests/testthat/tests.stancode.R
Original file line number Diff line number Diff line change
Expand Up @@ -1409,8 +1409,8 @@ test_that("weighted, censored, and truncated likelihoods are correct", {
expect_match2(scode, "target += poisson_lpmf(Y[Jevent[1:Nevent]] | mu[Jevent[1:Nevent]]);")
expect_match2(scode, "poisson_lcdf(rcens[Jicens[1:Nicens]] | mu[Jicens[1:Nicens]])")

scode <- stancode(y | cens(x) ~ 1, dat, family = cox())
expect_match2(scode, "target += cox_log_lccdf(Y[n] | mu[n], bhaz[n], cbhaz[n]);")
scode <- stancode(y | cens(x) ~ 1, dat, family = asym_laplace())
expect_match2(scode, "target += asym_laplace_lccdf(Y[n] | mu[n], sigma, quantile);")

dat$x[1] <- 2
scode <- stancode(y | cens(x, y2) ~ 1, dat, family = asym_laplace())
Expand Down Expand Up @@ -1717,13 +1717,16 @@ test_that("Stan code of Cox models is correct", {
x = rnorm(100), g = sample(1:3, 100, TRUE))
bform <- bf(y | cens(ce) ~ x)
scode <- stancode(bform, data, brmsfamily("cox"))
expect_match2(scode, "target += cox_log_lpdf(Y[n] | mu[n], bhaz[n], cbhaz[n]);")
expect_match2(scode,
"target += cox_log_lpdf(Y[Jevent[1:Nevent]] | mu[Jevent[1:Nevent]], bhaz[Jevent[1:Nevent]], cbhaz[Jevent[1:Nevent]]);"
)
expect_match2(scode, "vector[N] cbhaz = Zcbhaz * sbhaz;")
expect_match2(scode, "lprior += dirichlet_lpdf(sbhaz | con_sbhaz);")
expect_match2(scode, "simplex[Kbhaz] sbhaz;")

bform <- bf(y ~ x)
scode <- stancode(bform, data, brmsfamily("cox", "identity"))
expect_match2(scode, "target += cox_lccdf(Y[n] | mu[n], bhaz[n], cbhaz[n]);")
expect_match2(scode, "target += cox_lpdf(Y | mu, bhaz, cbhaz);")

bform <- bf(y | bhaz(gr = g) ~ x)
scode <- stancode(bform, data, brmsfamily("cox"))
Expand Down

0 comments on commit 0034377

Please sign in to comment.