Skip to content

Commit

Permalink
add log info
Browse files Browse the repository at this point in the history
  • Loading branch information
SarahOuologuem committed Apr 18, 2024
1 parent ad90c0c commit abe15e7
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 34 deletions.
27 changes: 16 additions & 11 deletions panpipes/python_scripts/refmap_scib.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,38 +46,44 @@
covariates_use = args.covariate.split(",")
covariates_use = [a.strip() for a in covariates_use]
else:
sys.exit("i don't have covariates to calculate metrics on")
L.error("Covariates need to be specified.")
sys.exit("Covariates need to be specified.")


repuse = str(args.repuse)

L.info("Reading in query data from '%s'" % args.query_data)
mdata = mu.read(args.query)

if type(mdata) is mu.MuData:
if "rna" not in mdata.mod.keys():
sys.exit("we only support querying using RNA but your mdata doesn't contain rna")
L.error("Modality 'rna' could not be found in MuData '%s'. We only support querying using RNA." % args.query_data)
sys.exit("Modality 'rna' could not be found in MuData '%s'. We only support querying using RNA." % args.query_data)
else:
input_adata = mdata["rna"].copy()
del mdata


adata_query = input_adata[input_adata.obs['is_reference'] == 'Query'].copy()
L.info("repuse is %s" %(repuse))
L.info("Repuse is %s" %(repuse))
if repuse not in adata_query.obsm.keys():
sys.exit("the latent representation is not in the obsm.keys of this query")
L.error("The latent representation '%s' could not be found in the obsm.keys of query '%s'" % (repuse, args.query_data))
sys.exit("The latent representation '%s' could not be found in the obsm.keys of query '%s'" % (repuse, args.query_data))

L.info("query is:")
L.info("The query AnnData is:")
print(adata_query)

L.info("Calculating scib metrics using ground truth covariates:")
L.info("Calculating scib metrics using ground truth covariates: ")
print(covariates_use)

if args.cluster_key is None:
cluster_key = "leiden_" + repuse
L.info("you didn't specify cluster_key, so i'm using %s " % cluster_key)
L.warning("No cluster key specified. Using %s " % cluster_key)
else:
cluster_key = str(args.cluster_key)
L.info("cluster_key used is: %s" %cluster_key)

