Skip to content

Commit

Permalink
update from mudata to spatialdata
Browse files Browse the repository at this point in the history
  • Loading branch information
SarahOuologuem committed Jan 14, 2025
1 parent 7e04d41 commit 1e002b5
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 72 deletions.
6 changes: 3 additions & 3 deletions panpipes/panpipes/pipeline_deconvolution_spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ def get_logger():


def gen_filter_jobs():
input_paths_spatial=glob.glob(os.path.join(PARAMS["input_spatial"],"*.h5mu"))
input_paths_spatial=glob.glob(os.path.join(PARAMS["input_spatial"],"*.zarr"))
input_singlecell = PARAMS["input_singlecell"]
for input_spatial in input_paths_spatial:
sample_prefix = os.path.basename(input_spatial)
sample_prefix = sample_prefix.replace(".h5mu","")
outfile_spatial = "cell2location.output/" + sample_prefix + "/Cell2Loc_spatial_output.h5mu"
sample_prefix = sample_prefix.replace(".zarr","")
outfile_spatial = "cell2location.output/" + sample_prefix + "/Cell2Loc_spatial_output.zarr"
yield input_spatial, outfile_spatial, sample_prefix, input_singlecell


Expand Down
95 changes: 49 additions & 46 deletions panpipes/python_scripts/run_cell2location.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import cell2location as c2l
import scanpy as sc
import pandas as pd
import muon as mu
import spatialdata as sd

import os
import argparse
Expand All @@ -20,6 +20,7 @@
from panpipes.funcs.scmethods import cell2loc_filter_genes



L = logging.getLogger()
L.setLevel(logging.INFO)
log_handler = logging.StreamHandler(sys.stdout)
Expand Down Expand Up @@ -197,13 +198,15 @@

#1. read in the data
#spatial:
L.info("Reading in spatial MuData from '%s'" % args.input_spatial)
mdata_spatial = mu.read(args.input_spatial)
adata_st = mdata_spatial.mod['spatial']
L.info("Reading in spatial SpatialData from '%s'" % args.input_spatial)
sdata_st = sd.read_zarr(args.input_spatial)
#mdata_spatial = mu.read(args.input_spatial)
#adata_st = mdata_spatial.mod['spatial']
#single-cell:
L.info("Reading in reference MuData from '%s'" % args.input_singlecell)
mdata_singlecell = mu.read(args.input_singlecell)
adata_sc = mdata_singlecell.mod['rna']
L.info("Reading in reference SpatialData from '%s'" % args.input_singlecell)
sdata_sc = sd.read_zarr(args.input_singlecell)
#mdata_singlecell = mu.read(args.input_singlecell)
#adata_sc = mdata_singlecell.mod['rna']



