Skip to content

Commit

Permalink
updated test for rna-seq plots
Browse files Browse the repository at this point in the history
  • Loading branch information
ivanovaos committed Sep 10, 2024
1 parent 5c2fd3b commit af036c1
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 27 deletions.
65 changes: 38 additions & 27 deletions networkcommons/visual/_rnaseq.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
'build_volcano_plot',
'build_ma_plot',
'plot_pca',
'build_heatmap_with_tree',
'plot_heatmap_with_tree',
'plot_density'
]

Expand All @@ -51,7 +51,8 @@ def plot_density(df,
xlabel='Intensity',
ylabel='Density'):
"""
Plots density of intensity values for specified genes, including mean and quantile lines, and separates distributions by groups if metadata is provided.
Plots density of intensity values for specified genes, including mean and quantile lines,
and separates distributions by groups if metadata is provided.
Each gene is displayed in a separate subplot.
Args:
Expand Down Expand Up @@ -319,42 +320,52 @@ def plot_pca(dataframe, metadata, feature_col='idx', **kwargs):
return pca_df


def build_heatmap_with_tree(
def plot_heatmap_with_tree(
data: pd.DataFrame,
top_n: int = 50,
value_column: str = 'log2FoldChange_condition_1',
conditions: list[str] = None,
title: str = "Heatmap of Top Differentially Expressed Genes",
clustering_method: str = 'ward',
metric: str = 'euclidean',
title: str = 'Heatmap with Hierarchical Clustering',
xlabel: str = 'Samples',
ylabel: str = 'Genes',
cmap: str = 'viridis',
save: bool = False,
output_dir: str = "."
output_dir: str = ".",
render: bool = False
):
"""
Build a heatmap with hierarchical clustering for the top differentially expressed genes across multiple conditions.
Creates a heatmap with hierarchical clustering for rows and columns.
Args:
data (pd.DataFrame): DataFrame containing RNA-seq results.
top_n (int): Number of top differentially expressed genes to include in the heatmap.
value_column (str): Column name for the values to rank and select the top genes.
conditions (list[str]): List of condition columns to include in the heatmap.
title (str): Title of the plot.
save (bool): Whether to save the plot. Default is False.
output_dir (str): Directory to save the plot. Default is ".".
data (pd.DataFrame): DataFrame containing the data for the heatmap.
clustering_method (str, optional): Method for hierarchical clustering. Defaults to 'ward'.
metric (str, optional): Metric for distance calculation. Defaults to 'euclidean'.
title (str, optional): Title of the plot. Defaults to 'Heatmap with Hierarchical Clustering'.
xlabel (str, optional): Label for the x-axis. Defaults to 'Samples'.
ylabel (str, optional): Label for the y-axis. Defaults to 'Genes'.
cmap (str, optional): Colormap for the heatmap. Defaults to 'viridis'.
save (bool, optional): Whether to save the plot. Defaults to False.
output_dir (str, optional): Directory to save the plot if `save` is True. Defaults to ".".
render (bool, optional): Whether to show the plot. Defaults to False.
Returns:
matplotlib.figure.Figure: The created figure object.
"""
if conditions is None:
raise ValueError("Conditions must be provided as a list of column names.")
# Compute the distance matrices
row_linkage = sns.clustermap(data, method=clustering_method, metric=metric, cmap=cmap)
col_linkage = sns.clustermap(data.T, method=clustering_method, metric=metric, cmap=cmap)

# Select top differentially expressed genes
top_genes = data.nlargest(top_n, value_column).index
top_data = data.loc[top_genes, conditions]
fig = plt.figure(figsize=(10, 10))

# Create the clustermap
g = sns.clustermap(top_data, cmap="viridis", cbar=True, fmt=".2f", linewidths=.5)
sns.heatmap(data, cmap=cmap, ax=fig.gca(), cbar=True, annot=False, fmt=".2f")

plt.title(title)
plt.ylabel("Gene")
plt.xlabel("Condition")
plt.xlabel(xlabel)
plt.ylabel(ylabel)

if save:
plt.savefig(f"{output_dir}/heatmap_with_tree.png")
plt.savefig(f"{output_dir}/heatmap_with_tree.png", bbox_inches='tight')

if render:
plt.show()

plt.show()
return fig
122 changes: 122 additions & 0 deletions tests/test_rnaseq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import pytest
import pandas as pd
import numpy as np
from unittest.mock import patch
import matplotlib.pyplot as plt
from networkcommons.visual import (plot_density,
build_volcano_plot,
build_ma_plot,
plot_pca,
plot_heatmap_with_tree)


@pytest.fixture
def example_dataframe():
"""Fixture for generating an example dataframe for testing."""
data = {
'idx': ['gene_1', 'gene_2', 'gene_3'],
'sample_1': [10, 15, 5],
'sample_2': [20, 18, 9],
'sample_3': [12, 22, 8]
}
return pd.DataFrame(data)


@pytest.fixture
def metadata_dataframe():
"""Fixture for generating an example metadata dataframe."""
metadata = {
'sample_ID': ['sample_1', 'sample_2', 'sample_3'],
'group': ['control', 'treated', 'control']
}
return pd.DataFrame(metadata)


def test_plot_density():
"""Test the plot_density function with valid data."""

# Create a sample dataframe with enough data points
example_dataframe = pd.DataFrame({
'idx': ['gene_1', 'gene_2', 'gene_3'],
'sample_1': [10, 15, 5],
'sample_2': [20, 18, 9],
'sample_3': [12, 22, 8],
'sample_4': [14, 19, 7] # Adding more samples to ensure enough data points
})

# Create metadata for grouping
metadata_dataframe = pd.DataFrame({
'sample_ID': ['sample_1', 'sample_2', 'sample_3', 'sample_4'],
'group': ['control', 'treated', 'control', 'treated']
})

gene_ids = ['gene_1', 'gene_2'] # Make sure this has genes present in the dataframe

# Mock plt.show to avoid blocking during the test
with patch('matplotlib.pyplot.show'):
plot_density(example_dataframe, gene_ids, metadata_dataframe)

# Assert if the plot was created by checking the number of axes
assert len(plt.gcf().get_axes()) == 2 # Should have 2 subplots for 2 genes

def test_build_volcano_plot():
"""Test the build_volcano_plot function."""
data = pd.DataFrame({
'log2FoldChange': [1.5, -2.0, 0.5, -0.3],
'pvalue': [0.01, 0.04, 0.20, 0.05]
})

with patch('matplotlib.pyplot.show'):
build_volcano_plot(data)

# Assert if the plot was created
assert len(plt.gcf().get_axes()) == 1 # Should have one main axis for the volcano plot


def test_build_ma_plot():
"""Test the build_ma_plot function."""
data = pd.DataFrame({
'log2FoldChange': [1.5, -2.0, 0.5, -0.3],
'meanExpression': [10, 15, 20, 25]
})

with patch('matplotlib.pyplot.show'):
build_ma_plot(data, log2fc='log2FoldChange', mean_exp='meanExpression')

# Assert if the plot was created
assert len(plt.gcf().get_axes()) == 1 # Should have one main axis for the MA plot


def test_plot_pca(example_dataframe, metadata_dataframe):
"""Test the plot_pca function."""
with patch('matplotlib.pyplot.show'):
pca_df = plot_pca(example_dataframe, metadata_dataframe)

# Assert that the returned dataframe has the correct shape
assert pca_df.shape[1] == 3 # Expecting PCA1, PCA2, and 'group' columns


def test_build_heatmap_with_tree():
"""Test the build_heatmap_with_tree function."""
data = pd.DataFrame({
'gene_1': [2.3, -1.1, 0.4],
'gene_2': [1.2, 0.5, -0.7],
'gene_3': [3.1, 0.9, -1.2]
}, index=['condition_1', 'condition_2', 'condition_3'])

with patch('matplotlib.pyplot.show'):
fig = plot_heatmap_with_tree(
data,
clustering_method='ward',
metric='euclidean',
title='Test Heatmap',
xlabel='Samples',
ylabel='Genes',
cmap='viridis',
save=False,
render=False
)

# Assert if the figure was created and contains an axes object
assert isinstance(fig, plt.Figure) # Check if the returned object is a matplotlib Figure
assert len(fig.get_axes()) > 0 # Assert that axes were created in the figure

0 comments on commit af036c1

Please sign in to comment.