From ab2bd50e3978414f6355edf392ec99f8347b358f Mon Sep 17 00:00:00 2001 From: kalidouBA Date: Mon, 4 Dec 2023 09:56:04 +0100 Subject: [PATCH] with non negative --- R/NMF_optim.R | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/R/NMF_optim.R b/R/NMF_optim.R index cea3b18..3855e3b 100644 --- a/R/NMF_optim.R +++ b/R/NMF_optim.R @@ -7,6 +7,7 @@ #' @param k Number of components. #' @param W Initial matrix W (optional, defaults to random non-negative values). #' @param H Initial matrix H (optional, defaults to random non-negative values). +#' @param notNegative Logical. If TRUE, the factorization is constrained to non-negative elements. #' @param max_iter Maximum number of iterations (default is 100). #' @return A list containing factorized matrices W and H. #' @@ -22,7 +23,7 @@ #' #' @export -nmf_conjugate_gradient <- function(V, k, W = NULL, H = NULL, max_iter = 100) { +nmf_conjugate_gradient <- function(V, k, W = NULL, H = NULL, notNegative = FALSE, max_iter = 100) { # Initialize W and H with random non-negative values if(is.null(W) || is.null(H)){ @@ -63,13 +64,14 @@ nmf_conjugate_gradient <- function(V, k, W = NULL, H = NULL, max_iter = 100) { # Perform projected conjugate gradient optimization result <- optim(par = theta, fn = obj_fun, gr = grad_obj_fun, method = "L-BFGS-B", control = list(maxit = max_iter)) - # Extract factorized matrices W and H and enforce non-negativity - # W <- project_to_non_negative(matrix(result$par[1:(nrow(V) * k)], nrow = nrow(V), ncol = k)) - # H <- project_to_non_negative(matrix(result$par[(nrow(V) * k + 1):length(result$par)], nrow = k, ncol = ncol(V))) - W <- matrix(result$par[1:(nrow(V) * k)], nrow = nrow(V), ncol = k) H <- matrix(result$par[(nrow(V) * k + 1):length(result$par)], nrow = k, ncol = ncol(V)) + # Extract factorized matrices W and H and enforce non-negativity + if(notNegative){ + W <- project_to_non_negative(W) + H <- project_to_non_negative(H) + } n <- length(V) squared_diff <- (V - W%*%H)^2 mse <- sum(squared_diff) / n