Skip to content

Commit

Permalink
make plot_fn return the graphs and not show them
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickleonardy committed Nov 3, 2023
1 parent 8025f7c commit 6130a4b
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 24 deletions.
34 changes: 20 additions & 14 deletions cobra/evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def _compute_scalar_metrics(y_true: np.ndarray,
lift_at=lift_at), 2)
})

def plot_roc_curve(self, path: str=None, dim: tuple=(12, 8)):
def plot_roc_curve(self, path: str=None, dim: tuple=(12, 8)) -> plt.Figure:
"""Plot ROC curve of the model.
Parameters
Expand Down Expand Up @@ -197,8 +197,8 @@ def plot_roc_curve(self, path: str=None, dim: tuple=(12, 8)):

if path:
plt.savefig(path, format="png", dpi=300, bbox_inches="tight")

plt.show()
plt.close()
return fig

def plot_confusion_matrix(self, path: str=None, dim: tuple=(12, 8),
labels: list=["0", "1"]):
Expand Down Expand Up @@ -232,9 +232,10 @@ def plot_confusion_matrix(self, path: str=None, dim: tuple=(12, 8),
if path:
plt.savefig(path, format="png", dpi=300, bbox_inches="tight")

plt.show()
plt.close()
return fig

def plot_cumulative_response_curve(self, path: str=None, dim: tuple=(12, 8)):
def plot_cumulative_response_curve(self, path: str=None, dim: tuple=(12, 8)) -> plt.Figure:
"""Plot cumulative response curve.
Parameters
Expand Down Expand Up @@ -283,9 +284,10 @@ def plot_cumulative_response_curve(self, path: str=None, dim: tuple=(12, 8)):
if path is not None:
plt.savefig(path, format="png", dpi=300, bbox_inches="tight")

plt.show()
plt.close()
return fig

def plot_lift_curve(self, path: str=None, dim: tuple=(12, 8)):
def plot_lift_curve(self, path: str=None, dim: tuple=(12, 8)) -> plt.Figure:
"""Plot lift per decile.
Parameters
Expand Down Expand Up @@ -332,9 +334,10 @@ def plot_lift_curve(self, path: str=None, dim: tuple=(12, 8)):
if path is not None:
plt.savefig(path, format="png", dpi=300, bbox_inches="tight")

plt.show()
plt.close()
return fig

def plot_cumulative_gains(self, path: str=None, dim: tuple=(12, 8)):
def plot_cumulative_gains(self, path: str=None, dim: tuple=(12, 8)) -> plt.Figure:
"""Plot cumulative gains per decile.
Parameters
Expand Down Expand Up @@ -376,7 +379,8 @@ def plot_cumulative_gains(self, path: str=None, dim: tuple=(12, 8)):

if path is not None:
plt.savefig(path, format="png", dpi=300, bbox_inches="tight")
plt.show()
plt.close()
return fig

@staticmethod
def _find_optimal_cutoff(y_true: np.ndarray,
Expand Down Expand Up @@ -658,7 +662,7 @@ def _compute_qq_residuals(y_true: np.ndarray,
"residuals": df["z_res"].values,
})

def plot_predictions(self, path: str=None, dim: tuple=(12, 8)):
def plot_predictions(self, path: str=None, dim: tuple=(12, 8)) -> plt.Figure:
"""Plot predictions from the model against actual values.
Parameters
Expand Down Expand Up @@ -692,9 +696,10 @@ def plot_predictions(self, path: str=None, dim: tuple=(12, 8)):
if path:
plt.savefig(path, format="png", dpi=300, bbox_inches="tight")

plt.show()
plt.close()
return fig

def plot_qq(self, path: str=None, dim: tuple=(12, 8)):
def plot_qq(self, path: str=None, dim: tuple=(12, 8)) -> plt.Figure:
"""Display a Q-Q plot from the standardized prediction residuals.
Parameters
Expand Down Expand Up @@ -733,4 +738,5 @@ def plot_qq(self, path: str=None, dim: tuple=(12, 8)):
if path:
plt.savefig(path, format="png", dpi=300, bbox_inches="tight")

plt.show()
plt.close()
return fig
6 changes: 3 additions & 3 deletions cobra/evaluation/pigs_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def plot_incidence(pig_tables: pd.DataFrame,
variable: str,
model_type: str,
column_order: list=None,
dim: tuple=(12, 8)):
dim: tuple=(12, 8)) -> plt.Figure:
"""Plots a Predictor Insights Graph (PIG), a graph in which the mean
target value is plotted for a number of bins constructed from a predictor
variable. When the target is a binary classification target,
Expand Down Expand Up @@ -257,5 +257,5 @@ def plot_incidence(pig_tables: pd.DataFrame,
plt.tight_layout()
plt.margins(0.01)

# Show
plt.show()
plt.close()
return fig
18 changes: 11 additions & 7 deletions cobra/evaluation/plotting_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

def plot_univariate_predictor_quality(df_metric: pd.DataFrame,
dim: tuple=(12, 8),
path: str=None):
path: str=None) -> plt.Figure:
"""Plot univariate quality of the predictors.
Parameters
Expand Down Expand Up @@ -58,7 +58,8 @@ def plot_univariate_predictor_quality(df_metric: pd.DataFrame,

plt.gca().legend().set_title("")

plt.show()
plt.close()
return fig

def plot_correlation_matrix(df_corr: pd.DataFrame,
dim: tuple=(12, 8),
Expand All @@ -81,15 +82,16 @@ def plot_correlation_matrix(df_corr: pd.DataFrame,
if path is not None:
plt.savefig(path, format="png", dpi=300, bbox_inches="tight")

plt.show()
plt.close()
return fig

def plot_performance_curves(model_performance: pd.DataFrame,
dim: tuple=(12, 8),
path: str=None,
colors: dict={"train": "#0099bf",
"selection": "#ff9500",
"validation": "#8064a2"},
metric_name: str=None):
metric_name: str=None) -> plt.Figure:
"""Plot performance curves generated by the forward feature selection
for the train-selection-validation sets.
Expand Down Expand Up @@ -159,12 +161,13 @@ def plot_performance_curves(model_performance: pd.DataFrame,
if path is not None:
plt.savefig(path, format="png", dpi=300, bbox_inches="tight")

plt.show()
plt.close()
return fig

def plot_variable_importance(df_variable_importance: pd.DataFrame,
title: str=None,
dim: tuple=(12, 8),
path: str=None):
path: str=None) -> plt.Figure:
"""Plot variable importance of a given model.
Parameters
Expand Down Expand Up @@ -199,4 +202,5 @@ def plot_variable_importance(df_variable_importance: pd.DataFrame,
if path is not None:
plt.savefig(path, format="png", dpi=300, bbox_inches="tight")

plt.show()
plt.close()
return fig

0 comments on commit 6130a4b

Please sign in to comment.