Skip to content

Commit

Permalink
with non negative
Browse files Browse the repository at this point in the history
  • Loading branch information
kalidouBA committed Dec 4, 2023
1 parent cd96f01 commit ab2bd50
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions R/NMF_optim.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#' @param k Number of components.
#' @param W Initial matrix W (optional, defaults to random non-negative values).
#' @param H Initial matrix H (optional, defaults to random non-negative values).
#' @param notNegative Logical. If TRUE, the factorization is constrained to non-negative elements.
#' @param max_iter Maximum number of iterations (default is 100).
#' @return A list containing factorized matrices W and H.
#'
Expand All @@ -22,7 +23,7 @@
#'
#' @export

nmf_conjugate_gradient <- function(V, k, W = NULL, H = NULL, max_iter = 100) {
nmf_conjugate_gradient <- function(V, k, W = NULL, H = NULL, notNegative = FALSE, max_iter = 100) {

# Initialize W and H with random non-negative values
if(is.null(W) || is.null(H)){
Expand Down Expand Up @@ -63,13 +64,14 @@ nmf_conjugate_gradient <- function(V, k, W = NULL, H = NULL, max_iter = 100) {
# Perform projected conjugate gradient optimization
result <- optim(par = theta, fn = obj_fun, gr = grad_obj_fun, method = "L-BFGS-B", control = list(maxit = max_iter))

# Extract factorized matrices W and H and enforce non-negativity
# W <- project_to_non_negative(matrix(result$par[1:(nrow(V) * k)], nrow = nrow(V), ncol = k))
# H <- project_to_non_negative(matrix(result$par[(nrow(V) * k + 1):length(result$par)], nrow = k, ncol = ncol(V)))

W <- matrix(result$par[1:(nrow(V) * k)], nrow = nrow(V), ncol = k)
H <- matrix(result$par[(nrow(V) * k + 1):length(result$par)], nrow = k, ncol = ncol(V))

# Extract factorized matrices W and H and enforce non-negativity
if(notNegative){
W <- project_to_non_negative(W)
H <- project_to_non_negative(H)
}
n <- length(V)
squared_diff <- (V - W%*%H)^2
mse <- sum(squared_diff) / n
Expand Down

0 comments on commit ab2bd50

Please sign in to comment.