Expand All @@ -217,12 +220,12 @@
reduced_gene_set = pd.read_csv(args.gene_list, header = 0)
reduced_gene_set.columns = ["HVGs"]
L.info("Subsetting data on gene list")
adata_sc.var["selected_gene"] = adata_sc.var.index.isin(reduced_gene_set["HVGs"])
adata_st.var["selected_gene"] = adata_st.var.index.isin(reduced_gene_set["HVGs"])
adata_sc = adata_sc[:, adata_sc.var["selected_gene"]]
adata_st = adata_st[:, adata_st.var["selected_gene"]]
sdata_sc["table"].var["selected_gene"] = sdata_sc["table"].var.index.isin(reduced_gene_set["HVGs"])
sdata_st["table"].var["selected_gene"] = sdata_st["table"].var.index.isin(reduced_gene_set["HVGs"])
sdata_sc["table"] = sdata_sc["table"][:, sdata_sc["table"].var["selected_gene"]]
sdata_st["table"] = sdata_st["table"][:, sdata_st["table"].var["selected_gene"]]
# check whether all genes are present in both, spatial & reference
if set(adata_st.var.index) != set(adata_sc.var.index):
if set(sdata_st["table"].var.index) != set(sdata_sc["table"].var.index):
L.error(
"Not all genes of the gene list %s are present in the reference as well as in the ST data. Please provide a gene list where all genes are present in both, reference and ST.", args.gene_list)
sys.exit(
Expand All @@ -231,34 +234,34 @@
else: # perform feature selection according to cell2loc
if remove_mt is True:
L.info("Removing MT genes")
adata_st.var["MT_gene"] = [gene.startswith("MT-") for gene in adata_st.var.index]
adata_st.obsm["MT"] = adata_st[:, adata_st.var["MT_gene"].values].X.toarray()
adata_st = adata_st[:, ~adata_st.var["MT_gene"].values]
sdata_st["table"].var["MT_gene"] = [gene.startswith("MT-") for gene in sdata_st["table"].var.index]
sdata_st["table"].obsm["MT"] = sdata_st["table"][:, sdata_st["table"].var["MT_gene"].values].X.toarray()
sdata_st["table"] = sdata_st["table"][:, ~sdata_st["table"].var["MT_gene"].values]
# intersect vars of reference and spatial
L.info("Intersecting vars of reference and spatial ")
shared_features = [feature for feature in adata_st.var_names if feature in adata_sc.var_names]
adata_sc = adata_sc[:, shared_features]
adata_st = adata_st[:, shared_features]
shared_features = [feature for feature in sdata_st["table"].var_names if feature in sdata_sc["table"].var_names]
sdata_sc["table"] = sdata_sc["table"][:, shared_features]
sdata_st["table"] = sdata_st["table"][:, shared_features]
# select features
L.info("Selecting features using 'cell2location.utils.filtering.filter_genes() function'")
selected = cell2loc_filter_genes(adata_sc, figdir + "/gene_filter.png", cell_count_cutoff=float(args.cell_count_cutoff),
selected = cell2loc_filter_genes(sdata_sc["table"], figdir + "/gene_filter.png", cell_count_cutoff=float(args.cell_count_cutoff),
cell_percentage_cutoff2=float(args.cell_percentage_cutoff2),
nonz_mean_cutoff=float(args.nonz_mean_cutoff))
L.info("Subsetting data on selected features")
adata_sc = adata_sc[:, selected]
adata_st = adata_st[:, selected]
sdata_sc["table"] = sdata_sc["table"][:, selected]
sdata_st["table"] = sdata_st["table"][:, selected]



# 3. Fit regression model
L.info("Setting up AnnData for the reference model")
c2l.models.RegressionModel.setup_anndata(adata=adata_sc,
c2l.models.RegressionModel.setup_anndata(adata=sdata_sc["table"],
labels_key = args.labels_key_reference,
layer= args.layer_reference,
batch_key= args.batch_key_reference,
categorical_covariate_keys = categorical_covariate_keys_reference,
continuous_covariate_keys = continuous_covariate_keys_reference)
model_ref = c2l.models.RegressionModel(adata_sc)
model_ref = c2l.models.RegressionModel(sdata_sc["table"])
L.info("Training the reference model")
model_ref.train(max_epochs=max_epochs_reference, use_gpu = use_gpu_reference)

Expand All @@ -268,23 +271,23 @@

# export results
L.info("Extracting the posterior of the reference model")
adata_sc = model_ref.export_posterior(adata_sc)
if "means_per_cluster_mu_fg" in adata_sc.varm.keys():
inf_aver = adata_sc.varm["means_per_cluster_mu_fg"][[f"means_per_cluster_mu_fg_{i}" for i in adata_sc.uns["mod"]["factor_names"]]].copy()
sdata_sc["table"] = model_ref.export_posterior(sdata_sc["table"])
if "means_per_cluster_mu_fg" in sdata_sc["table"].varm.keys():
inf_aver = sdata_sc["table"].varm["means_per_cluster_mu_fg"][[f"means_per_cluster_mu_fg_{i}" for i in sdata_sc["table"].uns["mod"]["factor_names"]]].copy()
else:
inf_aver = adata_sc.var[[f"means_per_cluster_mu_fg_{i}" for i in adata_sc.uns["mod"]["factor_names"]]].copy()
inf_aver.columns = adata_sc.uns["mod"]["factor_names"]
inf_aver = sdata_sc["table"].var[[f"means_per_cluster_mu_fg_{i}" for i in sdata_sc["table"].uns["mod"]["factor_names"]]].copy()
inf_aver.columns = sdata_sc["table"].uns["mod"]["factor_names"]
inf_aver.to_csv(output_dir+"/Cell2Loc_inf_aver.csv")

# plot QC
L.info("Plotting QC plots")
cell2loc_plot_QC_reference(model_ref, figdir + "/QC_reference_reconstruction_accuracy.png", figdir + "/QC_reference_expression signatures_vs_avg_expression.png")

# save model and update mudata
if adata_sc.var.index.names[0] in adata_sc.var.columns:
adata_sc.var.index.names = [None]
mdata_singlecell.mod["rna"] = adata_sc
mdata_singlecell.update()
# save model
if sdata_sc["table"].var.index.names[0] in sdata_sc["table"].var.columns:
sdata_sc["table"].var.index.names = [None]
#mdata_singlecell.mod["rna"] = adata_sc
#mdata_singlecell.update()
if save_models is True:
L.info("Saving reference model to '%s'" % output_dir)
model_ref.save(output_dir +"/Reference_model", overwrite=True)
Expand All @@ -293,15 +296,15 @@

# 4. Fit mapping model
L.info("Setting up AnnData for the spatial model")
c2l.models.Cell2location.setup_anndata(adata=adata_st,
c2l.models.Cell2location.setup_anndata(adata=sdata_st["table"],
labels_key = args.labels_key_st,
layer= args.layer_st,
batch_key= args.batch_key_st,
categorical_covariate_keys = categorical_covariate_keys_st,
continuous_covariate_keys = continuous_covariate_keys_st)


model_spatial = c2l.models.Cell2location(adata = adata_st, cell_state_df=inf_aver,
model_spatial = c2l.models.Cell2location(adata = sdata_st["table"], cell_state_df=inf_aver,
N_cells_per_location=float(args.N_cells_per_location),
detection_alpha=float(args.detection_alpha))
L.info("Training the spatial model")
Expand All @@ -312,32 +315,32 @@
cell2loc_plot_history(model_spatial, figdir + "/ELBO_spatial_model.png")
#extract posterior
L.info("Extracting the posterior of the spatial model")
adata_st = model_spatial.export_posterior(adata_st)
sdata_st["table"] = model_spatial.export_posterior(sdata_st["table"])
#plot QC
L.info("Plotting QC plots")
cell2loc_plot_QC_reconstr(model_spatial, figdir + "/QC_spatial_reconstruction_accuracy.png")


#plot output
L.info("Plotting spatial embedding plot coloured by 'q05_cell_abundance_w_sf'")
adata_st.obs[adata_st.uns["mod"]["factor_names"]] = adata_st.obsm["q05_cell_abundance_w_sf"]
sc.pl.spatial(adata_st,color=adata_st.uns["mod"]["factor_names"], show = False, save = "_Cell2Loc_q05_cell_abundance_w_sf.png")
sdata_st["table"].obs[sdata_st["table"].uns["mod"]["factor_names"]] = sdata_st["table"].obsm["q05_cell_abundance_w_sf"]
sc.pl.spatial(sdata_st["table"],color=sdata_st["table"].uns["mod"]["factor_names"], show = False, save = "_Cell2Loc_q05_cell_abundance_w_sf.png")


# save model and update mudata
if adata_st.var.index.names[0] in adata_st.var.columns:
adata_st.var.index.names = [None]
mdata_spatial.mod["spatial"] = adata_st
mdata_spatial.update()
# save model
if sdata_st["table"].var.index.names[0] in sdata_st["table"].var.columns:
sdata_st["table"].var.index.names = [None]
#mdata_spatial.mod["spatial"] = adata_st
#mdata_spatial.update()
if save_models is True:
L.info("Saving spatial model to '%s'" % output_dir)
model_spatial.save(output_dir+"/Spatial_mapping_model", overwrite=True)


#6. save mudatas
L.info("Saving MuDatas to '%s'" % output_dir)
mdata_singlecell.write(output_dir+"/Cell2Loc_screference_output.h5mu")
mdata_spatial.write(output_dir+"/Cell2Loc_spatial_output.h5mu")
L.info("Saving SpatialDatas to '%s'" % output_dir)
sdata_sc.write(output_dir+"/Cell2Loc_screference_output.zarr")
sdata_st.write(output_dir+"/Cell2Loc_spatial_output.zarr")


L.info("Done")
Expand Down
49 changes: 26 additions & 23 deletions panpipes/python_scripts/run_tangram.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import scanpy as sc
import tangram as tg
import muon as mu
import spatialdata as sd

import os
import argparse
Expand Down Expand Up @@ -100,13 +101,15 @@

#1. read in the data
#spatial:
L.info("Reading in spatial MuData from '%s'" % args.input_spatial)
mdata_spatial = mu.read(args.input_spatial)
adata_st = mdata_spatial.mod['spatial']
L.info("Reading in spatial SpatialData from '%s'" % args.input_spatial)
sdata_st = sd.read_zarr(args.input_spatial)
#mdata_spatial = mu.read(args.input_spatial)
#adata_st = mdata_spatial.mod['spatial']
#single-cell:
L.info("Reading in reference MuData from '%s'" % args.input_singlecell)
mdata_singlecell = mu.read(args.input_singlecell)
adata_sc = mdata_singlecell.mod['rna']
L.info("Reading in reference SpatialData from '%s'" % args.input_singlecell)
sdata_sc = sd.read_zarr(args.input_singlecell)
#mdata_singlecell = mu.read(args.input_singlecell)
#adata_sc = mdata_singlecell.mod['rna']


#2. Perform gene selection:
Expand All @@ -121,43 +124,43 @@

else: # perform feature selection using sc.tl.rank_genes_groups()
L.info("Running 'scanpy.tl.rank_genes_groups()'")
sc.tl.rank_genes_groups(adata_sc, groupby=args.labels_key_rank_genes, layer=args.layer_rank_genes, method=args.method_rank_genes,corr_method = args.corr_method_rank_genes)
sc.tl.rank_genes_groups(sdata_sc["table"], groupby=args.labels_key_rank_genes, layer=args.layer_rank_genes, method=args.method_rank_genes,corr_method = args.corr_method_rank_genes)
L.info("Plotting rank genes group")
sc.pl.rank_genes_groups(adata_sc, show = False, save = ".png")
markers_df = pd.DataFrame(adata_sc.uns["rank_genes_groups"]["names"]).iloc[0:int(args.n_genes_rank), :]
sc.pl.rank_genes_groups(sdata_sc["table"], show = False, save = ".png")
markers_df = pd.DataFrame(sdata_sc["table"].uns["rank_genes_groups"]["names"]).iloc[0:int(args.n_genes_rank), :]
L.info("Saving rank genes to " + output_dir + "/rank_genes_groups.csv")
markers_df.to_csv(output_dir + "/rank_genes_groups.csv")
markers = list(np.unique(markers_df.melt().value.values))

# "Preprocess" anndatas
L.info("Preprocessing AnnDatas")
tg.pp_adatas(adata_sc=adata_sc, adata_sp=adata_st, genes=markers)
tg.pp_adatas(adata_sc=sdata_sc["table"], adata_sp=sdata_st["table"], genes=markers)

# 3. Run tangram
L.info("Training model")
adata_results = tg.mapping_utils.map_cells_to_space(
adata_sc=adata_sc, adata_sp=adata_st, num_epochs=int(args.num_epochs), device=args.device, **args.kwargs
adata_sc=sdata_sc["table"], adata_sp=sdata_st["table"], num_epochs=int(args.num_epochs), device=args.device, **args.kwargs
)

# 3. Extract and plot results
L.info("Extracting annotations")
tg.project_cell_annotations(adata_results, adata_st, annotation=args.labels_key_model)
tg.project_cell_annotations(adata_results, sdata_st["table"], annotation=args.labels_key_model)

L.info("Plotting spatial embedding plot coloured by 'tangram_ct_pred'")
annotation_list = list(pd.unique(adata_sc.obs[args.labels_key_model]))
df = adata_st.obsm["tangram_ct_pred"][annotation_list]
tg.construct_obs_plot(df, adata_st, perc=0.05)
if "spatial" in adata_st.uns:
sc.pl.spatial(adata_st, color=annotation_list, cmap="viridis", show=False, frameon=False, ncols=3, save = "_tangram_ct_pred.png")
annotation_list = list(pd.unique(sdata_sc["table"].obs[args.labels_key_model]))
df = sdata_st["table"].obsm["tangram_ct_pred"][annotation_list]
tg.construct_obs_plot(df, sdata_st["table"], perc=0.05)
if "spatial" in sdata_st["table"].uns:
sc.pl.spatial(sdata_st["table"], color=annotation_list, cmap="viridis", show=False, frameon=False, ncols=3, save = "_tangram_ct_pred.png")
else:
sc.pl.spatial(adata_st, color=annotation_list, cmap="viridis", show=False, frameon=False, ncols=3, save = "_tangram_ct_pred.png",spot_size=0.5)
sc.pl.spatial(sdata_st["table"], color=annotation_list, cmap="viridis", show=False, frameon=False, ncols=3, save = "_tangram_ct_pred.png",spot_size=0.5)


mdata_singlecell_results = mu.MuData({"rna": adata_sc})
mdata_spatial_results = mu.MuData({"spatial": adata_st})
#mdata_singlecell_results = mu.MuData({"rna": adata_sc})
#mdata_spatial_results = mu.MuData({"spatial": adata_st})

L.info("Saving MuDatas to '%s'" % output_dir)
mdata_singlecell_results.write(output_dir+"/Tangram_screference_output.h5mu")
mdata_spatial_results.write(output_dir+"/Tangram_spatial_output.h5mu")
L.info("Saving SpatialDatas to '%s'" % output_dir)
sdata_sc.write(output_dir+"/Tangram_screference_output.zarr")
sdata_st.write(output_dir+"/Tangram_spatial_output.zarr")

L.info("Done")

0 comments on commit 1e002b5

Please sign in to comment.