diff --git a/networkcommons/visual/_network_stats.py b/networkcommons/visual/_network_stats.py index cb5e33a..fa8cec8 100644 --- a/networkcommons/visual/_network_stats.py +++ b/networkcommons/visual/_network_stats.py @@ -15,9 +15,14 @@ """ Plot network (graph) metrics. + +This module provides several functions to visualize network metrics and generate +plots for data analysis. Functions include plotting the number of nodes and edges, +creating heatmaps, generating scatter plots, and more. """ from __future__ import annotations +from networkcommons._session import _log __all__ = [ 'plot_n_nodes_edges', @@ -36,51 +41,78 @@ import seaborn as sns import numpy as np import os -import logging from scipy.cluster.hierarchy import linkage from scipy.spatial.distance import squareform -def plot_rank(df, - bio_ids=None, - figsize=(12, 6), - x_label='Proteins', - y_label='Average Intensity', - title='Protein abundance Rank Plot', - legend_labels=None, - id_column='idx', - average_color='blue', - stdev_color='gray', - stdev_alpha=0.2, - highlight_color='red', - highlight_size=5, - highlight_zorder=5, - filepath=None, - render=False): +def plot_rank(df: pd.DataFrame, + bio_ids: List[str] = None, + figsize: tuple = (12, 6), + x_label: str = 'Proteins', + y_label: str = 'Average Intensity', + title: str = 'Protein abundance Rank Plot', + legend_labels: dict = None, + id_column: str = 'idx', + average_color: str = 'blue', + stdev_color: str = 'gray', + stdev_alpha: float = 0.2, + highlight_color: str = 'red', + highlight_size: int = 5, + highlight_zorder: int = 5, + filepath: str = None, + render: bool = False) -> plt.Figure: """ - Plots a protein abundance rank plot with customizable attributes. - - Args: - df (pd.DataFrame): Input DataFrame containing gene or protein data. - bio_ids (list of str, optional): List of specific genes or proteins to highlight. Defaults to None. - figsize (tuple, optional): Size of the figure. Defaults to (12, 6). - x_label (str, optional): Label for the x-axis. Defaults to 'Proteins'. - y_label (str, optional): Label for the y-axis. Defaults to 'Average Intensity'. - title (str, optional): Title of the plot. Defaults to 'Protein abundance Rank Plot'. - legend_labels (dict, optional): Dictionary with legend labels for 'average', 'stdev', and 'highlight'. Defaults to None. - id_column (str, optional): Name of the column containing the IDs (e.g., gene_symbols or proteins). Defaults to 'idx'. - average_color (str, optional): Color of the average intensity line. Defaults to 'blue'. - stdev_color (str, optional): Color of the standard deviation shaded area. Defaults to 'gray'. - stdev_alpha (float, optional): Alpha (transparency) for the standard deviation shaded area. Defaults to 0.2. - highlight_color (str, optional): Color for highlighting specific genes or proteins. Defaults to 'red'. - highlight_size (int, optional): Size of the highlighted points. Defaults to 5. - highlight_zorder (int, optional): Z-order for the highlighted points. Defaults to 5. - filepath (str, optional): Path to save the plot. If None, the plot will not be saved. Defaults to None. - render (bool, optional): Whether to display the plot. Defaults to False. - - Returns: - matplotlib.figure.Figure: The figure object for the plot. + Plot a protein abundance rank plot. + + This function generates a plot showing protein abundance ranked by their average intensity + with an option to highlight specific proteins. + + Parameters + ---------- + df : pd.DataFrame + Input DataFrame containing gene or protein data. + bio_ids : list of str, optional + List of specific genes or proteins to highlight. Defaults to None. + figsize : tuple, optional + Size of the figure. Defaults to (12, 6). + x_label : str, optional + Label for the x-axis. Defaults to 'Proteins'. + y_label : str, optional + Label for the y-axis. Defaults to 'Average Intensity'. + title : str, optional + Title of the plot. Defaults to 'Protein abundance Rank Plot'. + legend_labels : dict, optional + Dictionary with legend labels for 'average', 'stdev', and 'highlight'. Defaults to None. + id_column : str, optional + Name of the column containing the IDs (e.g., gene symbols or proteins). Defaults to 'idx'. + average_color : str, optional + Color of the average intensity line. Defaults to 'blue'. + stdev_color : str, optional + Color of the standard deviation shaded area. Defaults to 'gray'. + stdev_alpha : float, optional + Transparency for the standard deviation shaded area. Defaults to 0.2. + highlight_color : str, optional + Color for highlighting specific genes or proteins. Defaults to 'red'. + highlight_size : int, optional + Size of the highlighted points. Defaults to 5. + highlight_zorder : int, optional + Z-order for the highlighted points. Defaults to 5. + filepath : str, optional + Path to save the plot. If None, the plot will not be saved. Defaults to None. + render : bool, optional + Whether to display the plot. Defaults to False. + + Returns + ------- + matplotlib.figure.Figure + The figure object for the plot. """ + if id_column not in df.columns: + _log(f"Column '{id_column}' not found in the DataFrame. Using the index as the ID column.", level=30) + df = df.reset_index() # Reset the index, moving it to a new column + id_column = 'index' # Set id_column to 'index' as it's now a column + df[id_column] = df[id_column].astype(str) + df = df.copy() # Compute average and standard deviation across columns, ignoring the non-numeric columns df['average'] = df.select_dtypes(include=[np.number]).mean(axis=1) @@ -137,46 +169,62 @@ def plot_rank(df, plt.show() if not filepath and not render: - logging.warning("No output specified. Returning the plot object.") + _log("No output specified. Returning the plot object.", level=30) return plt.gcf() - -def plot_scatter(df, - summarise_df=True, - x_col='diff_dysregulation', - y_col='coverage', - size_col='nodes_with_phosphoinfo', - hue_col='method', - style_col='type', - numeric_cols=None, - xlabel='Difference in Activity scores', - ylabel='Coverage', - title='Coverage vs Difference in Activity scores', - figsize=(10, 6), - filepath="scatter_plot.png", - render=False): +def plot_scatter(df: pd.DataFrame, + summarise_df: bool = True, + x_col: str = 'diff_dysregulation', + y_col: str = 'coverage', + size_col: str = 'nodes_with_phosphoinfo', + hue_col: str = 'method', + style_col: str = 'type', + numeric_cols: List[str] = None, + xlabel: str = 'Difference in Activity scores', + ylabel: str = 'Coverage', + title: str = 'Coverage vs Difference in Activity scores', + figsize: tuple = (10, 6), + filepath: str = "scatter_plot.png", + render: bool = False) -> plt.Figure: """ - Plots a scatter plot with customizable column labels. - - Args: - df (pd.DataFrame): Input DataFrame containing the data to plot. - summarise_df (bool, optional): Whether to summarize the random networks if a random control has been performed. Defaults to True. - x_col (str, optional): Column name for the x-axis. Defaults to 'diff_dysregulation'. - y_col (str, optional): Column name for the y-axis. Defaults to 'coverage'. - size_col (str, optional): Column name for the size of the points. Defaults to 'nodes_with_phosphoinfo'. - hue_col (str, optional): Column name for the hue (color) of the points. Defaults to 'method'. - style_col (str, optional): Column name for the style of the points. Defaults to 'type'. - numeric_cols (list of str, optional): List of numeric columns to summarize. Defaults to all numeric columns in the DataFrame. - xlabel (str, optional): Label for the x-axis. Defaults to 'Difference in Activity scores'. - ylabel (str, optional): Label for the y-axis. Defaults to 'Coverage'. - title (str, optional): Title of the plot. Defaults to 'Coverage vs Difference in Activity scores'. - figsize (tuple, optional): Figure size of the plot. Defaults to (10, 6). - filepath (str, optional): Path to save the plot. If None, the plot will be saved as "scatter_plot.png". Defaults to "scatter_plot.png". - render (bool, optional): Whether to display the plot. Defaults to False. - - Returns: - matplotlib.figure.Figure: The figure object for the plot. + Plot a scatter plot with customizable column labels. + + Parameters + ---------- + df : pd.DataFrame + Input DataFrame containing the data to plot. + summarise_df : bool, optional + Whether to summarize the random networks if a random control has been performed. Defaults to True. + x_col : str, optional + Column name for the x-axis. Defaults to 'diff_dysregulation'. + y_col : str, optional + Column name for the y-axis. Defaults to 'coverage'. + size_col : str, optional + Column name for the size of the points. Defaults to 'nodes_with_phosphoinfo'. + hue_col : str, optional + Column name for the hue (color) of the points. Defaults to 'method'. + style_col : str, optional + Column name for the style of the points. Defaults to 'type'. + numeric_cols : list of str, optional + List of numeric columns to summarize. Defaults to all numeric columns in the DataFrame. + xlabel : str, optional + Label for the x-axis. Defaults to 'Difference in Activity scores'. + ylabel : str, optional + Label for the y-axis. Defaults to 'Coverage'. + title : str, optional + Title of the plot. Defaults to 'Coverage vs Difference in Activity scores'. + figsize : tuple, optional + Figure size of the plot. Defaults to (10, 6). + filepath : str, optional + Path to save the plot. If None, the plot will be saved as "scatter_plot.png". Defaults to "scatter_plot.png". + render : bool, optional + Whether to display the plot. Defaults to False. + + Returns + ------- + matplotlib.figure.Figure + The figure object for the plot. """ if numeric_cols is None: numeric_cols = df.select_dtypes(include=['number']).columns @@ -205,7 +253,7 @@ def plot_scatter(df, plt.show() if not filepath and not render: - logging.warning("No output specified. Returning the plot object.") + _log("No output specified. Returning the plot object.", level=30) return plt.gcf() @@ -265,13 +313,11 @@ def lollipop_plot( if orientation == 'vertical': ax.vlines(x=positions, ymin=0, ymax=values, color=color, linewidth=linewidth, label=category) ax.scatter(positions, values, color=color, s=size ** 2, marker=marker, zorder=3) - for j, value in enumerate(values): ax.text(positions[j], value + size * 0.2, str(value), ha='center', va='bottom', fontsize=size) else: ax.hlines(y=positions, xmin=0, xmax=values, color=color, linewidth=linewidth, label=category) ax.scatter(values, positions, color=color, s=size ** 2, marker=marker, zorder=3) - for j, value in enumerate(values): ax.text(value + size * 0.2, positions[j], str(value), va='center', ha='left', fontsize=size) @@ -298,35 +344,49 @@ def lollipop_plot( return fig -def plot_n_nodes_edges( - networks: Dict[str, nx.DiGraph], - filepath=None, - render=False, - orientation='vertical', - color_palette='Set2', - size=10, - linewidth=2, - marker='o', - show_nodes=True, - show_edges=True -): +def plot_n_nodes_edges(networks: Dict[str, nx.DiGraph], + filepath: str = None, + render: bool = False, + orientation: str = 'vertical', + color_palette: str = 'Set2', + size: int = 10, + linewidth: int = 2, + marker: str = 'o', + show_nodes: bool = True, + show_edges: bool = True) -> plt.Figure: """ Plot the number of nodes and edges in the networks using a lollipop plot. - Args: - networks (Dict[str, nx.DiGraph]): A dictionary of network names and their corresponding graphs. - filepath (str): Path to save the plot. Default is None. - render (bool): Whether to display the plot. Default is False. - orientation (str): 'vertical' or 'horizontal'. Default is 'vertical'. - color_palette (str): Matplotlib color palette. Default is 'Set2'. - size (int): Size of the markers. Default is 10. - linewidth (int): Line width of the lollipops. Default is 2. - marker (str): Marker style for the lollipops. Default is 'o'. - show_nodes (bool): Whether to show nodes count. Default is True. - show_edges (bool): Whether to show edges count. Default is True. + Parameters + ---------- + networks : dict of nx.DiGraph + A dictionary of network names and their corresponding graphs. + filepath : str, optional + Path to save the plot. Defaults to None. + render : bool, optional + Whether to display the plot. Defaults to False. + orientation : str, optional + Orientation of the plot ('vertical' or 'horizontal'). Defaults to 'vertical'. + color_palette : str, optional + Matplotlib color palette to use. Defaults to 'Set2'. + size : int, optional + Size of the markers. Defaults to 10. + linewidth : int, optional + Line width of the lollipops. Defaults to 2. + marker : str, optional + Marker style for the lollipops. Defaults to 'o'. + show_nodes : bool, optional + Whether to show the number of nodes. Defaults to True. + show_edges : bool, optional + Whether to show the number of edges. Defaults to True. + + Returns + ------- + matplotlib.figure.Figure + The figure object for the plot. """ if not show_nodes and not show_edges: - logging.warning("Both 'show_nodes' and 'show_edges' are False. Using show nodes as default.") + _log("Both 'show_nodes' and 'show_edges' are False. Using show nodes as default.", level=30) show_nodes = True labels = [] @@ -375,30 +435,50 @@ def plot_n_nodes_edges( return lolli_plot -def plot_n_nodes_edges_from_df( - metrics_df: pd.DataFrame, - metrics: List[str], - filepath=None, - render=False, - orientation='vertical', - color_palette='Set2', - size=10, - linewidth=2, - marker='o' -): +def plot_n_nodes_edges_from_df(metrics_df: pd.DataFrame, + metrics: List[str], + filepath: str = None, + render: bool = False, + orientation: str = 'vertical', + color_palette: str = 'Set2', + size: int = 10, + linewidth: int = 2, + marker: str = 'o') -> plt.Figure: """ Plot the specified metrics from a DataFrame using a lollipop plot. - Args: - metrics_df (pd.DataFrame): DataFrame containing metrics with networks as rows and specified metrics in columns. - metrics (List[str]): List of column names in the DataFrame to plot. - filepath (str): Path to save the plot. Default is None. - render (bool): Whether to display the plot. Default is False. - orientation (str): 'vertical' or 'horizontal'. Default is 'vertical'. - color_palette (str): Matplotlib color palette. Default is 'Set2'. - size (int): Size of the markers. Default is 10. - linewidth (int): Line width of the lollipops. Default is 2. - marker (str): Marker style for the lollipops. Default is 'o'. + Parameters + ---------- + metrics_df : pd.DataFrame + DataFrame containing metrics with networks as rows and specified metrics in columns. + metrics : list of str + List of column names in the DataFrame to plot. + filepath : str, optional + Path to save the plot. Defaults to None. + render : bool, optional + Whether to display the plot. Defaults to False. + orientation : str, optional + Orientation of the plot ('vertical' or 'horizontal'). Defaults to 'vertical'. + color_palette : str, optional + Matplotlib color palette to use. Defaults to 'Set2'. + size : int, optional + Size of the markers. Defaults to 10. + linewidth : int, optional + Line width of the lollipops. Defaults to 2. + marker : str, optional + Marker style for the lollipops. Defaults to 'o'. + + Returns + ------- + matplotlib.figure.Figure + The figure object for the plot. + + Raises + ------ + ValueError + If no metrics are provided or if the DataFrame is empty. + TypeError + If elements in 'metrics' are not strings. """ if not metrics: raise ValueError("At least one metric must be specified.") @@ -493,7 +573,7 @@ def build_heatmap_with_tree(distance_df: pd.DataFrame, g.fig.show() if not save and not render: - logging.warning("No output specified. Returning the plot") + _log("No output specified. Returning the plot", level=30) return g.fig @@ -543,6 +623,6 @@ def create_rank_heatmap(ora_results: pd.DataFrame, plt.show() if not filepath and not render: - logging.warning("No output specified. Returning the plot object.") + _log("No output specified. Returning the plot object.", level=30) - return plt.gcf() + return plt.gcf() \ No newline at end of file diff --git a/networkcommons/visual/_vis_networkx.py b/networkcommons/visual/_vis_networkx.py index 83ab9e7..be2354d 100644 --- a/networkcommons/visual/_vis_networkx.py +++ b/networkcommons/visual/_vis_networkx.py @@ -102,8 +102,9 @@ def visualize_network_simple(self, A (pygraphviz.AGraph): The visualized network graph. """ if len(self.network.nodes) > max_nodes: - _log("The network is too large to visualize.", level=40) - return + _log("The network is too large to visualize, you can increase the max_nodes parameter if needed.", level=40) + print("The network is too large to visualize, you can increase the max_nodes parameter if needed.") + return None A = nx.nx_agraph.to_agraph(self.network) A.graph_attr['ratio'] = '1.2' @@ -224,7 +225,8 @@ def visualize_network_default(self, source_dict, target_dict, prog='dot', - custom_style=None): + custom_style=None, + max_nodes=75): """ Visualizes the network using default styles. @@ -234,10 +236,16 @@ def visualize_network_default(self, target_dict (dict): Targets and measurement signs. prog (str, optional): Layout program to use. Defaults to 'dot'. custom_style (dict, optional): Custom style dictionary to apply. + max_nodes (int, optional): Maximum number of nodes to visualize. Defaults to 75. Returns: A (pygraphviz.AGraph): The visualized network graph. """ + if len(self.network.nodes) > max_nodes: + _log("The network is too large to visualize, you can increase the max_nodes parameter if needed.", level=40) + print("The network is too large to visualize, you can increase the max_nodes parameter if needed.") + return None + default_style = _styles.get_styles()['default'] style = _styles.merge_styles(default_style, custom_style) @@ -275,7 +283,8 @@ def visualize_network_sign_consistent(self, source_dict, target_dict, prog='dot', - custom_style=None): + custom_style=None, + max_nodes=75): """ Visualizes the network considering sign consistency. @@ -285,10 +294,15 @@ def visualize_network_sign_consistent(self, target_dict (dict): Targets and measurement signs. prog (str, optional): Layout program to use. Defaults to 'dot'. custom_style (dict, optional): Custom style dictionary to apply. - + max_nodes (int, optional): Maximum number of nodes to visualize. Defaults to 75. Returns: A (pygraphviz.AGraph): The visualized network graph. """ + if len(self.network.nodes) > max_nodes: + _log("The network is too large to visualize, you can increase the max_nodes parameter if needed.", level=40) + print("The network is too large to visualize, you can increase the max_nodes parameter if needed.") + return None + default_style = _styles.get_styles()['sign_consistent'] style = _styles.merge_styles(default_style, custom_style) diff --git a/tests/test_network_stats.py b/tests/test_network_stats.py index 333d79d..1b31282 100644 --- a/tests/test_network_stats.py +++ b/tests/test_network_stats.py @@ -1,10 +1,14 @@ -import unittest +import pytest from unittest.mock import patch, MagicMock import pandas as pd import networkx as nx import numpy as np +import matplotlib import matplotlib.pyplot as plt +# Set the matplotlib backend to 'Agg' for headless environments (like GitHub Actions) +matplotlib.use('Agg') + from networkcommons.visual._network_stats import ( plot_rank, plot_scatter, @@ -15,237 +19,204 @@ create_rank_heatmap ) +@pytest.fixture +def setup_data(): + # Sample data setup + df = pd.DataFrame({ + 'idx': ['Gene1', 'Gene2', 'Gene3'], + 'Value1': [10, 20, 15], + 'Value2': [12, 18, 14], + 'method': ['Method1', 'Method2', 'Method3'], + 'type': ['Type1', 'Type2', 'Type1'], + 'diff_dysregulation': [0.1, 0.2, -0.1], + 'coverage': [0.8, 0.9, 0.7], + 'nodes_with_phosphoinfo': [5, 10, 15] + }) + + networks = { + 'Network1': nx.DiGraph([(1, 2), (2, 3)]), + 'Network2': nx.DiGraph([(1, 2), (3, 4)]) + } + + metrics_df = pd.DataFrame({ + 'Nodes': [3, 4], + 'Edges': [2, 2] + }, index=['Network1', 'Network2']) + + jaccard_df = pd.DataFrame( + np.array([[0.0, 0.2, 0.5], + [0.2, 0.0, 0.3], + [0.5, 0.3, 0.0]]), + columns=['A', 'B', 'C'], + index=['A', 'B', 'C'] + ) + + ora_results = pd.DataFrame({ + 'ora_Term': ['Term1', 'Term2', 'Term1', 'Term2'], + 'network': ['Net1', 'Net1', 'Net2', 'Net2'], + 'ora_rank': [1, 2, 3, 4] + }) + ora_terms = ['Term1', 'Term2'] + + return df, networks, metrics_df, jaccard_df, ora_results, ora_terms + + +@patch('networkcommons.visual._network_stats._log') +@patch('networkcommons.visual._network_stats.plt.savefig') +def test_plot_rank_logs_missing_column(mock_savefig, mock_log, setup_data): + df, _, _, _, _, _ = setup_data + df = df.drop(columns=['idx']) # Simulate missing 'idx' column + + print(df) + + plot_rank(df, bio_ids=['Gene1'], filepath='test.png') + + # Check if _log was called with the correct warning message about the missing column + mock_log.assert_called_once_with("Column 'idx' not found in the DataFrame. Using the index as the ID column.", + level=30) + + +@patch('networkcommons.visual._network_stats._log') +def test_plot_rank_no_output_warning(mock_log, setup_data): + df, _, _, _, _, _ = setup_data + plot_rank(df, filepath=None, render=False) + + # Check if _log was called with the correct warning message about no output being specified + mock_log.assert_called_once_with("No output specified. Returning the plot object.", level=30) + + +@patch('networkcommons.visual._network_stats.plt.savefig') +def test_plot_rank_saves_figure(mock_savefig, setup_data): + df, _, _, _, _, _ = setup_data + filepath = 'test_rank_plot.png' + plot_rank(df, bio_ids=['Gene1'], filepath=filepath) + mock_savefig.assert_called_once_with(filepath) + + +@patch('networkcommons.visual._network_stats.plt.savefig') +def test_plot_scatter_saves_figure(mock_savefig, setup_data): + df, _, _, _, _, _ = setup_data + filepath = 'test_scatter_plot.png' + plot_scatter(df, filepath=filepath) + mock_savefig.assert_called_once_with(filepath) + + +@patch('matplotlib.figure.Figure.savefig') +def test_lollipop_plot_saves_figure(mock_savefig): + data = { + 'Label': ['A', 'B', 'C', 'D'], + 'Value': [10, 20, 30, 40] + } + df = pd.DataFrame(data) + filepath = 'test_lollipop_plot.png' + + # Call the lollipop_plot function + fig = lollipop_plot( + df=df, + label_col='Label', + value_col='Value', + filepath=filepath, + render=False + ) + + # Assert that savefig was called on the figure object with the correct filepath + mock_savefig.assert_called_once_with(filepath) + + +@patch('networkcommons.visual._network_stats._log') +def test_plot_n_nodes_edges_no_nodes_or_edges(mock_log, setup_data): + _, networks, _, _, _, _ = setup_data + # Call the function with both show_nodes and show_edges set to False + plot_n_nodes_edges(networks, filepath='test.png', show_nodes=False, show_edges=False) + + # Check if _log was called with the correct warning message about no nodes or edges being selected + mock_log.assert_called_once_with("Both 'show_nodes' and 'show_edges' are False. Using show nodes as default.", level=30) + + +@patch('networkcommons.visual._network_stats.lollipop_plot') +def test_plot_n_nodes_edges(mock_lollipop_plot, setup_data): + _, networks, _, _, _, _ = setup_data + filepath = 'test_nodes_edges_plot.png' + + # Call the plot_n_nodes_edges function + plot_n_nodes_edges(networks, filepath=filepath) + + # Prepare expected DataFrame passed to lollipop_plot + expected_df = pd.DataFrame({ + 'Network': ['Network1', 'Network1', 'Network2', 'Network2'], + 'Category': ['Nodes', 'Edges', 'Nodes', 'Edges'], + 'Values': [3, 2, 4, 2] + }) + + # Get the actual DataFrame that was passed to lollipop_plot + actual_df = mock_lollipop_plot.call_args[0][0] + + # Ensure both DataFrames have the same column order and reset index before comparing + expected_df = expected_df[['Network', 'Category', 'Values']].reset_index(drop=True) + actual_df_ordered = actual_df[['Network', 'Category', 'Values']].reset_index(drop=True) + + # Assert that the DataFrames are equal + pd.testing.assert_frame_equal(actual_df_ordered, expected_df, check_like=True) + + +@patch('networkcommons.visual._network_stats.lollipop_plot') +def test_plot_n_nodes_edges_from_df(mock_lollipop_plot, setup_data): + _, _, metrics_df, _, _, _ = setup_data + filepath = 'test_nodes_edges_df_plot.png' + + # Call the function to plot nodes and edges from the DataFrame + plot_n_nodes_edges_from_df(metrics_df, ['Nodes', 'Edges'], filepath=filepath) + + # Prepare expected DataFrame passed to lollipop_plot + expected_df = pd.DataFrame({ + 'Network': ['Network1', 'Network1', 'Network2', 'Network2'], + 'Category': ['Nodes', 'Edges', 'Nodes', 'Edges'], + 'Values': [3, 2, 4, 2] + }) + + # Get the actual DataFrame that was passed to lollipop_plot + actual_df = mock_lollipop_plot.call_args[0][0] + + # Ensure both DataFrames have the same column order and reset index before comparing + expected_df = expected_df[['Network', 'Category', 'Values']].reset_index(drop=True) + actual_df_ordered = actual_df[['Network', 'Category', 'Values']].reset_index(drop=True) + + # Assert that the DataFrames are equal + pd.testing.assert_frame_equal(actual_df_ordered, expected_df, check_like=True) + + +@patch('networkcommons.visual._network_stats.sns.clustermap') +def test_build_heatmap_with_tree(mock_clustermap, setup_data): + _, _, _, jaccard_df, _, _ = setup_data + mock_fig = MagicMock() + mock_clustermap.return_value = MagicMock(fig=mock_fig) + + output_dir = "." + filepath = f"{output_dir}/heatmap_with_tree.png" + + # Call the function with save=True + build_heatmap_with_tree(jaccard_df, save=True, output_dir=output_dir) + + # Assert that savefig was called correctly + mock_fig.savefig.assert_called_once_with(filepath, bbox_inches='tight') + + +@patch('networkcommons.visual._network_stats.sns.clustermap') +def test_build_heatmap_with_tree_render(mock_clustermap, setup_data): + _, _, _, jaccard_df, _, _ = setup_data + mock_fig = MagicMock() + mock_clustermap.return_value = MagicMock(fig=mock_fig) + + # Call the function with render=True + build_heatmap_with_tree(jaccard_df, render=True) + + # Assert that show was called correctly + mock_fig.show.assert_called_once() + -class TestPlotFunctions(unittest.TestCase): - - def setUp(self): - # Sample data setup - self.df = pd.DataFrame({ - 'idx': ['Gene1', 'Gene2', 'Gene3'], - 'Value1': [10, 20, 15], - 'Value2': [12, 18, 14], - 'method': ['Method1', 'Method2', 'Method3'], - 'type': ['Type1', 'Type2', 'Type1'], - 'diff_dysregulation': [0.1, 0.2, -0.1], - 'coverage': [0.8, 0.9, 0.7], - 'nodes_with_phosphoinfo': [5, 10, 15] - }) - - self.networks = { - 'Network1': nx.DiGraph([(1, 2), (2, 3)]), - 'Network2': nx.DiGraph([(1, 2), (3, 4)]) - } - - self.metrics_df = pd.DataFrame({ - 'Nodes': [3, 4], - 'Edges': [2, 2] - }, index=['Network1', 'Network2']) - - self.jaccard_df = pd.DataFrame( - np.array([[0.0, 0.2, 0.5], - [0.2, 0.0, 0.3], - [0.5, 0.3, 0.0]]), - columns=['A', 'B', 'C'], - index=['A', 'B', 'C'] - ) - - self.ora_results = pd.DataFrame({ - 'ora_Term': ['Term1', 'Term2', 'Term1', 'Term2'], - 'network': ['Net1', 'Net1', 'Net2', 'Net2'], - 'ora_rank': [1, 2, 3, 4] - }) - self.ora_terms = ['Term1', 'Term2'] - - @patch('networkcommons.visual._network_stats.plt.savefig') - def test_plot_rank(self, mock_savefig): - filepath = 'test_rank_plot.png' - plot_rank(self.df, bio_ids=['Gene1'], filepath=filepath) - mock_savefig.assert_called_once_with(filepath) - - @patch('networkcommons.visual._network_stats.plt.savefig') - def test_plot_scatter(self, mock_savefig): - filepath = 'test_scatter_plot.png' - plot_scatter(self.df, filepath=filepath) - mock_savefig.assert_called_once_with(filepath) - - @patch('matplotlib.figure.Figure.savefig') - def test_lollipop_plot_saves_figure(self, mock_savefig): - data = { - 'Label': ['A', 'B', 'C', 'D'], - 'Value': [10, 20, 30, 40] - } - df = pd.DataFrame(data) - filepath = 'test_lollipop_plot.png' - - # Call the lollipop_plot function - fig = lollipop_plot( - df=df, - label_col='Label', - value_col='Value', - filepath=filepath, - render=False - ) - - # Assert that savefig was called on the figure object with the correct filepath - mock_savefig.assert_called_once_with(filepath) - - @patch('networkcommons.visual._network_stats.lollipop_plot') - def test_plot_n_nodes_edges(self, mock_lollipop_plot): - filepath = 'test_nodes_edges_plot.png' - - # Call the plot_n_nodes_edges function - plot_n_nodes_edges(self.networks, filepath=filepath) - - # Prepare expected DataFrame passed to lollipop_plot - expected_df = pd.DataFrame({ - 'Network': ['Network1', 'Network1', 'Network2', 'Network2'], - 'Category': ['Nodes', 'Edges', 'Nodes', 'Edges'], - 'Values': [3, 2, 4, 2] - }) - - # Get the actual DataFrame that was passed to lollipop_plot - actual_df = mock_lollipop_plot.call_args[0][0] - - # Ensure both DataFrames have the same column order and reset index before comparing - expected_df = expected_df[['Network', 'Category', 'Values']].reset_index(drop=True) - actual_df_ordered = actual_df[['Network', 'Category', 'Values']].reset_index(drop=True) - - # Assert that the DataFrames are equal - pd.testing.assert_frame_equal(actual_df_ordered, expected_df, check_like=True) - - # Assert that lollipop_plot was called with the correct arguments - mock_lollipop_plot.assert_called_once_with( - actual_df, - label_col='Network', - value_col='Values', - orientation='vertical', - color_palette='Set2', - size=10, - linewidth=2, - marker='o', - title="Number of Nodes and Edges", - filepath=filepath, - render=False - ) - - @patch('networkcommons.visual._network_stats.lollipop_plot') - def test_plot_n_nodes_edges_from_df(self, mock_lollipop_plot): - filepath = 'test_nodes_edges_df_plot.png' - - # Call the function to plot nodes and edges from the DataFrame - plot_n_nodes_edges_from_df(self.metrics_df, ['Nodes', 'Edges'], filepath=filepath) - - # Prepare expected DataFrame passed to lollipop_plot - expected_df = pd.DataFrame({ - 'Network': ['Network1', 'Network1', 'Network2', 'Network2'], - 'Category': ['Nodes', 'Edges', 'Nodes', 'Edges'], - 'Values': [3, 2, 4, 2] - }) - - # Get the actual DataFrame that was passed to lollipop_plot - actual_df = mock_lollipop_plot.call_args[0][0] - - # Ensure both DataFrames have the same column order and reset index before comparing - expected_df = expected_df[['Network', 'Category', 'Values']].reset_index(drop=True) - actual_df_ordered = actual_df[['Network', 'Category', 'Values']].reset_index(drop=True) - - # Assert that the DataFrames are equal - pd.testing.assert_frame_equal(actual_df_ordered, expected_df, check_like=True) - - # Assert that lollipop_plot was called with the correct arguments - mock_lollipop_plot.assert_called_once_with( - # actual and expected DataFrames are the same - checked above - actual_df, - label_col='Network', - value_col='Values', - orientation='vertical', - color_palette='Set2', - size=10, - linewidth=2, - marker='o', - title="Number of Nodes and Edges", - filepath=filepath, - render=False - ) - - @patch('networkcommons.visual._network_stats.sns.clustermap') - def test_build_heatmap_with_tree(self, mock_clustermap): - # Set up the mock return value for clustermap - mock_fig = MagicMock() - mock_clustermap.return_value = MagicMock(fig=mock_fig) - - output_dir = "." - filepath = f"{output_dir}/heatmap_with_tree.png" - - # Call the function with save=True - build_heatmap_with_tree(self.jaccard_df, save=True, output_dir=output_dir) - - # Assert that savefig was called correctly - mock_fig.savefig.assert_called_once_with(filepath, bbox_inches='tight') - - @patch('networkcommons.visual._network_stats.sns.clustermap') - def test_build_heatmap_with_tree_render(self, mock_clustermap): - # Set up the mock return value for clustermap - mock_fig = MagicMock() - mock_clustermap.return_value = MagicMock(fig=mock_fig) - - # Call the function with render=True - build_heatmap_with_tree(self.jaccard_df, render=True) - - # Assert that show was called correctly - mock_fig.show.assert_called_once() - - @patch('networkcommons.visual._network_stats.plt.savefig') - def test_create_rank_heatmap(self, mock_savefig): - filepath = 'test_rank_heatmap.png' - create_rank_heatmap(self.ora_results, self.ora_terms, filepath=filepath) - mock_savefig.assert_called_once_with(filepath) - - @patch('networkcommons.visual._network_stats.logging.warning') - @patch('networkcommons.visual._network_stats.lollipop_plot') - def test_plot_n_nodes_edges_invalid_input(self, - mock_lollipop_plot, - mock_logging_warning): - filepath = 'test_nodes_edges_plot.png' - - # Call the function with both show_nodes and show_edges set to False - plot_n_nodes_edges(self.networks, filepath=filepath, show_nodes=False, show_edges=False) - - # Check that a warning was logged - mock_logging_warning.assert_called_once_with( - "Both 'show_nodes' and 'show_edges' are False. Using show nodes as default." - ) - - # Verify that the lollipop_plot was called with a DataFrame that contains the nodes data - actual_df = mock_lollipop_plot.call_args[0][0] - expected_df = pd.DataFrame({ - 'Network': ['Network1', 'Network2'], # Assuming 'self.networks' contains these two networks - 'Category': ['Nodes', 'Nodes'], # Because 'show_nodes' was set to True by default - 'Values': [len(self.networks['Network1'].nodes), len(self.networks['Network2'].nodes)] - }) - - # Assert the DataFrame passed to lollipop_plot is correct - pd.testing.assert_frame_equal(actual_df, expected_df, check_like=True) - - # Check that lollipop_plot was called with the correct parameters - mock_lollipop_plot.assert_called_with( - actual_df, - label_col='Network', - value_col='Values', - orientation='vertical', - color_palette='Set2', - size=10, - linewidth=2, - marker='o', - title="Number of Nodes", - filepath=filepath, - render=False - ) - - def test_plot_n_nodes_edges_from_df_invalid_input(self): - with self.assertRaises(ValueError): - plot_n_nodes_edges_from_df(self.metrics_df, [], render=False) - - -if __name__ == '__main__': - unittest.main() +@patch('networkcommons.visual._network_stats.plt.savefig') +def test_create_rank_heatmap_saves_figure(mock_savefig, setup_data): + _, _, _, _, ora_results, ora_terms = setup_data + filepath = 'test_rank_heatmap.png' + create_rank_heatmap(ora_results, ora_terms, filepath=filepath) + mock_savefig.assert_called_once_with(filepath) \ No newline at end of file diff --git a/tests/test_rnaseq.py b/tests/test_vis_rnaseq.py similarity index 100% rename from tests/test_rnaseq.py rename to tests/test_vis_rnaseq.py