Skip to content

Commit

Permalink
fix forward alg start for AMEs
Browse files Browse the repository at this point in the history
  • Loading branch information
helske committed Nov 21, 2024
1 parent f36001a commit d36d4c7
Show file tree
Hide file tree
Showing 17 changed files with 138 additions and 120 deletions.
12 changes: 6 additions & 6 deletions src/mnhmm_EM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ double mnhmm_base::objective_pi(const arma::vec& x, arma::vec& grad) {
update_pi(i, current_d);
}
double sum_epi = arma::accu(E_Pi(current_d).col(i)); // this is != 1 if pseudocounts are used
double val = arma::dot(E_Pi(current_d).col(i), log_Pi(current_d));
double val = arma::dot(E_Pi(current_d).col(i), log_pi(current_d));
if (!std::isfinite(val)) {
if (!grad.is_empty()) {
grad.fill(std::numeric_limits<double>::max());
Expand All @@ -122,7 +122,7 @@ double mnhmm_base::objective_pi(const arma::vec& x, arma::vec& grad) {
// Only update grad if it's non-empty (i.e., for gradient-based optimization)
if (!grad.is_empty()) {
tmpgrad -= sum_epi * (E_Pi(current_d).col(i) / sum_epi -
Pi(current_d)) * X_pi.col(i).t();
pi(current_d)) * X_pi.col(i).t();
if (!tmpgrad.is_finite()) {
grad.fill(std::numeric_limits<double>::max());
return std::numeric_limits<double>::max();
Expand Down Expand Up @@ -645,7 +645,7 @@ Rcpp::List EM_LBFGS_mnhmm_singlechannel(
model.update_log_py(i);
for (arma::uword d = 0; d < model.D; d++) {
univariate_forward_nhmm(
log_alpha, model.log_Pi(d), model.log_A(d),
log_alpha, model.log_pi(d), model.log_A(d),
model.log_py.slice(d).cols(0, model.Ti(i) - 1)
);
univariate_backward_nhmm(
Expand Down Expand Up @@ -720,7 +720,7 @@ Rcpp::List EM_LBFGS_mnhmm_singlechannel(
model.update_log_py(i);
for (arma::uword d = 0; d < model.D; d++) {
univariate_forward_nhmm(
log_alpha, model.log_Pi(d), model.log_A(d),
log_alpha, model.log_pi(d), model.log_A(d),
model.log_py.slice(d).cols(0, model.Ti(i) - 1)
);
univariate_backward_nhmm(
Expand Down Expand Up @@ -866,7 +866,7 @@ Rcpp::List EM_LBFGS_mnhmm_multichannel(
model.update_log_py(i);
for (arma::uword d = 0; d < model.D; d++) {
univariate_forward_nhmm(
log_alpha, model.log_Pi(d), model.log_A(d),
log_alpha, model.log_pi(d), model.log_A(d),
model.log_py.slice(d).cols(0, model.Ti(i) - 1)
);
univariate_backward_nhmm(
Expand Down Expand Up @@ -941,7 +941,7 @@ Rcpp::List EM_LBFGS_mnhmm_multichannel(
model.update_log_py(i);
for (arma::uword d = 0; d < model.D; d++) {
univariate_forward_nhmm(
log_alpha, model.log_Pi(d), model.log_A(d),
log_alpha, model.log_pi(d), model.log_A(d),
model.log_py.slice(d).cols(0, model.Ti(i) - 1)
);
univariate_backward_nhmm(
Expand Down
28 changes: 14 additions & 14 deletions src/mnhmm_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ struct mnhmm_base {
arma::field<arma::mat> gamma_pi;
arma::field<arma::cube> eta_A;
arma::field<arma::cube> gamma_A;
// Pi, A, and log_p(y) of _one_ id we are currently working with
// pi, A, and log_p(y) of _one_ id we are currently working with
arma::vec omega;
arma::vec log_omega;
arma::field<arma::vec> Pi;
arma::field<arma::vec> log_Pi;
arma::field<arma::vec> pi;
arma::field<arma::vec> log_pi;
arma::field<arma::cube> A;
arma::field<arma::cube> log_A;
arma::cube log_py;
Expand Down Expand Up @@ -109,8 +109,8 @@ struct mnhmm_base {
gamma_A(eta_to_gamma(eta_A, Qs)),
omega(D),
log_omega(D),
Pi(D),
log_Pi(D),
pi(D),
log_pi(D),
A(D),
log_A(D),
log_py(S, T, D),
Expand All @@ -122,8 +122,8 @@ struct mnhmm_base {
n_obs(sum(Ti)),
lambda(lambda){
for (arma::uword d = 0; d < D; d++) {
Pi(d) = arma::vec(S);
log_Pi(d) = arma::vec(S);
pi(d) = arma::vec(S);
log_pi(d) = arma::vec(S);
A(d) = arma::cube(S, S, T);
log_A(d) = arma::cube(S, S, T);
E_Pi(d) = arma::mat(S, N);
Expand Down Expand Up @@ -157,23 +157,23 @@ struct mnhmm_base {
void update_pi(const arma::uword i) {
if (icpt_only_pi) {
for (arma::uword d = 0; d < D; d++) {
Pi(d) = softmax(gamma_pi(d).col(0));
log_Pi(d) = arma::log(Pi(d));
pi(d) = softmax(gamma_pi(d).col(0));
log_pi(d) = arma::log(pi(d));
}
} else {
for (arma::uword d = 0; d < D; d++) {
Pi(d) = softmax(gamma_pi(d) * X_pi.col(i));
log_Pi(d) = arma::log(Pi(d));
pi(d) = softmax(gamma_pi(d) * X_pi.col(i));
log_pi(d) = arma::log(pi(d));
}
}
}
void update_pi(const arma::uword i, const arma::uword d) {
if (icpt_only_pi) {
Pi(d) = softmax(gamma_pi(d).col(0));
pi(d) = softmax(gamma_pi(d).col(0));
} else {
Pi(d) = softmax(gamma_pi(d) * X_pi.col(i));
pi(d) = softmax(gamma_pi(d) * X_pi.col(i));
}
log_Pi(d) = arma::log(Pi(d));
log_pi(d) = arma::log(pi(d));
}
void update_A(const arma::uword i) {
arma::mat Atmp(S, S);
Expand Down
2 changes: 1 addition & 1 deletion src/mnhmm_mc.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ struct mnhmm_mc : public mnhmm_base {
arma::uvec M;
arma::field<arma::mat> Qm;
arma::field<arma::cube> gamma_B;
// these store Pi, A, B, and log_p(y) of _one_ id and cluster we are currently working with
// these store pi, A, B, and log_p(y) of _one_ id and cluster we are currently working with
arma::field<arma::cube> B;
arma::field<arma::cube> log_B;
// excepted counts for EM algorithm
Expand Down
2 changes: 1 addition & 1 deletion src/mnhmm_sc.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ struct mnhmm_sc : public mnhmm_base {
const arma::uword M;
arma::mat Qm;
arma::field<arma::cube> gamma_B;
// these store Pi, A, B, and log_p(y) of _one_ id we are currently working with
// these store pi, A, B, and log_p(y) of _one_ id we are currently working with
arma::field<arma::cube> B;
arma::field<arma::cube> log_B;
// excepted counts for EM algorithm
Expand Down
12 changes: 6 additions & 6 deletions src/nhmm_EM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ double nhmm_base::objective_pi(const arma::vec& x, arma::vec& grad) {
update_pi(i);
}
double sum_epi = arma::accu(E_Pi.col(i)); // this is != 1 if pseudocounts are used
double val = arma::dot(E_Pi.col(i), log_Pi);
double val = arma::dot(E_Pi.col(i), log_pi);
if (!std::isfinite(val)) {
if (!grad.is_empty()) {
grad.fill(std::numeric_limits<double>::max());
Expand All @@ -33,7 +33,7 @@ double nhmm_base::objective_pi(const arma::vec& x, arma::vec& grad) {
value -= val;
// Only update grad if it's non-empty (i.e., for gradient-based optimization)
if (!grad.is_empty()) {
tmpgrad -= sum_epi * (E_Pi.col(i) / sum_epi - Pi) * X_pi.col(i).t();
tmpgrad -= sum_epi * (E_Pi.col(i) / sum_epi - pi) * X_pi.col(i).t();
if (!tmpgrad.is_finite()) {
grad.fill(std::numeric_limits<double>::max());
return std::numeric_limits<double>::max();
Expand Down Expand Up @@ -520,7 +520,7 @@ Rcpp::List EM_LBFGS_nhmm_singlechannel(
model.update_log_py(i);

univariate_forward_nhmm(
log_alpha, model.log_Pi, model.log_A,
log_alpha, model.log_pi, model.log_A,
model.log_py.cols(0, model.Ti(i) - 1)
);

Expand Down Expand Up @@ -592,7 +592,7 @@ Rcpp::List EM_LBFGS_nhmm_singlechannel(
}
model.update_log_py(i);
univariate_forward_nhmm(
log_alpha, model.log_Pi, model.log_A,
log_alpha, model.log_pi, model.log_A,
model.log_py.cols(0, model.Ti(i) - 1)
);
univariate_backward_nhmm(
Expand Down Expand Up @@ -714,7 +714,7 @@ Rcpp::List EM_LBFGS_nhmm_multichannel(
}
model.update_log_py(i);
univariate_forward_nhmm(
log_alpha, model.log_Pi, model.log_A,
log_alpha, model.log_pi, model.log_A,
model.log_py.cols(0, model.Ti(i) - 1)
);
univariate_backward_nhmm(
Expand Down Expand Up @@ -783,7 +783,7 @@ Rcpp::List EM_LBFGS_nhmm_multichannel(
}
model.update_log_py(i);
univariate_forward_nhmm(
log_alpha, model.log_Pi, model.log_A,
log_alpha, model.log_pi, model.log_A,
model.log_py.cols(0, model.Ti(i) - 1)
);
univariate_backward_nhmm(
Expand Down
16 changes: 8 additions & 8 deletions src/nhmm_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ struct nhmm_base {
arma::mat gamma_pi;
arma::cube eta_A;
arma::cube gamma_A;
// these store Pi, A, B, and log_p(y) of _one_ id we are currently working with
arma::vec Pi;
arma::vec log_Pi;
// these store pi, A, B, and log_p(y) of _one_ id we are currently working with
arma::vec pi;
arma::vec log_pi;
arma::cube A;
arma::cube log_A;
arma::mat log_py;
Expand Down Expand Up @@ -86,8 +86,8 @@ struct nhmm_base {
gamma_pi(eta_to_gamma(eta_pi, Qs)),
eta_A(eta_A_),
gamma_A(eta_to_gamma(eta_A, Qs)),
Pi(S),
log_Pi(S),
pi(S),
log_pi(S),
A(S, S, T),
log_A(S, S, T),
log_py(S, T),
Expand All @@ -108,11 +108,11 @@ struct nhmm_base {
}
void update_pi(arma::uword i) {
if (icpt_only_pi) {
Pi = softmax(gamma_pi.col(0));
pi = softmax(gamma_pi.col(0));
} else {
Pi = softmax(gamma_pi * X_pi.col(i));
pi = softmax(gamma_pi * X_pi.col(i));
}
log_Pi = arma::log(Pi);
log_pi = arma::log(pi);
}
void update_A(arma::uword i) {
arma::mat Atmp(S, S);
Expand Down
4 changes: 2 additions & 2 deletions src/nhmm_forward.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ void forward_nhmm(Model& model, arma::cube& log_alpha) {
model.update_log_py(i);
univariate_forward_nhmm(
log_alpha.slice(i),
model.log_Pi,
model.log_pi,
model.log_A,
model.log_py.cols(0, model.Ti(i) - 1)
);
Expand Down Expand Up @@ -67,7 +67,7 @@ void forward_mnhmm(Model& model, arma::cube& log_alpha) {
log_alpha.slice(i).rows(d * model.S, (d + 1) * model.S - 1);
univariate_forward_nhmm(
submat,
model.log_omega(d) + model.log_Pi(d),
model.log_omega(d) + model.log_pi(d),
model.log_A(d),
model.log_py.slice(d).cols(0, model.Ti(i) - 1)
);
Expand Down
28 changes: 14 additions & 14 deletions src/nhmm_gradients.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,22 @@ void gradient_wrt_omega(
void gradient_wrt_pi(
arma::mat& grad, arma::mat& tmpmat,
const arma::mat& log_py, const arma::mat& log_beta, const double ll,
const arma::vec& Pi, const arma::mat& X, const arma::uword i) {
const arma::vec& pi, const arma::mat& X, const arma::uword i) {

tmpmat = -Pi * Pi.t();
tmpmat.diag() += Pi;
tmpmat = -pi * pi.t();
tmpmat.diag() += pi;
grad += tmpmat * exp(log_py.col(0) + log_beta.col(0) - ll) * X.col(i).t();
}

void gradient_wrt_pi(
arma::mat& grad, arma::mat& tmpmat,
const arma::vec& log_omega, const arma::cube& log_py,
const arma::cube& log_beta, const arma::vec& loglik,
const arma::field<arma::vec>& Pi, const arma::mat& X, const arma::uword i,
const arma::field<arma::vec>& pi, const arma::mat& X, const arma::uword i,
const arma::uword d) {

tmpmat = -Pi(d) * Pi(d).t();
tmpmat.diag() += Pi(d);
tmpmat = -pi(d) * pi(d).t();
tmpmat.diag() += pi(d);
grad += tmpmat * exp(log_omega(d) + log_py.slice(d).col(0) +
log_beta.slice(d).col(0) - loglik(i)) * X.col(i).t();
}
Expand Down Expand Up @@ -63,7 +63,7 @@ void gradient_wrt_A(
// NHMM singlechannel
void gradient_wrt_B_t0(
arma::mat& grad, arma::vec& tmpvec,
const arma::umat& obs, const arma::vec& log_Pi, const arma::mat& log_beta,
const arma::umat& obs, const arma::vec& log_pi, const arma::mat& log_beta,
const double ll, const arma::cube& B, const arma::cube& X,
const arma::uword i, const arma::uword s) {

Expand All @@ -72,7 +72,7 @@ void gradient_wrt_B_t0(
double brow = Brow(idx);
tmpvec = -Brow.t() * brow;
tmpvec(idx) += brow;
grad += exp(log_Pi(s) + log_beta(s, 0) - ll) * tmpvec * X.slice(i).col(0).t();
grad += exp(log_pi(s) + log_beta(s, 0) - ll) * tmpvec * X.slice(i).col(0).t();
}
void gradient_wrt_B(
arma::mat& grad, arma::vec& tmpvec,
Expand All @@ -92,7 +92,7 @@ void gradient_wrt_B(
// NHMM multichannel
void gradient_wrt_B_t0(
arma::mat& grad, arma::vec& tmpvec,
const arma::ucube& obs, const arma::vec& log_Pi, const arma::mat& log_beta,
const arma::ucube& obs, const arma::vec& log_pi, const arma::mat& log_beta,
const double ll, const arma::field<arma::cube>& log_B,
const arma::field<arma::cube>& B, const arma::cube& X,
const arma::uvec& M, const arma::uword i, const arma::uword s,
Expand All @@ -110,7 +110,7 @@ void gradient_wrt_B_t0(
logpy += log_B(cc)(s, obs(cc, 0, i), 0);
}
}
grad += exp(log_Pi(s) + logpy + log_beta(s, 0) - ll) * tmpvec *
grad += exp(log_pi(s) + logpy + log_beta(s, 0) - ll) * tmpvec *
X.slice(i).col(0).t();
}
void gradient_wrt_B(
Expand Down Expand Up @@ -140,7 +140,7 @@ void gradient_wrt_B(
void gradient_wrt_B_t0(
arma::mat& grad, arma::vec& tmpvec,
const arma::vec& log_omega,
const arma::umat& obs, const arma::field<arma::vec>& log_Pi,
const arma::umat& obs, const arma::field<arma::vec>& log_pi,
const arma::cube& log_beta,
const arma::vec& loglik, const arma::field<arma::cube>& B, const arma::cube& X,
const arma::uword i, const arma::uword s, const unsigned d) {
Expand All @@ -151,7 +151,7 @@ void gradient_wrt_B_t0(
double brow = Brow(idx);
tmpvec = -Brow.t() * brow;
tmpvec(idx) += brow;
grad += exp(log_omega(d) + log_Pi(d)(s) + log_beta(s, 0, d) -
grad += exp(log_omega(d) + log_pi(d)(s) + log_beta(s, 0, d) -
loglik(i)) * tmpvec * X.slice(i).col(0).t();
}
void gradient_wrt_B(
Expand All @@ -177,7 +177,7 @@ void gradient_wrt_B(
void gradient_wrt_B_t0(
arma::mat& grad, arma::vec& tmpvec,
const arma::vec& log_omega,
const arma::ucube& obs, const arma::field<arma::vec>& log_Pi,
const arma::ucube& obs, const arma::field<arma::vec>& log_pi,
const arma::cube& log_beta,
const arma::vec& loglik, const arma::field<arma::cube>& log_B,
const arma::field<arma::cube>& B, const arma::cube& X,
Expand All @@ -196,7 +196,7 @@ void gradient_wrt_B_t0(
logpy += log_B(d * C + cc)(s, obs(cc, 0, i), 0);
}
}
grad += exp(log_omega(d) + log_Pi(d)(s) + logpy +
grad += exp(log_omega(d) + log_pi(d)(s) + logpy +
log_beta(s, 0, d) - loglik(i)) * tmpvec * X.slice(i).col(0).t();
}

Expand Down
Loading

0 comments on commit d36d4c7

Please sign in to comment.