diff --git a/copulas/visualization.py b/copulas/visualization.py index 2cc0601e..aed9f189 100644 --- a/copulas/visualization.py +++ b/copulas/visualization.py @@ -93,6 +93,8 @@ def compare_2d(real, synth, columns=None, figsize=None): ax = real.plot.scatter(x, y, color='blue', alpha=0.5, figsize=figsize) ax = synth.plot.scatter(x, y, ax=ax, color='orange', alpha=0.5, figsize=figsize) ax.legend(['Real', 'Synthetic']) + + return ax def compare_1d(real, synth, columns=None, figsize=None):