Skip to content

Commit

Permalink
domino bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
yollct committed Feb 27, 2023
1 parent eac88fb commit 2e0fb2f
Show file tree
Hide file tree
Showing 8 changed files with 86 additions and 63 deletions.
17 changes: 11 additions & 6 deletions spycone/DOMINO/src/core/domino.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import pandas as pd
import numpy as np
from itertools import repeat
import pickle
import multiprocessing
import matplotlib
Expand Down Expand Up @@ -63,9 +64,9 @@ def add_scores_to_nodes(G, scores):
return G


def create_subgraph(params):
def create_subgraph(params, G):
cur_module = params
global G_modularity
G_modularity = G
nodes = set(cur_module)
res = G_modularity.subgraph(list(nodes))
return res
Expand All @@ -86,7 +87,7 @@ def prune_network_by_modularity(G, modules):
#print(f"Before slicing: n of cc:{len(list(connected_components(G_modularity)))}, n of nodes: {len(G_modularity.nodes)}, n of edges, {len(G_modularity.edges)#}")
p = multiprocessing.Pool(N_OF_THREADS)

G_modules = p.map(create_subgraph, [m for m in modules])
G_modules = p.starmap(create_subgraph, zip(modules, repeat(G)))
p.close()
# print(f'{modules}')
#print(f'# of modules after extraction: {len(G_modules)}')
Expand Down Expand Up @@ -234,6 +235,7 @@ def get_putative_modules(G, full_G=None, improvement_delta=0, modularity_score_o

def retain_relevant_slices(G_original, module_sig_th):
global G_modularity
print(G_modularity)

pertubed_nodes = []
for cur_node in G_modularity.nodes():
Expand Down Expand Up @@ -344,23 +346,26 @@ def main(active_genes_file, network_file, scores=None, slices_file=None, slice_t
G = build_network(network_file)
#pickle.dump(G, open(f'{network_file}.pkl', 'wb+'))
#print(f'network\' pkl is saved: {network_file}.pkl')

#print("done building network")

# assign activeness to nodes
scores = extract_scores(active_genes_file, scores)
G = add_scores_to_nodes(G, scores)

modularity_connected_components = read_preprocessed_slices(slices_file)

global G_modularity
G_modularity = G
prune_network_by_modularity(G, modularity_connected_components)
print("here",G_modularity)
G_modularity, relevant_slices, qvals = retain_relevant_slices(G, slice_threshold)

print("here2",G_modularity)
#print(f'{len(relevant_slices)} relevant slices were retained with threshold {slice_threshold}')
params = []
for i_cc, cc in enumerate(relevant_slices):
params.append([G, cc, i_cc, n_steps, relevant_slices, prize_factor, module_threshold])
p = multiprocessing.Pool(N_OF_THREADS)
putative_modules = reduce(lambda a, b: a + b, p.map(analyze_slice, params), [])
putative_modules = reduce(lambda a, b: a + b, p.starmap(analyze_slice, params), [])
p.close()
#print(f'n of putative modules: {len(putative_modules)}')
final_modules, sig_scores = get_final_modules(G, putative_modules)
Expand Down
Binary file modified spycone/_NEASE/nease/__pycache__/__init__.cpython-38.pyc
Binary file not shown.
Binary file modified spycone/_NEASE/nease/__pycache__/functions.cpython-38.pyc
Binary file not shown.
Binary file modified spycone/_NEASE/nease/__pycache__/load.cpython-38.pyc
Binary file not shown.
Binary file modified spycone/_NEASE/nease/__pycache__/nease.cpython-38.pyc
Binary file not shown.
Binary file modified spycone/_NEASE/nease/__pycache__/process.cpython-38.pyc
Binary file not shown.
71 changes: 35 additions & 36 deletions spycone/run_domino.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,18 +212,18 @@ def run_domino(target, name=None, is_results=None, scores = None, network_file =


sl.create_slices(network_file, output_file_path)
# if isinstance(target, list):
# ##check list
# checkedtarget = _check_nodes(list(map(str,target)), network_file)
if isinstance(target, list):
##check list
checkedtarget = _check_nodes(list(map(str,target)), network_file)

# ## clusterobj can also be list
# a = defaultdict(list)
## clusterobj can also be list
a = defaultdict(list)

# tmp, scores = domino.main(list(map(str,checkedtarget)), network_file, slices_file=output_file_path, slice_threshold=slice_threshold, module_threshold=module_threshold, prize_factor=prize_factor, n_steps=n_steps)
# a[name].append(tmp)
# a[name].append(scores)
tmp, scores = domino.main(list(map(str,checkedtarget)), network_file, slices_file=output_file_path, slice_threshold=slice_threshold, module_threshold=module_threshold, prize_factor=prize_factor, n_steps=n_steps)
a[name].append(tmp)
a[name].append(scores)

# return a
return a

# elif isinstance(target, clustering) and run_cluster is not None:
# a = defaultdict(list)
Expand All @@ -241,34 +241,33 @@ def run_domino(target, name=None, is_results=None, scores = None, network_file =

# return a

# else:
else:
a=defaultdict(list)
for u,v in target.genelist_clusters.items():
scores = []
for gene in target.symbs_clusters[u]:
if genescores is not None:
if gene in genescores.keys():
scores.append(genescores[gene])
else:
scores.append(1)

a=defaultdict(list)
for u,v in target.genelist_clusters.items():
scores = []
for gene in target.symbs_clusters[u]:
if genescores is not None:
if gene in genescores.keys():
scores.append(genescores[gene])
else:
scores.append(1)
##fornow emp
##TODO scores

#scoresdf.to_csv("/nfs/home/students/chit/lrz_ticone/domino_emp/{}_cluster{}_mod.csv".format(name, u), index=False)
###
checkedtarget = _check_nodes(list(map(str,v)), network_file)

tmp, scores = domino.main(list(map(str,checkedtarget)), network_file, slices_file=output_file_path, slice_threshold=slice_threshold, module_threshold=module_threshold, prize_factor=prize_factor, n_steps=n_steps)
a[u].append(tmp)
a[u].append(scores)

else:
scores.append(1)
##fornow emp
##TODO scores

#scoresdf.to_csv("/nfs/home/students/chit/lrz_ticone/domino_emp/{}_cluster{}_mod.csv".format(name, u), index=False)
###
checkedtarget = _check_nodes(list(map(str,v)), network_file)

tmp, scores = domino.main(list(map(str,checkedtarget)), network_file, slices_file=output_file_path, slice_threshold=slice_threshold, module_threshold=module_threshold, prize_factor=prize_factor, n_steps=n_steps)
a[u].append(tmp)
a[u].append(scores)

print("---------Network enrichment Result---------\n")
for u,v in a.items():
#for e, vv in enumerate(v[0]):
print(f"Cluster {u} found {len(v[0])} module(s).")
print("-----END-----")
return a
print("---------Network enrichment Result---------\n")
for u,v in a.items():
#for e, vv in enumerate(v[0]):
print(f"Cluster {u} found {len(v[0])} module(s).")
print("-----END-----")
return a
61 changes: 40 additions & 21 deletions spycone/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,14 @@ def vis_all_clusters(clusterObj, x_label="time points", y_label="expression", Ti
median_allts[~np.isfinite(median_allts)]=0

if plot_clusters is None:
plot_clusters = list(clusterObj.keys())
plot_clusters = list(clusterObj.index_clusters.keys())
return_plotly=False
width = 100*len(plot_clusters)
height=500
else:
return_plotly =True
width=300
height=300*len(plot_clusters)

##faster probably
reorder = []
Expand All @@ -68,31 +75,45 @@ def vis_all_clusters(clusterObj, x_label="time points", y_label="expression", Ti
clusters_sns = pd.concat([pd.Series(clusters,name="clusters"),cluster_tsarray], axis=1)
clusters_sns = pd.melt(clusters_sns, id_vars="clusters",value_vars=cluster_tsarray.columns, value_name="expression",var_name="timepoints", ignore_index=False)
clusters_sns['gene'] = clusters_sns.index
clusters_sns['type']="object"
clusters_sns = clusters_sns.reset_index()

#for prototype
cluster_pro = pd.DataFrame(np.array(list(clusterObj._prototype.values())))
cluster_pro.columns = ["tp{}".format(x) for x in range(tp)]
cluster_pro['clusters'] = list(clusterObj.symbs_clusters.keys())
cluster_pro = pd.melt(cluster_pro, id_vars='clusters', value_vars=cluster_pro.columns.to_list()[:-1], value_name="expression", var_name="timepoints")
cluster_pro['type']="prototype"
cluster_pro['gene']=cluster_pro.index
cluster_pro = cluster_pro[cluster_pro['clusters'].isin(plot_clusters)]
# allsns = pd.concat(cluster_sns_list)
# sns.relplot(x="timepoints", y="expression",kind="line", hue="cluster", col="cluster",col_wrap=5, height=3, aspect=.75, linewidth=2.5, palette= "Set2", data=allsns)

grey = ['lightgray'] * len(plot_clusters)
#g = sns.relplot(data=clusters_sns, x="timepoints", y="expression", col="clusters", hue="clusters", units="gene", kind="line", height=3, aspect=.75,estimator=None, linewidth=1, palette=grey, legend=False, **kwargs)
if ncol is None:
ncol = len(plot_clusters)
nrow =1

fig = px.line(clusters_sns, x="timepoints", y="expression", color="gene", facet_col="clusters", width=300, height=300*len(plot_clusters), facet_col_wrap= col_wrap)
pro = px.line(cluster_pro, x="timepoints", y="expression", color="clusters", facet_col="clusters", width=300, height=300*len(plot_clusters), facet_col_wrap= col_wrap)

fig.update_traces(line_color="lightgray")
for x in range(len(plot_clusters)):
fig.add_trace(pro.data[x])
fig.update_traces(showlegend=False)
fig.update_traces(hoverinfo='skip')

if not return_plotly:
grey = ['lightgray'] * len(plot_clusters)

g = sns.relplot(data=clusters_sns, x="timepoints", y="expression", col="clusters", hue="clusters", units="gene", kind="line", height=3, aspect=.75,estimator=None, linewidth=1, palette=grey, legend=False, col_wrap=col_wrap, **kwargs)
pal = sns.color_palette("dark", len(clusterObj.genelist_clusters.keys()))
i=0
for x,ax in g.axes_dict.items():
subdata = cluster_pro[cluster_pro['clusters']==x]

sns.lineplot(data=subdata, x="timepoints", y="expression", linewidth=4, color=pal[i], ax=ax, legend=False)
i+=1

else:
fig = px.line(clusters_sns, x="timepoints", y="expression", color="gene", facet_col="clusters", width=width, height=height, facet_col_wrap= col_wrap)
pro = px.line(cluster_pro, x="timepoints", y="expression", color="clusters", facet_col="clusters", width=width, height=height, facet_col_wrap= col_wrap)

fig.update_traces(line_color='lightgrey')
for x in range(len(plot_clusters)):
fig.add_trace(pro.data[x])
fig.update_traces(showlegend=False)
fig.update_traces(hoverinfo='skip')

# print(clusters_sns.head())
# fig = make_subplots(rows=nrow, cols=ncol, shared_yaxes=False, specs=[[{'type':'scatter'}]*len(plot_clusters)])
# sr = 1
Expand All @@ -115,15 +136,13 @@ def vis_all_clusters(clusterObj, x_label="time points", y_label="expression", Ti
# g.set_xticklabels(labels=xtickslabels,rotation=90, ha="right", fontsize=8)

##plot prototypes
#g = sns.relplot(data=cluster_pro, x="timepoints", y="expression", kind="line", col="clusters",hue="clusters",col_wrap=3, height=3, aspect=.75, linewidth=4, palette="Set2")
# pal = sns.color_palette("dark", len(clusterObj.genelist_clusters.keys()))
# i=0
# for x,ax in g.axes_dict.items():
# subdata = cluster_pro[cluster_pro['clusters']==x]
# g = sns.relplot(data=cluster_pro, x="timepoints", y="expression", kind="line", col="clusters",hue="clusters",col_wrap=3, height=3, aspect=.75, linewidth=4, palette="Set2")


# sns.lineplot(data=subdata, x="timepoints", y="expression", linewidth=4, color=pal[i], ax=ax, legend=False)
# i+=1
return fig
if return_plotly:
return fig
else:
return g



Expand Down

0 comments on commit 2e0fb2f

Please sign in to comment.