Skip to content

Commit

Permalink
Merge pull request #1000 from wwang-chcn/suite_dev
Browse files Browse the repository at this point in the history
Add: three functions for ONTraC integration
  • Loading branch information
jiajic authored Jul 31, 2024
2 parents bb14457 + bc81691 commit 3a5667e
Showing 1 changed file with 246 additions and 3 deletions.
249 changes: 246 additions & 3 deletions R/ONTraC_wrapper.R
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,83 @@ getONTraCv1Input <- function(gobject, # nolint: object_name_linter.
}


#' @title getONTraCv2Input
#' @name getONTraCv2Input
#' @description generate the input data for ONTraC v2
#' @inheritParams data_access_params
#' @inheritParams read_data_params
#' @param output_path the path to save the output file
#' @param cell_type the cell type column name in the metadata
#' @returns data.table with columns: Cell_ID, Sample, x, y, Cell_Type
#' @details This function generate the input data for ONTraC v2
#' @examples
#' g <- GiottoData::loadGiottoMini("visium")
#'
#' getONTraCv2Input(
#' gobject = g,
#' cell_type = "custom_leiden"
#' )
#' @export
getONTraCv2Input <- function(gobject, # nolint: object_name_linter.
cell_type,
output_path = getwd(),
spat_unit = NULL,
feat_type = NULL,
verbose = TRUE) {
# Set feat_type and spat_unit
spat_unit <- set_default_spat_unit(
gobject = gobject,
spat_unit = spat_unit
)
feat_type <- set_default_feat_type(
gobject = gobject,
spat_unit = spat_unit,
feat_type = feat_type
)

pos_df <- getSpatialLocations(
gobject = gobject,
spat_unit = spat_unit,
output = "data.table"
)
meta_df <- pDataDT(
gobject = gobject,
spat_unit = spat_unit,
feat_type = feat_type
)
output_df <- merge(x = pos_df, y = meta_df, by = "cell_ID")

# check if the cell_type column exits
if (!cell_type %in% colnames(output_df)) {
vmsg(.v = verbose, paste(
"Given",
cell_type,
"do not exist in giotto object's metadata!"
))
return(NULL)
}

# add default sample name for one sample obj
if (!"list_ID" %in% colnames(output_df)) {
output_df$list_ID <- "ONTraC"
}

output_df <- output_df[, .SD, .SDcols = c(
"cell_ID",
"list_ID",
"sdimx",
"sdimy",
cell_type
)]
colnames(output_df) <- c("Cell_ID", "Sample", "x", "y", "Cell_Type")
file_path <- file.path(output_path, "ONTraC_meta_data_input.csv")
write.csv(output_df, file = file_path, quote = FALSE, row.names = FALSE)
vmsg(.v = verbose, paste("ONTraC input file was saved as", file_path))

return(output_df)
}


