Skip to content

Commit

Permalink
fix: new evaluation functions
Browse files Browse the repository at this point in the history
  • Loading branch information
tiadams committed Sep 30, 2024
1 parent e2ac0c2 commit 4253869
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions syndat/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,7 @@ def plot_correlations(real: pandas.DataFrame, synthetic: pandas.DataFrame, store
fig.savefig(store_destination + "/" + names[idx] + '.png', bbox_inches="tight")

def plot_shap_discrimination(real: pandas.DataFrame, synthetic: pandas.DataFrame) -> None:
"""
Trains a Random Forest Classifier to discriminate between real and synthetic data and plots SHAP summary values.
:param real: The real data
:param synthetic: The synthetic data
"""
# Assuming 'real' and 'synthetic' are your datasets and are pandas DataFrames
# Assuming 'real' and 'synthetic_no_dp' are your datasets and are pandas DataFrames
# Add a label column to each dataset
real['label'] = 1
synthetic['label'] = 0
Expand Down Expand Up @@ -105,7 +99,7 @@ def plot_shap_discrimination(real: pandas.DataFrame, synthetic: pandas.DataFrame
shap_values = explainer.shap_values(X_test)

# Plot SHAP summary
shap.summary_plot(shap_values[:, :, 1], X_test)
shap.summary_plot(shap_values[1], X_test)


def plot_categorical_feature(feature: str, real_data: pandas.DataFrame, synthetic_data: pandas.DataFrame) -> None:
Expand All @@ -126,6 +120,9 @@ def plot_categorical_feature(feature: str, real_data: pandas.DataFrame, syntheti
plt.title(f'Real Data - {feature}')
plt.xticks(rotation=90)

# Get the maximum y value for the real dataset
real_max_y = plt.gca().get_ylim()[1] # Get the current y limits of the real plot

# Plot for the synthetic dataset
plt.subplot(1, 2, 2)
sns.countplot(x=feature, data=synthetic_data, color='orange')
Expand All @@ -134,6 +131,9 @@ def plot_categorical_feature(feature: str, real_data: pandas.DataFrame, syntheti
plt.title(f'Synthetic Data - {feature}')
plt.xticks(rotation=90)

# Set the y-axis limits to be the same as the real dataset
plt.ylim(0, real_max_y) # Align y-axis limits

plt.tight_layout()
plt.show()

Expand All @@ -147,6 +147,7 @@ def plot_numerical_feature(feature: str, real_data: pandas.DataFrame, synthetic_
:param synthetic_data: The synthetic data
"""
# Calculate summary statistics
# Calculate summary statistics
def get_summary_stats(data, feature):
return {
'Mean': data[feature].mean(),
Expand All @@ -168,7 +169,7 @@ def get_summary_stats(data, feature):

plt.figure(figsize=(14, 8))

# Compute the combined range for x-axis limits
# Compute the combined range for x-axis limits across both datasets
min_value = min(real_data[feature].min(), synthetic_data[feature].min())
max_value = max(real_data[feature].max(), synthetic_data[feature].max())

Expand Down

0 comments on commit 4253869

Please sign in to comment.