Skip to content

Commit

Permalink
major update removing non-burnin samplers
Browse files Browse the repository at this point in the history
  • Loading branch information
Fradenti committed May 16, 2024
1 parent 8d5d193 commit 23d1c3a
Show file tree
Hide file tree
Showing 31 changed files with 1,548 additions and 1,264 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# SANple 0.1.1

* Updated the MCMC functions to take a `burn in` value as input;
* Some algorithms ran `nrep-1` iterations. Updated to `nrep`;
* Improved efficiency of stick-breaking computation;
* Improved the initialization of the algorithms, streamlined some scripts;
* Changed `.cpp` `for loops` indexes from `int` to `unsigned int` when needed;
Expand Down
20 changes: 10 additions & 10 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
# Generated by using Rcpp::compileAttributes() -> do not edit by hand
# Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393

sample_cam_arma <- function(nrep, y, group, maxK, maxL, m0, tau0, lambda0, gamma0, fixed_alpha, fixed_beta, alpha, beta, hyp_alpha1, hyp_alpha2, hyp_beta1, hyp_beta2, mu_start, sigma2_start, M_start, S_start, alpha_start, beta_start, progressbar) {
.Call(`_SANple_sample_cam_arma`, nrep, y, group, maxK, maxL, m0, tau0, lambda0, gamma0, fixed_alpha, fixed_beta, alpha, beta, hyp_alpha1, hyp_alpha2, hyp_beta1, hyp_beta2, mu_start, sigma2_start, M_start, S_start, alpha_start, beta_start, progressbar)
sample_cam_burn <- function(nrep, burn, y, group, maxK, maxL, m0, tau0, lambda0, gamma0, fixed_alpha, fixed_beta, alpha, beta, hyp_alpha1, hyp_alpha2, hyp_beta1, hyp_beta2, mu_start, sigma2_start, M_start, S_start, alpha_start, beta_start, progressbar) {
.Call(`_SANple_sample_cam_burn`, nrep, burn, y, group, maxK, maxL, m0, tau0, lambda0, gamma0, fixed_alpha, fixed_beta, alpha, beta, hyp_alpha1, hyp_alpha2, hyp_beta1, hyp_beta2, mu_start, sigma2_start, M_start, S_start, alpha_start, beta_start, progressbar)
}

sample_fcam_arma <- function(nrep, y, group, maxK, maxL, m0, tau0, lambda0, gamma0, fixed_alpha, fixed_beta, alpha, beta, hyp_alpha1, hyp_alpha2, hyp_beta1, hyp_beta2, mu_start, sigma2_start, M_start, S_start, alpha_start, beta_start, eps_alpha, eps_beta, progressbar) {
.Call(`_SANple_sample_fcam_arma`, nrep, y, group, maxK, maxL, m0, tau0, lambda0, gamma0, fixed_alpha, fixed_beta, alpha, beta, hyp_alpha1, hyp_alpha2, hyp_beta1, hyp_beta2, mu_start, sigma2_start, M_start, S_start, alpha_start, beta_start, eps_alpha, eps_beta, progressbar)
sample_fcam_burn <- function(nrep, burn, y, group, maxK, maxL, m0, tau0, lambda0, gamma0, fixed_alpha, fixed_beta, alpha, beta, hyp_alpha1, hyp_alpha2, hyp_beta1, hyp_beta2, mu_start, sigma2_start, M_start, S_start, alpha_start, beta_start, eps_alpha, eps_beta, progressbar) {
.Call(`_SANple_sample_fcam_burn`, nrep, burn, y, group, maxK, maxL, m0, tau0, lambda0, gamma0, fixed_alpha, fixed_beta, alpha, beta, hyp_alpha1, hyp_alpha2, hyp_beta1, hyp_beta2, mu_start, sigma2_start, M_start, S_start, alpha_start, beta_start, eps_alpha, eps_beta, progressbar)
}

sample_ficam_arma <- function(nrep, y, group, maxK, maxL, m0, tau0, lambda0, gamma0, fixed_alpha, fixed_beta, alpha, beta, hyp_alpha1, hyp_alpha2, hyp_beta1, hyp_beta2, mu_start, sigma2_start, M_start, S_start, alpha_start, beta_start, eps_beta, progressbar) {
.Call(`_SANple_sample_ficam_arma`, nrep, y, group, maxK, maxL, m0, tau0, lambda0, gamma0, fixed_alpha, fixed_beta, alpha, beta, hyp_alpha1, hyp_alpha2, hyp_beta1, hyp_beta2, mu_start, sigma2_start, M_start, S_start, alpha_start, beta_start, eps_beta, progressbar)
sample_ficam_burn <- function(nrep, burn, y, group, maxK, maxL, m0, tau0, lambda0, gamma0, fixed_alpha, fixed_beta, alpha, beta, hyp_alpha1, hyp_alpha2, hyp_beta1, hyp_beta2, mu_start, sigma2_start, M_start, S_start, alpha_start, beta_start, eps_beta, progressbar) {
.Call(`_SANple_sample_ficam_burn`, nrep, burn, y, group, maxK, maxL, m0, tau0, lambda0, gamma0, fixed_alpha, fixed_beta, alpha, beta, hyp_alpha1, hyp_alpha2, hyp_beta1, hyp_beta2, mu_start, sigma2_start, M_start, S_start, alpha_start, beta_start, eps_beta, progressbar)
}

sample_overcam_arma <- function(nrep, y, group, maxK, maxL, m0, tau0, lambda0, gamma0, fixed_alpha, fixed_beta, alpha, beta, hyp_alpha, hyp_beta, mu_start, sigma2_start, M_start, S_start, alpha_start, beta_start, eps_alpha, eps_beta, progressbar) {
.Call(`_SANple_sample_overcam_arma`, nrep, y, group, maxK, maxL, m0, tau0, lambda0, gamma0, fixed_alpha, fixed_beta, alpha, beta, hyp_alpha, hyp_beta, mu_start, sigma2_start, M_start, S_start, alpha_start, beta_start, eps_alpha, eps_beta, progressbar)
sample_overcam_burn <- function(nrep, burn, y, group, maxK, maxL, m0, tau0, lambda0, gamma0, fixed_alpha, fixed_beta, alpha, beta, hyp_alpha, hyp_beta, mu_start, sigma2_start, M_start, S_start, alpha_start, beta_start, eps_alpha, eps_beta, progressbar) {
.Call(`_SANple_sample_overcam_burn`, nrep, burn, y, group, maxK, maxL, m0, tau0, lambda0, gamma0, fixed_alpha, fixed_beta, alpha, beta, hyp_alpha, hyp_beta, mu_start, sigma2_start, M_start, S_start, alpha_start, beta_start, eps_alpha, eps_beta, progressbar)
}

sample_overficam_arma <- function(nrep, y, group, maxK, maxL, m0, tau0, lambda0, gamma0, fixed_alpha, fixed_beta, alpha, beta, hyp_alpha1, hyp_alpha2, hyp_beta, mu_start, sigma2_start, M_start, S_start, alpha_start, beta_start, eps_beta, progressbar) {
.Call(`_SANple_sample_overficam_arma`, nrep, y, group, maxK, maxL, m0, tau0, lambda0, gamma0, fixed_alpha, fixed_beta, alpha, beta, hyp_alpha1, hyp_alpha2, hyp_beta, mu_start, sigma2_start, M_start, S_start, alpha_start, beta_start, eps_beta, progressbar)
sample_overficam_burn <- function(nrep, burn, y, group, maxK, maxL, m0, tau0, lambda0, gamma0, fixed_alpha, fixed_beta, alpha, beta, hyp_alpha1, hyp_alpha2, hyp_beta, mu_start, sigma2_start, M_start, S_start, alpha_start, beta_start, eps_beta, progressbar) {
.Call(`_SANple_sample_overficam_burn`, nrep, burn, y, group, maxK, maxL, m0, tau0, lambda0, gamma0, fixed_alpha, fixed_beta, alpha, beta, hyp_alpha1, hyp_alpha2, hyp_beta, mu_start, sigma2_start, M_start, S_start, alpha_start, beta_start, eps_beta, progressbar)
}

18 changes: 13 additions & 5 deletions R/estimate_clusters.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
#' set.seed(123)
#' y <- c(rnorm(40,0,0.3), rnorm(20,5,0.3))
#' g <- c(rep(1,30), rep(2, 30))
#' out <- sample_fiSAN(nrep = 500, y = y, group = g,
#' out <- sample_fiSAN(nrep = 500, burn = 200,
#' y = y, group = g,
#' nclus_start = 2,
#' maxK = 20, maxL = 20,
#' beta = 1)
Expand All @@ -28,12 +29,19 @@
#' @export
#' @importFrom salso salso
#' @useDynLib SANple
estimate_clusters <- function(object, burnin = NULL, ncores = 0)
estimate_clusters <- function(object, burnin = 0, ncores = 0)
{

if(is.null(burnin)) { burnin <- 1:round(object$params$nrep/3*2) }
estimated_oc <- suppressWarnings(salso::salso(object$sim$obs_cluster[-burnin,], nCores = ncores))
estimated_dc <- suppressWarnings(salso::salso(object$sim$distr_cluster[-burnin,], nCores = ncores))
if(burnin>0) {
OC <- object$sim$obs_cluster[-burnin,]
DC <- object$sim$distr_cluster[-burnin,]
}else{
OC <- object$sim$obs_cluster
DC <- object$sim$distr_cluster
}

estimated_oc <- suppressWarnings(salso::salso(OC, nCores = ncores))
estimated_dc <- suppressWarnings(salso::salso(DC, nCores = ncores))

n_oc <- length(unique(estimated_oc))
n_dc <- length(unique(estimated_dc))
Expand Down
75 changes: 40 additions & 35 deletions R/mcmc_CAM.R → R/mcmc_CAM_burn.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#' The implemented algorithm is based on the nested slice sampler of Denti et al. (2023), based on the algorithm of Kalli, Griffin and Walker (2011).
#'
#' @usage
#' sample_CAM(nrep, y, group,
#' sample_CAM(nrep, burn, y, group,
#' maxK = 50, maxL = 50,
#' m0 = 0, tau0 = 0.1, lambda0 = 3, gamma0 = 2,
#' hyp_alpha1 = 1, hyp_alpha2 = 1,
Expand All @@ -18,6 +18,7 @@
#' progress = TRUE, seed = NULL)
#'
#' @param nrep Number of MCMC iterations.
#' @param burn Number of discarded iterations.
#' @param y Vector of observations.
#' @param group Vector of the same length of y indicating the group membership (numeric).
#' @param maxK Maximum number of distributional clusters (default = 50).
Expand Down Expand Up @@ -127,7 +128,7 @@
#' g <- c(rep(1,30), rep(2, 30))
#' plot(density(y[g==1]), xlim = c(-5,10))
#' lines(density(y[g==2]), col = 2)
#' out <- sample_CAM(nrep = 500, y = y, group = g,
#' out <- sample_CAM(nrep = 500, burn = 200, y = y, group = g,
#' nclus_start = 2,
#' maxL = 20, maxK = 20)
#' out
Expand All @@ -143,7 +144,7 @@
#' @export sample_CAM
#'
#' @importFrom stats cor var dist hclust cutree rgamma
sample_CAM = function(nrep, y, group,
sample_CAM = function(nrep, burn, y, group,
maxK = 50,
maxL = 50,
m0 = 0, tau0 = 0.1, lambda0 = 3, gamma0 = 2,
Expand All @@ -163,21 +164,21 @@ sample_CAM = function(nrep, y, group,
set.seed(seed)

params <- list(nrep = nrep,
y = y,
group = group+1,
maxK = maxK,
maxL = maxL,
m0 = m0, tau0 = tau0,
lambda0 = lambda0, gamma0 = gamma0,
seed = seed)
y = y,
group = group+1,
maxK = maxK,
maxL = maxL,
m0 = m0, tau0 = tau0,
lambda0 = lambda0, gamma0 = gamma0,
seed = seed)

if(!is.null(alpha)) { params$alpha <- alpha }
if(!is.null(beta)) { params$beta <- beta }
if(is.null(alpha)) { params$hyp_alpha1 <- hyp_alpha1 }
if(is.null(alpha)) { params$hyp_alpha2 <- hyp_alpha2 }
if(is.null(beta)) { params$hyp_beta1 <- hyp_beta1 }
if(is.null(beta)) { params$hyp_beta2 <- hyp_beta2 }

if(is.null(S_start)) { S_start <- rep(0,length(unique(group))) }

# if the initial cluster allocation is passed
Expand Down Expand Up @@ -207,9 +208,9 @@ sample_CAM = function(nrep, y, group,

if(is.null(nclus_start)) { nclus_start = min(c(maxL, 30))}
M_start <- stats::kmeans(y,
centers = nclus_start,
algorithm="MacQueen",
iter.max = 50)$cluster
centers = nclus_start,
algorithm="MacQueen",
iter.max = 50)$cluster

# if the initial cluster allocation is not passed
# and you want a warmstart
Expand All @@ -223,7 +224,7 @@ sample_CAM = function(nrep, y, group,
sigma2_start[1:nclus_start][sigma2_start[1:nclus_start]==0] <- 0.001
sigma2_start[is.na(sigma2_start)] <- 0.001
}

# if the initial cluster allocation is not passed
# and you don't want a warmstart
if(!warmstart){
Expand All @@ -234,7 +235,7 @@ sample_CAM = function(nrep, y, group,
sigma2_start[1] <- var(y)/2
}
}

M_start <- M_start-1
sigma2_start[is.na(sigma2_start)] <- 0.001

Expand All @@ -245,12 +246,16 @@ sample_CAM = function(nrep, y, group,
fixed_alpha <- F
fixed_beta <- F
if(!is.null(alpha) ) {
fixed_alpha <- T } else { alpha <- 1 }
fixed_alpha <- T ;
alpha_start <- alpha
} else { alpha <- 1 }
if(!is.null(beta) ) {
fixed_beta <- T } else { beta <- 1}
beta_start <- beta
fixed_beta <- T ;
eps_beta <- 1 } else { beta <- 1 }

start = Sys.time()
out = sample_cam_arma(nrep, y, group,
out = sample_cam_burn(nrep, burn, y, group,
maxK, maxL,
m0, tau0,
lambda0, gamma0,
Expand All @@ -262,7 +267,7 @@ sample_CAM = function(nrep, y, group,
M_start, S_start,
alpha_start, beta_start,
progress
)
)
end = Sys.time()

warnings <- out$warnings
Expand All @@ -273,34 +278,34 @@ sample_CAM = function(nrep, y, group,

if(length(warnings) == 2) {
output <- list( "model" = "CAM",
"params" = params,
"sim" = out,
"time" = end - start,
"warnings" = warnings)
"params" = params,
"sim" = out,
"time" = end - start,
"warnings" = warnings)
warning("Increase maxL and maxK: all the provided mixture components were used. Check $warnings to see when it happened.")
} else if (length(warnings) == 1) {
if((length(warnings$top_maxK)>0) & (length(warnings$top_maxL)==0)) {
output <- list( "model" = "CAM",
"params" = params,
"sim" = out,
"time" = end - start,
"warnings" = warnings)
"params" = params,
"sim" = out,
"time" = end - start,
"warnings" = warnings)
warning("Increase maxK: all the provided distributional mixture components were used. Check '$warnings' to see when it happened.")
}

if((length(warnings$top_maxK)==0) & (length(warnings$top_maxL)>0)) {
output <- list( "model" = "CAM",
"params" = params,
"sim" = out,
"time" = end - start,
"warnings" = warnings)
"params" = params,
"sim" = out,
"time" = end - start,
"warnings" = warnings)
warning("Increase maxL: all the provided observational mixture components were used. Check '$warnings' to see when it happened.")
}
} else {
output <- list( "model" = "CAM",
"params" = params,
"sim" = out,
"time" = end - start )
"params" = params,
"sim" = out,
"time" = end - start )
}

structure(output, class = c("SANmcmc",class(output)))
Expand Down
Loading

0 comments on commit 23d1c3a

Please sign in to comment.