diff --git a/syndat/visualization.py b/syndat/visualization.py index f30d7c9..8d7c42c 100644 --- a/syndat/visualization.py +++ b/syndat/visualization.py @@ -70,13 +70,7 @@ def plot_correlations(real: pandas.DataFrame, synthetic: pandas.DataFrame, store fig.savefig(store_destination + "/" + names[idx] + '.png', bbox_inches="tight") def plot_shap_discrimination(real: pandas.DataFrame, synthetic: pandas.DataFrame) -> None: - """ - Trains a Random Forest Classifier to discriminate between real and synthetic data and plots SHAP summary values. - - :param real: The real data - :param synthetic: The synthetic data - """ - # Assuming 'real' and 'synthetic' are your datasets and are pandas DataFrames + # Assuming 'real' and 'synthetic_no_dp' are your datasets and are pandas DataFrames # Add a label column to each dataset real['label'] = 1 synthetic['label'] = 0 @@ -105,7 +99,7 @@ def plot_shap_discrimination(real: pandas.DataFrame, synthetic: pandas.DataFrame shap_values = explainer.shap_values(X_test) # Plot SHAP summary - shap.summary_plot(shap_values[:, :, 1], X_test) + shap.summary_plot(shap_values[1], X_test) def plot_categorical_feature(feature: str, real_data: pandas.DataFrame, synthetic_data: pandas.DataFrame) -> None: @@ -126,6 +120,9 @@ def plot_categorical_feature(feature: str, real_data: pandas.DataFrame, syntheti plt.title(f'Real Data - {feature}') plt.xticks(rotation=90) + # Get the maximum y value for the real dataset + real_max_y = plt.gca().get_ylim()[1] # Get the current y limits of the real plot + # Plot for the synthetic dataset plt.subplot(1, 2, 2) sns.countplot(x=feature, data=synthetic_data, color='orange') @@ -134,6 +131,9 @@ def plot_categorical_feature(feature: str, real_data: pandas.DataFrame, syntheti plt.title(f'Synthetic Data - {feature}') plt.xticks(rotation=90) + # Set the y-axis limits to be the same as the real dataset + plt.ylim(0, real_max_y) # Align y-axis limits + plt.tight_layout() plt.show() @@ -147,6 +147,7 @@ def plot_numerical_feature(feature: str, real_data: pandas.DataFrame, synthetic_ :param synthetic_data: The synthetic data """ # Calculate summary statistics + # Calculate summary statistics def get_summary_stats(data, feature): return { 'Mean': data[feature].mean(), @@ -168,7 +169,7 @@ def get_summary_stats(data, feature): plt.figure(figsize=(14, 8)) - # Compute the combined range for x-axis limits + # Compute the combined range for x-axis limits across both datasets min_value = min(real_data[feature].min(), synthetic_data[feature].min()) max_value = max(real_data[feature].max(), synthetic_data[feature].max())