Skip to content

Commit

Permalink
Merge pull request #18 from SCAI-BIO/update-shap-viisualization
Browse files Browse the repository at this point in the history
feat: add persistence to shap plots and add unit test
  • Loading branch information
tiadams authored Nov 11, 2024
2 parents 4253869 + 2812a90 commit daa3a47
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 3 deletions.
25 changes: 22 additions & 3 deletions syndat/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,18 @@ def plot_correlations(real: pandas.DataFrame, synthetic: pandas.DataFrame, store
fig = ax.get_figure()
fig.savefig(store_destination + "/" + names[idx] + '.png', bbox_inches="tight")

def plot_shap_discrimination(real: pandas.DataFrame, synthetic: pandas.DataFrame) -> None:
# Assuming 'real' and 'synthetic_no_dp' are your datasets and are pandas DataFrames
def plot_shap_discrimination(real: pd.DataFrame, synthetic: pd.DataFrame, save_path: str = None) -> None:
"""
Generates a SHAP summary plot to illustrate the discrimination between real and synthetic datasets
using a Random Forest classifier.
:param real: The real data
:param synthetic: The synthetic data
:param save_path: Path to the file where the resulting plot should be saved. If None, the plot will not be saved.
:return: None
"""
# Assuming 'real' and 'synthetic' are your datasets and are pandas DataFrames
# Add a label column to each dataset
real['label'] = 1
synthetic['label'] = 0
Expand Down Expand Up @@ -99,7 +109,16 @@ def plot_shap_discrimination(real: pandas.DataFrame, synthetic: pandas.DataFrame
shap_values = explainer.shap_values(X_test)

# Plot SHAP summary
shap.summary_plot(shap_values[1], X_test)
plt.figure()
shap.summary_plot(shap_values[1], X_test, show=False)

# Save the plot if save_path is specified
if save_path:
plt.savefig(save_path, bbox_inches='tight')
print(f"Plot saved to {save_path}")

# Show the plot
plt.show()


def plot_categorical_feature(feature: str, real_data: pandas.DataFrame, synthetic_data: pandas.DataFrame) -> None:
Expand Down
30 changes: 30 additions & 0 deletions tests/test_visualization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import unittest
import pandas as pd
import numpy as np
import os

from syndat import plot_shap_discrimination


class TestPlotShapDiscrimination(unittest.TestCase):

def setUp(self):
# Create sample data for testing
self.real = pd.DataFrame(np.random.normal(size=(100, 5)), columns=[f"feature_{i}" for i in range(5)])
self.synthetic = pd.DataFrame(np.random.normal(size=(100, 5)), columns=[f"feature_{i}" for i in range(5)])

# Define the path where the plot will be temporarily saved
self.save_path = "test_shap_plot.png"

def test_plot_shap_discrimination(self):
# Call the function with test data and save_path
plot_shap_discrimination(self.real, self.synthetic, save_path=self.save_path)

# Check if the plot file was created
self.assertTrue(os.path.exists(self.save_path), "SHAP plot file was not created.")

def tearDown(self):
# Remove the file if it exists after the test
if os.path.exists(self.save_path):
os.remove(self.save_path)

0 comments on commit daa3a47

Please sign in to comment.