Skip to content

Commit

Permalink
add the axes object to figure output
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickleonardy committed Nov 6, 2023
1 parent c6c3c56 commit 4238080
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 22 deletions.
78 changes: 65 additions & 13 deletions cobra/evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def _compute_scalar_metrics(y_true: np.ndarray,
lift_at=lift_at), 2)
})

def plot_roc_curve(self, path: str=None, dim: tuple=(12, 8)) -> plt.Figure:
def plot_roc_curve(self, path: str=None, dim: tuple=(12, 8)) -> tuple[plt.Figure, plt.Axes]:
"""Plot ROC curve of the model.
Parameters
Expand All @@ -167,6 +167,13 @@ def plot_roc_curve(self, path: str=None, dim: tuple=(12, 8)) -> plt.Figure:
Path to store the figure.
dim : tuple, optional
Tuple with width and length of the plot.
Retruns
-------
fig : plt.Figure
figure object containing the ROC curve
ax : plt.Axes
axes object linked to the figure
"""

if self.roc_curve is None:
Expand Down Expand Up @@ -198,7 +205,7 @@ def plot_roc_curve(self, path: str=None, dim: tuple=(12, 8)) -> plt.Figure:
if path:
plt.savefig(path, format="png", dpi=300, bbox_inches="tight")
plt.close()
return fig
return fig, ax

