diff --git a/gemelli/joint_viz.py b/gemelli/joint_viz.py new file mode 100644 index 0000000..2b64d6a --- /dev/null +++ b/gemelli/joint_viz.py @@ -0,0 +1,79 @@ +import networkx as nx +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.lines import Line2D +import seaborn as sns +from gemelli.rpca import joint_rpca + + +def create_graph(correlation_table, + feature_map, + features_use=None, + pos_corr_thresh=0.5, + neg_corr_thresh=-0.5): + if features_use is not None: + correlation_table = correlation_table.loc[features_use, features_use] + + idx = correlation_table.index.values + G = nx.from_numpy_matrix(correlation_table.values) + G = nx.relabel_nodes(G, lambda x: idx[x]) + for _id in idx: + G.remove_edge(_id, _id) + nx.set_node_attributes(G, feature_map, 'modality') + + edges_to_keep = [] + for node1, node2, attr in G.edges(data=True): + weight = attr['weight'] + if weight > pos_corr_thresh or weight < neg_corr_thresh: + edges_to_keep.append((node1, node2)) + + G = G.edge_subgraph(edges_to_keep) + return G + + +def visualize_graph(G, feature_map): + labels = nx.get_node_attributes(G, 'modality') + + modalities = list(set(feature_map.values())) + num_modalities = len(modalities) + + edge_weights = [] + edge_colors = [] + for u, v, attr in G.edges(data=True): + weight = attr['weight'] + if weight > 0: + ec = "blue" + elif weight < 0: + ec = "red" + else: + ec = "gray" + edge_colors.append(ec) + edge_weights.append(np.abs(weight)*0.5) + + palette = dict(zip(modalities, sns.color_palette("tab10", num_modalities))) + node_colors = [palette[G.nodes[node]["modality"]] for node in G.nodes] + labels = nx.get_node_attributes(G, "modality") + + fig, ax = plt.subplots(1, 1) + + nx.draw_networkx( + G, + node_color=node_colors, + edge_color=edge_colors, + width=edge_weights, + ax=ax, + with_labels=False + ) + + handles = [] + pos_line = Line2D([0], [0], label="positive", color="blue") + neg_line = Line2D([0], [0], label="negative", color="red") + handles.extend([pos_line, neg_line]) + + for modality, color in palette.items(): + p = Line2D([0], [0], mfc=color, label=modality, markersize=10, + marker="o", mew=0, linewidth=0) + handles.append(p) + + ax.legend(handles=handles) + return ax diff --git a/gemelli/rpca.py b/gemelli/rpca.py index 21211f2..dce059e 100644 --- a/gemelli/rpca.py +++ b/gemelli/rpca.py @@ -609,7 +609,24 @@ def frequency_filter(val, id_, md): return table +class JointOrdination: + def __init__(self, table_map, ordination): + self.table_map = table_map + self.ordination = ordination + self.feature_map = self._create_feature_map() + + def _create_feature_map(self): + feature_map = dict() + for table_name, table in self.table_map.items(): + # TODO: Check for overlaps + feature_map.update({ + feat: table_name for feat in table.ids("observation") + }) + return feature_map + + def joint_rpca(tables: biom.Table, + table_names: list = None, n_test_samples: int = DEFAULT_TESTS, sample_metadata: pd.DataFrame = DEFAULT_METACV, train_test_column: str = DEFAULT_COLCV, @@ -617,10 +634,8 @@ def joint_rpca(tables: biom.Table, min_sample_count: int = DEFAULT_MSC, min_feature_count: int = DEFAULT_MFC, min_feature_frequency: float = DEFAULT_MFF, - max_iterations: int = DEFAULT_OPTSPACE_ITERATIONS) -> ( - OrdinationResults, - DistanceMatrix, - pd.DataFrame): + max_iterations: int = DEFAULT_OPTSPACE_ITERATIONS + ) -> (JointOrdination, DistanceMatrix, pd.DataFrame): """ Performs joint-RPCA across data tables with shared samples. @@ -696,6 +711,12 @@ def joint_rpca(tables: biom.Table, """ + if table_names is None: + table_names = [f'table.{i}' for i, _ in enumerate(tables)] + else: + if len(table_names) != len(tables): + raise ValueError('Length of tables and table names must match.') + # filter each table for n, table_n in enumerate(tables): tables[n] = rpca_table_processing(table_n, @@ -747,7 +768,8 @@ def joint_rpca(tables: biom.Table, max_iterations, test_samples, train_samples) - return ord_res, U_dist_res, cv_dist + table_map = dict(zip(table_names, tables)) + return JointOrdination(table_map, ord_res), U_dist_res, cv_dist def joint_optspace_helper(tables,