Skip to content

Commit

Permalink
fix issue related to #1657
Browse files Browse the repository at this point in the history
  • Loading branch information
paul-buerkner committed Sep 23, 2024
1 parent 9053e9f commit 6c551b0
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 64 deletions.
4 changes: 1 addition & 3 deletions R/data-response.R
Original file line number Diff line number Diff line change
Expand Up @@ -328,10 +328,8 @@ data_response.brmsframe <- function(x, data, check_response = TRUE,
}
out$cens <- as.array(cens)
icens <- cens %in% 2
y2_expr <- get_ad_expr(x, "cens", "y2")
if (any(icens) || !is.null(y2_expr)) {
if (any(icens) || has_interval_cens(x)) {
# interval censoring is required
# check for 'y2' above as well to prevent issue #1367
y2 <- unname(get_ad_values(x, "cens", "y2", data))
if (is.null(y2)) {
stop2("Argument 'y2' is required for interval censored data.")
Expand Down
5 changes: 5 additions & 0 deletions R/formula-ad.R
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,11 @@ get_cens <- function(bterms, data, resp = NULL) {
out
}

# indicates if the model may have interval censored observations
has_interval_cens <- function(bterms) {
!is.null(get_ad_expr(bterms, "cens", "y2"))
}

# extract truncation boundaries
# @param bterms a brmsterms object
# @param data data.frame containing the truncation variables
Expand Down
20 changes: 6 additions & 14 deletions R/stan-likelihood.R
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,12 @@ stan_log_lik_cens <- function(ll, bterms, threads, normalize, ...) {
tp <- tp()
has_weights <- has_ad_terms(bterms, "weights")
has_trunc <- has_ad_terms(bterms, "trunc")
has_interval_cens <- cens$vars$y2 != "NA"
if (ll$vec && !(has_weights || has_trunc)) {
has_interval_cens <- has_interval_cens(bterms)
if (ll$vec && !(has_interval_cens || has_weights || has_trunc)) {
# vectorized log-likelihood contributions
types <- c("event", "rcens", "lcens", "icens")
# cannot vectorize over interval censored observations as
# vectorized lpdf functions return scalars not vectors (#1657)
types <- c("event", "rcens", "lcens")
J <- args <- named_list(types)
for (t in types) {
Jt <- glue("J{t}{resp}[1:N{t}{resp}]")
Expand All @@ -137,15 +139,6 @@ stan_log_lik_cens <- function(ll, bterms, threads, normalize, ...) {
"{tp}{ll$dist}_lccdf(Y{resp}{J$rcens}{ll$shift} | {args$rcens});\n",
"{tp}{ll$dist}_lcdf(Y{resp}{J$lcens}{ll$shift} | {args$lcens});\n"
)
if (has_interval_cens) {
rcens <- glue("rcens{resp}")
str_add(out) <- glue(
"{tp}log_diff_exp(\n",
" {ll$dist}_lcdf(rcens{resp}{J$icens}{ll$shift} | {args$icens}),\n",
" {ll$dist}_lcdf(Y{resp}{J$icens}{ll$shift} | {args$icens})\n",
" );\n"
)
}
} else {
# non-vectorized likelihood contributions
n <- stan_nn(threads)
Expand Down Expand Up @@ -219,8 +212,7 @@ stan_log_lik_mix <- function(ll, bterms, pred_mix_prob, threads,
" ps[{mix}] = {theta} + ",
"{ll$dist}_lcdf({Y}{resp}{n}{ll$shift} | {ll$args}){tr};\n"
)
has_interval_cens <- cens$vars$y2 != "NA"
if (has_interval_cens) {
if (has_interval_cens(bterms)) {
str_add(out) <- glue(
" }} else if (cens{resp}{n} == 2) {{\n",
" ps[{mix}] = {theta} + log_diff_exp(\n",
Expand Down
86 changes: 41 additions & 45 deletions R/stan-response.R
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,8 @@ stan_response <- function(bframe, threads, normalize, ...) {
" array[N{resp}] int<lower=-1,upper=2> cens{resp};\n"
)
str_add(out$pll_args) <- glue(", data array[] int cens{resp}")
y2_expr <- get_ad_expr(bframe, "cens", "y2")
is_interval_censored <- !is.null(y2_expr)
if (is_interval_censored) {
# some observations are interval censored
if (has_interval_cens(bframe)) {
# some observations may be interval censored
str_add(out$data) <- " // right censor points for interval censoring\n"
if (rtype == "int") {
str_add(out$data) <- glue(
Expand All @@ -155,48 +153,46 @@ stan_response <- function(bframe, threads, normalize, ...) {
)
str_add(out$pll_args) <- glue(", data vector rcens{resp}")
}
}
n <- stan_nn(threads)
cens_indicators_def <- glue(
" // indices of censored data\n",
" int Nevent{resp} = 0;\n",
" int Nrcens{resp} = 0;\n",
" int Nlcens{resp} = 0;\n",
" int Nicens{resp} = 0;\n",
" array[N{resp}] int Jevent{resp};\n",
" array[N{resp}] int Jrcens{resp};\n",
" array[N{resp}] int Jlcens{resp};\n",
" array[N{resp}] int Jicens{resp};\n"
)
cens_indicators_comp <- glue(
" // collect indices of censored data\n",
" for (n in 1:N{resp}) {{\n",
stan_nn_def(threads),
" if (cens{resp}{n} == 0) {{\n",
" Nevent{resp} += 1;\n",
" Jevent{resp}[Nevent{resp}] = n;\n",
" }} else if (cens{resp}{n} == 1) {{\n",
" Nrcens{resp} += 1;\n",
" Jrcens{resp}[Nrcens{resp}] = n;\n",
" }} else if (cens{resp}{n} == -1) {{\n",
" Nlcens{resp} += 1;\n",
" Jlcens{resp}[Nlcens{resp}] = n;\n",
" }} else if (cens{resp}{n} == 2) {{\n",
" Nicens{resp} += 1;\n",
" Jicens{resp}[Nicens{resp}] = n;\n",
" }}\n",
" }}\n"
)
if (use_threading(threads)) {
# in threaded Stan code, gathering the indices has to be done on the fly
# inside the reduce_sum call since the indices are dependent on the slice
# of observations whose log likelihood is being evaluated
str_add(out$fun) <- " #include 'fun_add_int.stan'\n"
str_add(out$pll_def) <- cens_indicators_def
str_add(out$model_comp_basic) <- cens_indicators_comp
} else {
str_add(out$tdata_def) <- cens_indicators_def
str_add(out$tdata_comp) <- cens_indicators_comp
# cannot yet vectorize over interval censored observations
# hence there is no need to collect the indices in that case
cens_indicators_def <- glue(
" // indices of censored data\n",
" int Nevent{resp} = 0;\n",
" int Nrcens{resp} = 0;\n",
" int Nlcens{resp} = 0;\n",
" array[N{resp}] int Jevent{resp};\n",
" array[N{resp}] int Jrcens{resp};\n",
" array[N{resp}] int Jlcens{resp};\n"
)
n <- stan_nn(threads)
cens_indicators_comp <- glue(
" // collect indices of censored data\n",
" for (n in 1:N{resp}) {{\n",
stan_nn_def(threads),
" if (cens{resp}{n} == 0) {{\n",
" Nevent{resp} += 1;\n",
" Jevent{resp}[Nevent{resp}] = n;\n",
" }} else if (cens{resp}{n} == 1) {{\n",
" Nrcens{resp} += 1;\n",
" Jrcens{resp}[Nrcens{resp}] = n;\n",
" }} else if (cens{resp}{n} == -1) {{\n",
" Nlcens{resp} += 1;\n",
" Jlcens{resp}[Nlcens{resp}] = n;\n",
" }}\n",
" }}\n"
)
if (use_threading(threads)) {
# in threaded Stan code, gathering the indices has to be done on the fly
# inside the reduce_sum call since the indices are dependent on the slice
# of observations whose log likelihood is being evaluated
str_add(out$fun) <- " #include 'fun_add_int.stan'\n"
str_add(out$pll_def) <- cens_indicators_def
str_add(out$model_comp_basic) <- cens_indicators_comp
} else {
str_add(out$tdata_def) <- cens_indicators_def
str_add(out$tdata_comp) <- cens_indicators_comp
}
}
}
bounds <- bframe$frame$resp$bounds
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/tests.stancode.R
Original file line number Diff line number Diff line change
Expand Up @@ -1406,8 +1406,8 @@ test_that("weighted, censored, and truncated likelihoods are correct", {
)

scode <- stancode(y | cens(x, y2) ~ 1, dat, family = poisson())
expect_match2(scode, "target += poisson_lpmf(Y[Jevent[1:Nevent]] | mu[Jevent[1:Nevent]]);")
expect_match2(scode, "poisson_lcdf(rcens[Jicens[1:Nicens]] | mu[Jicens[1:Nicens]])")
expect_match2(scode, "target += poisson_lpmf(Y[n] | mu[n]);")
expect_match2(scode, "poisson_lcdf(rcens[n] | mu[n])")

scode <- stancode(y | cens(x) ~ 1, dat, family = asym_laplace())
expect_match2(scode, "target += asym_laplace_lccdf(Y[n] | mu[n], sigma, quantile);")
Expand Down

0 comments on commit 6c551b0

Please sign in to comment.