From bc81691e29ab475c0b0b40526b39f6a5d34bdfaa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wen=20Wang=20=28=E7=8E=8B=E6=96=87=29?= Date: Wed, 31 Jul 2024 17:26:10 -0400 Subject: [PATCH] Add: new functions for ONTraC integration --- R/ONTraC_wrapper.R | 249 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 246 insertions(+), 3 deletions(-) diff --git a/R/ONTraC_wrapper.R b/R/ONTraC_wrapper.R index 0cdc02ff4..e33abeac7 100644 --- a/R/ONTraC_wrapper.R +++ b/R/ONTraC_wrapper.R @@ -75,6 +75,83 @@ getONTraCv1Input <- function(gobject, # nolint: object_name_linter. } +#' @title getONTraCv2Input +#' @name getONTraCv2Input +#' @description generate the input data for ONTraC v2 +#' @inheritParams data_access_params +#' @inheritParams read_data_params +#' @param output_path the path to save the output file +#' @param cell_type the cell type column name in the metadata +#' @returns data.table with columns: Cell_ID, Sample, x, y, Cell_Type +#' @details This function generate the input data for ONTraC v2 +#' @examples +#' g <- GiottoData::loadGiottoMini("visium") +#' +#' getONTraCv2Input( +#' gobject = g, +#' cell_type = "custom_leiden" +#' ) +#' @export +getONTraCv2Input <- function(gobject, # nolint: object_name_linter. + cell_type, + output_path = getwd(), + spat_unit = NULL, + feat_type = NULL, + verbose = TRUE) { + # Set feat_type and spat_unit + spat_unit <- set_default_spat_unit( + gobject = gobject, + spat_unit = spat_unit + ) + feat_type <- set_default_feat_type( + gobject = gobject, + spat_unit = spat_unit, + feat_type = feat_type + ) + + pos_df <- getSpatialLocations( + gobject = gobject, + spat_unit = spat_unit, + output = "data.table" + ) + meta_df <- pDataDT( + gobject = gobject, + spat_unit = spat_unit, + feat_type = feat_type + ) + output_df <- merge(x = pos_df, y = meta_df, by = "cell_ID") + + # check if the cell_type column exits + if (!cell_type %in% colnames(output_df)) { + vmsg(.v = verbose, paste( + "Given", + cell_type, + "do not exist in giotto object's metadata!" + )) + return(NULL) + } + + # add default sample name for one sample obj + if (!"list_ID" %in% colnames(output_df)) { + output_df$list_ID <- "ONTraC" + } + + output_df <- output_df[, .SD, .SDcols = c( + "cell_ID", + "list_ID", + "sdimx", + "sdimy", + cell_type + )] + colnames(output_df) <- c("Cell_ID", "Sample", "x", "y", "Cell_Type") + file_path <- file.path(output_path, "ONTraC_meta_data_input.csv") + write.csv(output_df, file = file_path, quote = FALSE, row.names = FALSE) + vmsg(.v = verbose, paste("ONTraC input file was saved as", file_path)) + + return(output_df) +} + + #' @title load_cell_bin_niche_cluster #' @name load_cell_bin_niche_cluster #' @description load cell-level binarized niche cluster @@ -366,9 +443,9 @@ plotNicheClusterConnectivity <- function( # nolint: object_name_linter. #' @inheritParams data_access_params #' @inheritParams plot_output_params #' @param spat_unit name of spatial unit niche stored cluster features -#' @param feat_type name of the feature type stored niche cluster connectivities -#' @param values name of the expression matrix stored connectivity values -#' @details This function plots the niche cluster connectivity matrix +#' @param feat_type name of the feature type stored probability matrix +#' @param values name of the expression matrix stored probability of each cell assigned to each niche cluster +#' @details This function plots the cell type composition within each niche cluster #' @export plotCTCompositionInNicheCluster <- function( # nolint: object_name_linter. gobject, @@ -465,6 +542,112 @@ plotCTCompositionInNicheCluster <- function( # nolint: object_name_linter. } +#' @title plotCTCompositionInProbCluster +#' @name plotCTCompositionInProbCluster +#' @description plot cell type composition within each probabilistic cluster +#' @param cell_type the cell type column name in the metadata +#' @inheritParams data_access_params +#' @inheritParams plot_output_params +#' @param spat_unit name of spatial unit niche stored cluster features +#' @param feat_type name of the feature type stored niche cluster connectivities +#' @param values name of the expression matrix stored probability of each cell assigned to each probabilistic cluster +#' @details This function plots the cell type composition within each probabilistic cluster +#' @export +plotCTCompositionInProbCluster <- function( # nolint: object_name_linter. + gobject, + cell_type, + values = "prob", + spat_unit = "cell", + feat_type = "niche cluster", + show_plot = NULL, + return_plot = NULL, + save_plot = NULL, + save_param = list(), + theme_param = list(), + default_save_name = "plotCTCompositionInProbCluster") { + # Get the cell type composition within each niche cluster + ## extract the cell-level niche cluster probability matrix + exp <- getExpression( + gobject = gobject, + values = values, + spat_unit = spat_unit, + feat_type = feat_type, + output = "exprObj" + ) + prob_df <- as.data.frame(t(as.matrix(exp@exprMat))) + prob_df$cell_ID <- rownames(prob_df) + ## combine the cell type and niche cluster probability matrix + combined_df <- merge( + as.data.frame(pDataDT(gobject, feat_type = feat_type))[, c( + "cell_ID", + cell_type + )], + prob_df, + by = "cell_ID" + ) + + # Calculate the normalized cell type composition within each niche cluster + cell_type_counts_df <- combined_df %>% + tidyr::pivot_longer( + cols = dplyr::starts_with("NicheCluster_"), + names_to = "Cluster", + values_to = "Probability" + ) %>% + dplyr::group_by( + !!rlang::sym(cell_type), + Cluster # nolint: object_usage_linter. + ) %>% + dplyr::summarise(Sum = sum(Probability, # nolint: object_usage_linter. + na.rm = TRUE + )) %>% + tidyr::spread(key = "Cluster", value = "Sum", fill = 0) + cell_type_counts_df <- as.data.frame(cell_type_counts_df) + rownames(cell_type_counts_df) <- cell_type_counts_df[[cell_type]] + cell_type_counts_df[[cell_type]] <- NULL + normalized_df <- as.data.frame(t( + t(cell_type_counts_df) / colSums(cell_type_counts_df) + )) + + + # Reshape the data frame into long format + normalized_df[[cell_type]] <- rownames(normalized_df) + df_long <- normalized_df %>% + tidyr::pivot_longer( + cols = -!!rlang::sym(cell_type), # nolint: object_usage_linter. + names_to = "Cluster", + values_to = "Composition" + ) + + # Create the heatmap using ggplot2 + pl <- ggplot(df_long, aes( + x = !!rlang::sym(cell_type), # nolint: object_usage_linter. + y = Cluster, # nolint: object_usage_linter. + fill = Composition # nolint: object_usage_linter. + )) + + geom_tile() + + viridis::scale_fill_viridis(option = "inferno", limits = c(0, 1)) + + theme_minimal() + + labs( + title = "Normalized cell type compositions within each niche cluster", + x = "Cell_Type", + y = "Cluster" + ) + + theme(axis.text.x = element_text(angle = 45, hjust = 1)) + + # return or save + return(GiottoVisuals::plot_output_handler( + gobject = gobject, + plot_object = pl, + save_plot = save_plot, + return_plot = return_plot, + show_plot = show_plot, + default_save_name = default_save_name, + save_param = save_param, + else_return = NULL + )) +} + + #' @title plotCellTypeNTScore #' @name plotCellTypeNTScore #' @description plot NTScore by cell type @@ -522,3 +705,63 @@ plotCellTypeNTScore <- function(gobject, # nolint: object_name_linter. else_return = NULL )) } + + +#' @title plotDiscreteAlongContinuous +#' @name plotDiscreteAlongContinuous +#' @description plot density of a discrete annotation along a continuou values +#' @param cell_type the column name of discrete annotation in cell metadata +#' @param values the column name of continuous values in cell metadata +#' @inheritParams data_access_params +#' @inheritParams plot_output_params +#' @export +plotCellTypeNTScore <- function(gobject, # nolint: object_name_linter. + cell_type, + values = "NTScore", + spat_unit = "cell", + feat_type = "niche cluster", + show_plot = NULL, + return_plot = NULL, + save_plot = NULL, + save_param = list(), + theme_param = list(), + default_save_name = "discreteAlongContinuous") { + # Get the cell type composition within each niche cluster + data_df <- pDataDT( + gobject = gobject, + spat_unit = spat_unit, + feat_type = feat_type + ) + avg_scores <- data_df %>% + dplyr::group_by(!!rlang::sym(cell_type)) %>% # nolint: object_usage_linter. + dplyr::summarise(Avg_NTScore = mean(NTScore)) # nolint: object_usage_linter. + data_df[[cell_type]] <- factor(data_df[[cell_type]], + levels = avg_scores[[cell_type]][order(avg_scores$Avg_NTScore)] + ) + + pl <- ggplot(data_df, aes( + x = NTScore, # nolint: object_usage_linter. + y = !!rlang::sym(cell_type), + fill = !!rlang::sym(cell_type) + )) + + geom_violin() + + theme_minimal() + + labs( + title = "Violin Plot of NTScore by Cell Type", + x = "NTScore", + y = "Cell Type" + ) + + ggplot2::theme(axis.text.x = element_text(angle = 45, hjust = 1)) + + # return or save + return(GiottoVisuals::plot_output_handler( + gobject = gobject, + plot_object = pl, + save_plot = save_plot, + return_plot = return_plot, + show_plot = show_plot, + default_save_name = default_save_name, + save_param = save_param, + else_return = NULL + )) +} \ No newline at end of file