L.info("Calculating scIB metrics")
results = {}
for labelk in covariates_use:
m={"ASW_scaled":scib.metrics.silhouette(adata_query, labelk, repuse, metric='euclidean', scale=True),
Expand All @@ -91,10 +97,9 @@

file_name= os.path.splitext(os.path.basename(args.query).replace("query_to_reference_", "").replace(".h5mu", ""))[0]
#"query_to_reference_" + model_name + "_" + latent_choice + ".h5mu"
L.info("saving file output")
file_out = os.path.join(args.outdir,("scib.query_"+ file_name+".csv"))
print(file_out)
file_out = os.path.join(args.outdir,("scib.query_"+ file_name+".tsv"))
#save file to txt
pres = pd.DataFrame.from_dict(results,orient='index')
L.info("Saving output to tsv file '%s'" % file_out)
pres.to_csv(file_out, sep="\t")
L.info("Finished scib")
L.info("Done")
68 changes: 45 additions & 23 deletions panpipes/python_scripts/refmap_scvitools.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@
args, opt = parser.parse_known_args()
sc.settings.figdir = "figures/"
sc.set_figure_params(figsize=(8, 6), dpi=300)
L.info("running with args:")
args.predict_rf = check_for_bool(args.predict_rf)
args.impute_proteins = check_for_bool(args.impute_proteins)

L.info(args)

threads_available = multiprocessing.cpu_count()
Expand All @@ -76,33 +76,39 @@
reference_architecture = str(args.reference_architecture)



L.info("Reading in query data from '%s'" % args.query_data)
mdata = mu.read(args.query_data)

if type(mdata) is mu.MuData:
if "rna" not in mdata.mod.keys():
sys.exit("we only support querying using RNA but your mdata doesn't contain rna")
L.error("Modality 'rna' could not be found in MuData '%s'. We only support querying using RNA." % args.query_data)
sys.exit("Modality 'rna' could not be found in MuData '%s'. We only support querying using RNA." % args.query_data)
else:
adata_query = mdata["rna"].copy()
if "prot" in mdata.mod.keys() and reference_architecture=="totalvi":
if X_is_raw(mdata['prot']):
L.info("Using raw protein data from mdata.mod['prot'].X")
X_array = mdata['prot'].X.copy()
elif "raw_counts" in mdata['prot'].layers.keys():
L.info("Using raw protein data from mdata.mod['prot'].layers['raw_counts']")
X_array = mdata['prot'].layers['raw_counts'].copy()

X_df = pd.DataFrame(X_array.todense(), index=mdata['prot'].obs_names, columns=mdata['prot'].var_names)
if X_df.shape[0] == adata_query.X.shape[0]:
L.info("adding protein_expression to obsm")
L.info("Adding raw protein expression to obsm")
#check the obs are in the correct order
X_df = X_df.loc[adata_query.obs_names,:]
adata_query.obsm['protein_expression'] = X_df
else:
L.error("dimensions do not match, cannot create the raw obsm counts in query")
L.error("Dimension 0 of protein expression matrix and RNA expression matrix do not match. Cannot create the raw obsm counts in query.")
sys.exit("Dimension 0 of protein expression matrix and RNA expression matrix do not match. Cannot create the raw obsm counts in query.")
else:
adata_query = mdata.copy()

L.info("this is your query anndata")
L.info("The query AnnData is:")
print(adata_query)

L.info("Saving raw counts to adata.layers['counts']")
if "counts" not in adata_query.layers.keys():
if "raw_counts" in adata_query.layers.keys():
adata_query.layers["counts"] = adata_query.layers["raw_counts"].copy()
Expand All @@ -114,17 +120,17 @@


if reference_architecture in ['totalvi','scvi','scanvi']:
L.info("running using %s" % reference_architecture)
L.info("Running using %s" % reference_architecture)
else:
sys.exit(" i don't recognise this architecture : %s") %reference_architecture
L.error("Architecture '%s' could not be found. The following architectures are available: 'totalvi', 'scvi', 'scanvi'" % reference_architecture)
sys.exit("Architecture '%s' could not be found. The following architectures are available: 'totalvi', 'scvi', 'scanvi'" % reference_architecture)

reference_path = args.reference_path
if not os.path.exists(reference_path):
L.info("The reference path you provided does not exist")
sys.exit("The reference path you provided does not exist")
L.info("The provided reference path '%s does not exist" % reference_path)
sys.exit("The provided reference path '%s' does not exist" % reference_path)
else:
reference_path = os.path.dirname(reference_path)
L.debug("Reference path: %s" % reference_path)

train_kwargs = {}
if params["training_plan"][reference_architecture] is not None:
Expand All @@ -140,29 +146,33 @@


if reference_architecture=="scvi":
L.info("Running scVI")
scvi.model.SCVI.prepare_query_anndata(adata_query, reference_path)
vae_q = scvi.model.SCVI.load_query_data(
adata_query,
reference_path)
latent_choice= "X_scvi"
vae_q.train(max_epochs= max_epochs , plan_kwargs=train_kwargs)
L.info("Saving latent to .obsm['X_scvi']")
adata_query.obsm["X_scvi"] = vae_q.get_latent_representation()


if reference_architecture=="scanvi":
# Notice that adata_query.obs["labels_scanvi"] does not exist.
# The load_query_data method detects this and fills it in adata_query with the unlabeled category (here "Unknown").
L.info("Running scANVI")
scvi.model.SCANVI.prepare_query_anndata(adata_query, reference_path)
vae_q = scvi.model.SCANVI.load_query_data(
adata_query,
reference_path)
latent_choice= "X_scanvi"
#vae_q.train(**train_kwargs) this doesn't work anymore cause max_epochs is not recognised as part of the plan kwargs
vae_q.train(max_epochs= max_epochs , plan_kwargs=train_kwargs)
L.info("Saving latent to .obsm['X_scanvi'] and predictions to .obs['predictions']")
adata_query.obsm["X_scanvi"] = vae_q.get_latent_representation()
adata_query.obs["predictions"] = vae_q.predict()
if args.query_celltype is not None:
L.info("Query has celltypes in column %s, i will plot what predictions look like from scanvi model" % args.query_celltype)
L.info("Plotting predictions from scANVI model")
df = adata_query.obs.groupby([str(args.query_celltype), "predictions"]).size().unstack(fill_value=0)
norm_df = df / df.sum(axis=0)

Expand All @@ -180,13 +190,13 @@
# temporary fix is disabled for now, need to modify to allow to fix query to match reference
fix_query = False
if fix_query:
L.info(" will do some manipulation of query data to make it match to referece structure")
if args.adata_reference is not None:
reference_data = os.path.basename(args.adata_reference)
mdata=mu.read(args.adata_reference)
if type(mdata) is MuData:
if "rna" not in mdata.mod.keys():
sys.exit("we only support querying using RNA but your mdata doesn't contain rna")
L.error("Modality 'rna' could not be found in MuData '%s'. We only support querying using RNA." % args.adata_reference)
sys.exit("Modality 'rna' could not be found in MuData '%s'. We only support querying using RNA." % args.adata_reference)
else:
adata_ref = mdata["rna"].copy()
adata_ref.obsm = mdata.obsm.copy()
Expand All @@ -198,6 +208,7 @@

#adata_query.layers["counts"] = adata_query.X.copy()#already taken care of
# are the following necessary?
L.info("Normalizing query data")
sc.pp.normalize_total(adata_query, target_sum=1e4)
sc.pp.log1p(adata_query)
adata_query.raw = adata_query
Expand Down Expand Up @@ -247,31 +258,34 @@
# L.info("name of obsm slot for reference is: %i" % args.reference_prot_assay)
# pname = args.reference_prot_assay
sys.exit("need a reference dataset to check for matching query entries")


L.info("Running totalVI")
scvi.model.TOTALVI.prepare_query_anndata(adata_query, reference_path)
vae_q = scvi.model.TOTALVI.load_query_data(
adata_query,
reference_path,
freeze_expression=True)
latent_choice= "X_totalvi"
vae_q.train(max_epochs= max_epochs , plan_kwargs=train_kwargs)
L.info("Saving latent to .obsm['X_totalvi']")
adata_query.obsm["X_totalvi"] = vae_q.get_latent_representation()
#remove this after finishing
adata_query.write(os.path.join( "query_temp_check_totvi.h5ad"))

if args.predict_rf :
L.info("predicting celltypes")
L.info("Predicting cell types")
predictions = (
vae_q.latent_space_classifer_.predict(
adata_query.obsm["X_totalvi"]
)
)
L.info("Saving predictions to .obs['predictions']")
adata_query.obs["predictions"] = predictions
#remove this after finishing
adata_query.write(os.path.join( "query_temp_check_predictions_totvi.h5ad"))

if args.query_celltype is not None:
L.info("Query has celltypes in column %s, i will plot what predictions look like from totalvi model" % args.query_celltype)
L.info("Plotting predictions from totalVI model")
df = adata_query.obs.groupby([str(args.query_celltype), "predictions"]).size().unstack(fill_value=0)
norm_df = df / df.sum(axis=0)

Expand All @@ -291,10 +305,12 @@

if args.adata_reference is not None:
reference_data = os.path.basename(args.adata_reference)
L.info("Reading in data from '%s'" % args.adata_reference)
mdata=mu.read(args.adata_reference)
if type(mdata) is MuData:
if "rna" not in mdata.mod.keys():
sys.exit("we only support querying using RNA but your mdata doesn't contain rna")
L.error("Modality 'rna' could not be found in MuData '%s'. We only support querying using RNA." % args.adata_reference)
sys.exit("Modality 'rna' could not be found in MuData '%s'. We only support querying using RNA." % args.adata_reference)
else:
adata_ref = mdata["rna"].copy()
adata_ref.obsm = mdata.obsm.copy()
Expand All @@ -305,6 +321,7 @@
adata_ref.obs.loc[:, 'is_reference'] = 'Reference'
adata_query.obs.loc[:, 'is_reference'] = 'Query'
#expect the batch to be always encoded as `batch` in both Q and R
L.info("Concatenating query and reference data")
adata_full = ad.concat( [adata_ref,adata_query])
# add param to decide if to recompute the total embedding or to recalc the embedding
#if "X_totalvi" not in adata_ref.obsm.keys():
Expand All @@ -320,30 +337,33 @@
n_samples=25,
return_mean=True,
transform_batch= transform_batch)
L.info("Saving denoised data to .obsm['totalvi_denoised_rna'] and .obsm['totalvi_denoised_protein']")
adata_full.obsm["totalvi_denoised_rna"], adata_full.obsm["totalvi_denoised_protein"] = normX, protein