def plot_confusion_matrix(self, path: str=None, dim: tuple=(12, 8),
labels: list=["0", "1"]):
Expand All @@ -212,6 +219,13 @@ def plot_confusion_matrix(self, path: str=None, dim: tuple=(12, 8),
Tuple with width and length of the plot.
labels : list, optional
Optional list of labels, default "0" and "1".
Retruns
-------
fig : plt.Figure
figure object containing confusion matrix
ax : plt.Axes
axes object linked to the figure
"""

if self.confusion_matrix is None:
Expand All @@ -233,9 +247,12 @@ def plot_confusion_matrix(self, path: str=None, dim: tuple=(12, 8),
plt.savefig(path, format="png", dpi=300, bbox_inches="tight")

plt.close()
return fig
return fig, ax

def plot_cumulative_response_curve(self, path: str=None, dim: tuple=(12, 8)) -> plt.Figure:
def plot_cumulative_response_curve(self,
path: str=None,
dim: tuple=(12, 8)
) -> tuple[plt.Figure, plt.Axes]:
"""Plot cumulative response curve.
Parameters
Expand All @@ -244,6 +261,13 @@ def plot_cumulative_response_curve(self, path: str=None, dim: tuple=(12, 8)) ->
Path to store the figure.
dim : tuple, optional
Tuple with width and length of the plot.
Retruns
-------
fig : plt.Figure
figure object containing the cumulative response curve
ax : plt.Axes
axes object linked to the figure
"""

if self.lift_curve is None:
Expand Down Expand Up @@ -285,9 +309,9 @@ def plot_cumulative_response_curve(self, path: str=None, dim: tuple=(12, 8)) ->
plt.savefig(path, format="png", dpi=300, bbox_inches="tight")

plt.close()
return fig
return fig, ax

def plot_lift_curve(self, path: str=None, dim: tuple=(12, 8)) -> plt.Figure:
def plot_lift_curve(self, path: str=None, dim: tuple=(12, 8)) -> tuple[plt.Figure, plt.Axes]:
"""Plot lift per decile.
Parameters
Expand All @@ -296,6 +320,13 @@ def plot_lift_curve(self, path: str=None, dim: tuple=(12, 8)) -> plt.Figure:
Path to store the figure.
dim : tuple, optional
Tuple with width and length of the plot.
Retruns
-------
fig : plt.Figure
figure object the livt curve
ax : plt.Axes
axes object linked to the figure
"""

if self.lift_curve is None:
Expand Down Expand Up @@ -335,9 +366,9 @@ def plot_lift_curve(self, path: str=None, dim: tuple=(12, 8)) -> plt.Figure:
plt.savefig(path, format="png", dpi=300, bbox_inches="tight")

plt.close()
return fig
return fig, ax

def plot_cumulative_gains(self, path: str=None, dim: tuple=(12, 8)) -> plt.Figure:
def plot_cumulative_gains(self, path: str=None, dim: tuple=(12, 8)) -> tuple[plt.Figure, plt.Axes]:
"""Plot cumulative gains per decile.
Parameters
Expand All @@ -346,6 +377,13 @@ def plot_cumulative_gains(self, path: str=None, dim: tuple=(12, 8)) -> plt.Figur
Path to store the figure.
dim : tuple, optional
Tuple with width and length of the plot.
Retruns
-------
fig : plt.Figure
figure object containing the cumulative gains curve
ax : plt.Axes
axes object linked to the figure
"""

with sns.axes_style("whitegrid"):
Expand Down Expand Up @@ -380,7 +418,7 @@ def plot_cumulative_gains(self, path: str=None, dim: tuple=(12, 8)) -> plt.Figur
if path is not None:
plt.savefig(path, format="png", dpi=300, bbox_inches="tight")
plt.close()
return fig
return fig, ax

@staticmethod
def _find_optimal_cutoff(y_true: np.ndarray,
Expand Down Expand Up @@ -662,7 +700,7 @@ def _compute_qq_residuals(y_true: np.ndarray,
"residuals": df["z_res"].values,
})

def plot_predictions(self, path: str=None, dim: tuple=(12, 8)) -> plt.Figure:
def plot_predictions(self, path: str=None, dim: tuple=(12, 8)) -> tuple[plt.Figure, plt.Axes]:
"""Plot predictions from the model against actual values.
Parameters
Expand All @@ -671,6 +709,13 @@ def plot_predictions(self, path: str=None, dim: tuple=(12, 8)) -> plt.Figure:
Path to store the figure.
dim : tuple, optional
Tuple with width and length of the plot.
Retruns
-------
fig : plt.Figure
figure object containing the predictions vs the actual values
ax : plt.Axes
axes object linked to the figure
"""
if self.y_true is None and self.y_pred is None:
msg = ("This {} instance is not fitted yet. Call 'fit' with "
Expand All @@ -697,9 +742,9 @@ def plot_predictions(self, path: str=None, dim: tuple=(12, 8)) -> plt.Figure:
plt.savefig(path, format="png", dpi=300, bbox_inches="tight")

plt.close()
return fig
return fig, ax

def plot_qq(self, path: str=None, dim: tuple=(12, 8)) -> plt.Figure:
def plot_qq(self, path: str=None, dim: tuple=(12, 8)) -> tuple[plt.Figure, plt.Axes]:
"""Display a Q-Q plot from the standardized prediction residuals.
Parameters
Expand All @@ -708,6 +753,13 @@ def plot_qq(self, path: str=None, dim: tuple=(12, 8)) -> plt.Figure:
Path to store the figure.
dim : tuple, optional
Tuple with width and length of the plot.
Retruns
-------
fig : plt.Figure
figure object containing the QQ-plot
ax : plt.Axes
axes object linked to the figure
"""

if self.qq is None:
Expand Down Expand Up @@ -739,4 +791,4 @@ def plot_qq(self, path: str=None, dim: tuple=(12, 8)) -> plt.Figure:
plt.savefig(path, format="png", dpi=300, bbox_inches="tight")

plt.close()
return fig
return fig, ax
11 changes: 9 additions & 2 deletions cobra/evaluation/pigs_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def plot_incidence(pig_tables: pd.DataFrame,
variable: str,
model_type: str,
column_order: list=None,
dim: tuple=(12, 8)) -> plt.Figure:
dim: tuple=(12, 8)) -> tuple[plt.Figure, plt.Axes]:
"""Plots a Predictor Insights Graph (PIG), a graph in which the mean
target value is plotted for a number of bins constructed from a predictor
variable. When the target is a binary classification target,
Expand All @@ -130,6 +130,13 @@ def plot_incidence(pig_tables: pd.DataFrame,
on the PIG.
dim: tuple, default=(12, 8)
Optional tuple to configure the width and length of the plot.
Retruns
-------
fig : plt.Figure
figure object contining the PIG
ax : plt.Axes
axes object linked to the figure
"""
if model_type not in ["classification", "regression"]:
raise ValueError("An unexpected value was set for the model_type "
Expand Down Expand Up @@ -258,4 +265,4 @@ def plot_incidence(pig_tables: pd.DataFrame,
plt.margins(0.01)

plt.close()
return fig
return fig, ax
42 changes: 35 additions & 7 deletions cobra/evaluation/plotting_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

def plot_univariate_predictor_quality(df_metric: pd.DataFrame,
dim: tuple=(12, 8),
path: str=None) -> plt.Figure:
path: str=None) -> tuple[plt.Figure, plt.Axes]:
"""Plot univariate quality of the predictors.
Parameters
Expand All @@ -21,6 +21,13 @@ def plot_univariate_predictor_quality(df_metric: pd.DataFrame,
Width and length of the plot.
path : str, optional
Path to store the figure.
Retruns
-------
fig : plt.Figure
figure object containing the univariate predictor quality
ax : plt.Axes
axes object linked to the figure
"""

if "AUC selection" in df_metric.columns:
Expand Down Expand Up @@ -59,7 +66,7 @@ def plot_univariate_predictor_quality(df_metric: pd.DataFrame,
plt.gca().legend().set_title("")

plt.close()
return fig
return fig, ax

def plot_correlation_matrix(df_corr: pd.DataFrame,
dim: tuple=(12, 8),
Expand All @@ -74,6 +81,13 @@ def plot_correlation_matrix(df_corr: pd.DataFrame,
Width and length of the plot.
path : str, optional
Path to store the figure.
Retruns
-------
fig : plt.Figure
figure object containing the correlation matrix
ax : plt.Axes
axes object linked to the figure
"""
fig, ax = plt.subplots(figsize=dim)
ax = sns.heatmap(df_corr, cmap="Blues")
Expand All @@ -83,15 +97,15 @@ def plot_correlation_matrix(df_corr: pd.DataFrame,
plt.savefig(path, format="png", dpi=300, bbox_inches="tight")

plt.close()
return fig
return fig, ax

def plot_performance_curves(model_performance: pd.DataFrame,
dim: tuple=(12, 8),
path: str=None,
colors: dict={"train": "#0099bf",
"selection": "#ff9500",
"validation": "#8064a2"},
metric_name: str=None) -> plt.Figure:
metric_name: str=None) -> tuple[plt.Figure, plt.Axes]:
"""Plot performance curves generated by the forward feature selection
for the train-selection-validation sets.
Expand All @@ -110,6 +124,13 @@ def plot_performance_curves(model_performance: pd.DataFrame,
Name to indicate the metric used in model_performance.
Defaults to RMSE in case of regression and AUC in case of
classification.
Retruns
-------
fig : plt.Figure
figure object that contains the performance curves
ax : plt.Axes
axes object linked to the figure
"""

model_type = model_performance["model_type"][0]
Expand Down Expand Up @@ -162,12 +183,12 @@ def plot_performance_curves(model_performance: pd.DataFrame,
plt.savefig(path, format="png", dpi=300, bbox_inches="tight")

plt.close()
return fig
return fig, ax

def plot_variable_importance(df_variable_importance: pd.DataFrame,
title: str=None,
dim: tuple=(12, 8),
path: str=None) -> plt.Figure:
path: str=None) -> tuple[plt.Figure, plt.Axes]:
"""Plot variable importance of a given model.
Parameters
Expand All @@ -180,6 +201,13 @@ def plot_variable_importance(df_variable_importance: pd.DataFrame,
Width and length of the plot.
path : str, optional
Path to store the figure.
Retruns
-------
fig : plt.Figure
figure object containing the variable importance
ax : plt.Axes
axes object linked to the figure
"""
with sns.axes_style("ticks"):
fig, ax = plt.subplots(figsize=dim)
Expand All @@ -203,4 +231,4 @@ def plot_variable_importance(df_variable_importance: pd.DataFrame,
plt.savefig(path, format="png", dpi=300, bbox_inches="tight")

plt.close()
return fig
return fig, ax

0 comments on commit 4238080

Please sign in to comment.