#' @title load_cell_bin_niche_cluster
#' @name load_cell_bin_niche_cluster
#' @description load cell-level binarized niche cluster
Expand Down Expand Up @@ -366,9 +443,9 @@ plotNicheClusterConnectivity <- function( # nolint: object_name_linter.
#' @inheritParams data_access_params
#' @inheritParams plot_output_params
#' @param spat_unit name of spatial unit niche stored cluster features
#' @param feat_type name of the feature type stored niche cluster connectivities
#' @param values name of the expression matrix stored connectivity values
#' @details This function plots the niche cluster connectivity matrix
#' @param feat_type name of the feature type stored probability matrix
#' @param values name of the expression matrix stored probability of each cell assigned to each niche cluster
#' @details This function plots the cell type composition within each niche cluster
#' @export
plotCTCompositionInNicheCluster <- function( # nolint: object_name_linter.
gobject,
Expand Down Expand Up @@ -465,6 +542,112 @@ plotCTCompositionInNicheCluster <- function( # nolint: object_name_linter.
}


#' @title plotCTCompositionInProbCluster
#' @name plotCTCompositionInProbCluster
#' @description plot cell type composition within each probabilistic cluster
#' @param cell_type the cell type column name in the metadata
#' @inheritParams data_access_params
#' @inheritParams plot_output_params
#' @param spat_unit name of spatial unit niche stored cluster features
#' @param feat_type name of the feature type stored niche cluster connectivities
#' @param values name of the expression matrix stored probability of each cell assigned to each probabilistic cluster
#' @details This function plots the cell type composition within each probabilistic cluster
#' @export
plotCTCompositionInProbCluster <- function( # nolint: object_name_linter.
gobject,
cell_type,
values = "prob",
spat_unit = "cell",
feat_type = "niche cluster",
show_plot = NULL,
return_plot = NULL,
save_plot = NULL,
save_param = list(),
theme_param = list(),
default_save_name = "plotCTCompositionInProbCluster") {
# Get the cell type composition within each niche cluster
## extract the cell-level niche cluster probability matrix
exp <- getExpression(
gobject = gobject,
values = values,
spat_unit = spat_unit,
feat_type = feat_type,
output = "exprObj"
)
prob_df <- as.data.frame(t(as.matrix(exp@exprMat)))
prob_df$cell_ID <- rownames(prob_df)
## combine the cell type and niche cluster probability matrix
combined_df <- merge(
as.data.frame(pDataDT(gobject, feat_type = feat_type))[, c(
"cell_ID",
cell_type
)],
prob_df,
by = "cell_ID"
)

# Calculate the normalized cell type composition within each niche cluster
cell_type_counts_df <- combined_df %>%
tidyr::pivot_longer(
cols = dplyr::starts_with("NicheCluster_"),
names_to = "Cluster",
values_to = "Probability"
) %>%
dplyr::group_by(
!!rlang::sym(cell_type),
Cluster # nolint: object_usage_linter.
) %>%
dplyr::summarise(Sum = sum(Probability, # nolint: object_usage_linter.
na.rm = TRUE
)) %>%
tidyr::spread(key = "Cluster", value = "Sum", fill = 0)
cell_type_counts_df <- as.data.frame(cell_type_counts_df)
rownames(cell_type_counts_df) <- cell_type_counts_df[[cell_type]]
cell_type_counts_df[[cell_type]] <- NULL
normalized_df <- as.data.frame(t(
t(cell_type_counts_df) / colSums(cell_type_counts_df)
))


# Reshape the data frame into long format
normalized_df[[cell_type]] <- rownames(normalized_df)
df_long <- normalized_df %>%
tidyr::pivot_longer(
cols = -!!rlang::sym(cell_type), # nolint: object_usage_linter.
names_to = "Cluster",
values_to = "Composition"
)

# Create the heatmap using ggplot2
pl <- ggplot(df_long, aes(
x = !!rlang::sym(cell_type), # nolint: object_usage_linter.
y = Cluster, # nolint: object_usage_linter.
fill = Composition # nolint: object_usage_linter.
)) +
geom_tile() +
viridis::scale_fill_viridis(option = "inferno", limits = c(0, 1)) +
theme_minimal() +
labs(
title = "Normalized cell type compositions within each niche cluster",
x = "Cell_Type",
y = "Cluster"
) +
theme(axis.text.x = element_text(angle = 45, hjust = 1))

# return or save
return(GiottoVisuals::plot_output_handler(
gobject = gobject,
plot_object = pl,
save_plot = save_plot,
return_plot = return_plot,
show_plot = show_plot,
default_save_name = default_save_name,
save_param = save_param,
else_return = NULL
))
}


#' @title plotCellTypeNTScore
#' @name plotCellTypeNTScore
#' @description plot NTScore by cell type
Expand Down Expand Up @@ -522,3 +705,63 @@ plotCellTypeNTScore <- function(gobject, # nolint: object_name_linter.
else_return = NULL
))
}


#' @title plotDiscreteAlongContinuous
#' @name plotDiscreteAlongContinuous
#' @description plot density of a discrete annotation along a continuou values
#' @param cell_type the column name of discrete annotation in cell metadata
#' @param values the column name of continuous values in cell metadata
#' @inheritParams data_access_params
#' @inheritParams plot_output_params
#' @export
plotCellTypeNTScore <- function(gobject, # nolint: object_name_linter.
cell_type,
values = "NTScore",
spat_unit = "cell",
feat_type = "niche cluster",
show_plot = NULL,
return_plot = NULL,
save_plot = NULL,
save_param = list(),
theme_param = list(),
default_save_name = "discreteAlongContinuous") {
# Get the cell type composition within each niche cluster
data_df <- pDataDT(
gobject = gobject,
spat_unit = spat_unit,
feat_type = feat_type
)
avg_scores <- data_df %>%
dplyr::group_by(!!rlang::sym(cell_type)) %>% # nolint: object_usage_linter.
dplyr::summarise(Avg_NTScore = mean(NTScore)) # nolint: object_usage_linter.
data_df[[cell_type]] <- factor(data_df[[cell_type]],
levels = avg_scores[[cell_type]][order(avg_scores$Avg_NTScore)]
)

pl <- ggplot(data_df, aes(
x = NTScore, # nolint: object_usage_linter.
y = !!rlang::sym(cell_type),
fill = !!rlang::sym(cell_type)
)) +
geom_violin() +
theme_minimal() +
labs(
title = "Violin Plot of NTScore by Cell Type",
x = "NTScore",
y = "Cell Type"
) +
ggplot2::theme(axis.text.x = element_text(angle = 45, hjust = 1))

# return or save
return(GiottoVisuals::plot_output_handler(
gobject = gobject,
plot_object = pl,
save_plot = save_plot,
return_plot = return_plot,
show_plot = show_plot,
default_save_name = default_save_name,
save_param = save_param,
else_return = NULL
))
}

0 comments on commit 3a5667e

Please sign in to comment.