if reference_architecture=="scanvi":
full_predictions = vae_q.predict(adata_full)
L.warn("Acc: {}".format(np.mean(full_predictions == adata_full.obs.celltype)))
L.warning("Acc of scANVI predictions: {}".format(np.mean(full_predictions == adata_full.obs.celltype)))
adata_full.obs["predictions"] = full_predictions
else:
adata_query.obs.loc[:, 'is_reference'] = 'Query'
adata_full = adata_query.copy()

if int(args.neighbors_n_pcs) > adata_full.obsm[latent_choice].shape[1]:
L.warn("N PCs is larger than %i dimensions, reducing n PCs to " % adata_full.obsm[latent_choice].shape[1])
L.warning("N PCs is larger than %i dimensions, reducing n PCs to " % adata_full.obsm[latent_choice].shape[1])


n_pcs= min(int(args.neighbors_n_pcs), adata_full.obsm[latent_choice].shape[1])


L.info("Running neighbors")
run_neighbors_method_choice(adata_full,
method=args.neighbors_method,
n_neighbors=int(args.neighbors_k),
n_pcs=n_pcs,
metric=args.neighbors_metric,
use_rep=latent_choice,
nthreads=max([threads_available, 6]))
L.info("Running UMAP and Leiden")
sc.tl.umap(adata_full, min_dist=0.4)
sc.tl.leiden(adata_full, key_added="leiden_" + latent_choice)

Expand All @@ -354,21 +374,23 @@
file_name= "umap_" + model_name + "_" + latent_choice
L.info ("filename is %s" % file_name )

L.info("Plotting UMAP")
fig = sc.pl.embedding(adata_full, basis = "umap",color=["is_reference"],
show=False, return_fig=True)
fig.tight_layout()
fig.savefig(os.path.join("figures/", file_name + ".png"))

umap = pd.DataFrame(adata_full.obsm['X_umap'], adata_full.obs.index)

L.info("Saving UMAP coordinates to csv file")
umap.to_csv(os.path.join("refmap/", file_name + ".csv") )
file_name= "query_to_reference_" + model_name + "_" + latent_choice + ".h5mu"

file_name= "query_to_reference_" + model_name + "_" + latent_choice + ".h5mu"
mdata_save = MuData({"rna":adata_full})

L.info("Saving MuData to refmap/" + file_name)
mdata_save.write(os.path.join( "refmap/" , file_name))

L.info('done')
L.info('Done')



0 comments on commit abe15e7

Please sign in to comment.