Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix the bug by taking the style directly from sns #186

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
198 changes: 198 additions & 0 deletions CHANGELOG.md

Large diffs are not rendered by default.

127 changes: 107 additions & 20 deletions cobra/evaluation/evaluator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

import numpy as np
import pandas as pd

from typing import Tuple
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import seaborn as sns
Expand Down Expand Up @@ -158,7 +158,7 @@ def _compute_scalar_metrics(y_true: np.ndarray,
lift_at=lift_at), 2)
})

def plot_roc_curve(self, path: str=None, dim: tuple=(12, 8)):
def plot_roc_curve(self, path: str=None, dim: tuple=(12, 8), show: bool=True) -> Tuple[plt.Figure, plt.Axes]:
"""Plot ROC curve of the model.

Parameters
Expand All @@ -167,6 +167,15 @@ def plot_roc_curve(self, path: str=None, dim: tuple=(12, 8)):
Path to store the figure.
dim : tuple, optional
Tuple with width and length of the plot.
show : bool, optional
Whether to show the plot or not.

Retruns
-------
fig : plt.Figure
figure object containing the ROC curve
ax : plt.Axes
axes object linked to the figure
"""

if self.roc_curve is None:
Expand All @@ -177,7 +186,7 @@ def plot_roc_curve(self, path: str=None, dim: tuple=(12, 8)):

auc = float(self.scalar_metrics.loc["AUC"])

with plt.style.context("seaborn-whitegrid"):
with sns.axes_style("whitegrid"):

fig, ax = plt.subplots(figsize=dim)

Expand All @@ -197,11 +206,13 @@ def plot_roc_curve(self, path: str=None, dim: tuple=(12, 8)):

if path:
plt.savefig(path, format="png", dpi=300, bbox_inches="tight")

plt.show()
if show:
plt.show()
plt.close()
return fig, ax

def plot_confusion_matrix(self, path: str=None, dim: tuple=(12, 8),
labels: list=["0", "1"]):
labels: list=["0", "1"], show: bool=True) -> Tuple[plt.Figure, plt.Axes]:
"""Plot the confusion matrix.

Parameters
Expand All @@ -212,6 +223,15 @@ def plot_confusion_matrix(self, path: str=None, dim: tuple=(12, 8),
Tuple with width and length of the plot.
labels : list, optional
Optional list of labels, default "0" and "1".
show : bool, optional
Whether to show the plot or not.

Retruns
-------
fig : plt.Figure
figure object containing confusion matrix
ax : plt.Axes
axes object linked to the figure
"""

if self.confusion_matrix is None:
Expand All @@ -232,9 +252,16 @@ def plot_confusion_matrix(self, path: str=None, dim: tuple=(12, 8),
if path:
plt.savefig(path, format="png", dpi=300, bbox_inches="tight")

plt.show()

def plot_cumulative_response_curve(self, path: str=None, dim: tuple=(12, 8)):
if show:
plt.show()
plt.close()
return fig, ax

def plot_cumulative_response_curve(self,
path: str=None,
dim: tuple=(12, 8),
show: bool=True
) -> Tuple[plt.Figure, plt.Axes]:
"""Plot cumulative response curve.

Parameters
Expand All @@ -243,6 +270,15 @@ def plot_cumulative_response_curve(self, path: str=None, dim: tuple=(12, 8)):
Path to store the figure.
dim : tuple, optional
Tuple with width and length of the plot.
show : bool, optional
Whether to show the plot or not.

Retruns
-------
fig : plt.Figure
figure object containing the cumulative response curve
ax : plt.Axes
axes object linked to the figure
"""

if self.lift_curve is None:
Expand All @@ -255,7 +291,7 @@ def plot_cumulative_response_curve(self, path: str=None, dim: tuple=(12, 8)):

lifts = np.array(lifts)*inc_rate*100

with plt.style.context("seaborn-ticks"):
with sns.axes_style("ticks"):
fig, ax = plt.subplots(figsize=dim)

plt.bar(x_labels[::-1], lifts, align="center",
Expand Down Expand Up @@ -283,9 +319,12 @@ def plot_cumulative_response_curve(self, path: str=None, dim: tuple=(12, 8)):
if path is not None:
plt.savefig(path, format="png", dpi=300, bbox_inches="tight")

if show:
plt.show()
plt.close()
return fig, ax

def plot_lift_curve(self, path: str=None, dim: tuple=(12, 8)):
def plot_lift_curve(self, path: str=None, dim: tuple=(12, 8), show: bool=True) -> Tuple[plt.Figure, plt.Axes]:
"""Plot lift per decile.

Parameters
Expand All @@ -294,6 +333,15 @@ def plot_lift_curve(self, path: str=None, dim: tuple=(12, 8)):
Path to store the figure.
dim : tuple, optional
Tuple with width and length of the plot.
show : bool, optional
Whether to show the plot or not.

Retruns
-------
fig : plt.Figure
figure object the livt curve
ax : plt.Axes
axes object linked to the figure
"""

if self.lift_curve is None:
Expand All @@ -304,7 +352,7 @@ def plot_lift_curve(self, path: str=None, dim: tuple=(12, 8)):

x_labels, lifts, _ = self.lift_curve

with plt.style.context("seaborn-ticks"):
with sns.axes_style("ticks"):
fig, ax = plt.subplots(figsize=dim)

plt.bar(x_labels[::-1], lifts, align="center",
Expand Down Expand Up @@ -332,9 +380,12 @@ def plot_lift_curve(self, path: str=None, dim: tuple=(12, 8)):
if path is not None:
plt.savefig(path, format="png", dpi=300, bbox_inches="tight")

if show:
plt.show()
plt.close()
return fig, ax

def plot_cumulative_gains(self, path: str=None, dim: tuple=(12, 8)):
def plot_cumulative_gains(self, path: str=None, dim: tuple=(12, 8), show: bool=True) -> Tuple[plt.Figure, plt.Axes]:
"""Plot cumulative gains per decile.

Parameters
Expand All @@ -343,9 +394,18 @@ def plot_cumulative_gains(self, path: str=None, dim: tuple=(12, 8)):
Path to store the figure.
dim : tuple, optional
Tuple with width and length of the plot.
show : bool, optional
Whether to show the plot or not.

Retruns
-------
fig : plt.Figure
figure object containing the cumulative gains curve
ax : plt.Axes
axes object linked to the figure
"""

with plt.style.context("seaborn-whitegrid"):
with sns.axes_style("whitegrid"):
fig, ax = plt.subplots(figsize=dim)

ax.plot(self.cumulative_gains[0]*100, self.cumulative_gains[1]*100,
Expand Down Expand Up @@ -376,7 +436,10 @@ def plot_cumulative_gains(self, path: str=None, dim: tuple=(12, 8)):

if path is not None:
plt.savefig(path, format="png", dpi=300, bbox_inches="tight")
if show:
plt.show()
plt.close()
return fig, ax

@staticmethod
def _find_optimal_cutoff(y_true: np.ndarray,
Expand Down Expand Up @@ -658,7 +721,7 @@ def _compute_qq_residuals(y_true: np.ndarray,
"residuals": df["z_res"].values,
})

def plot_predictions(self, path: str=None, dim: tuple=(12, 8)):
def plot_predictions(self, path: str=None, dim: tuple=(12, 8), show: bool=True) -> Tuple[plt.Figure, plt.Axes]:
"""Plot predictions from the model against actual values.

Parameters
Expand All @@ -667,6 +730,15 @@ def plot_predictions(self, path: str=None, dim: tuple=(12, 8)):
Path to store the figure.
dim : tuple, optional
Tuple with width and length of the plot.
show : bool, optional
Whether to show the plot or not.

Retruns
-------
fig : plt.Figure
figure object containing the predictions vs the actual values
ax : plt.Axes
axes object linked to the figure
"""
if self.y_true is None and self.y_pred is None:
msg = ("This {} instance is not fitted yet. Call 'fit' with "
Expand All @@ -675,7 +747,7 @@ def plot_predictions(self, path: str=None, dim: tuple=(12, 8)):
y_true = self.y_true
y_pred = self.y_pred

with plt.style.context("seaborn-whitegrid"):
with sns.axes_style("whitegrid"):

fig, ax = plt.subplots(figsize=dim)

Expand All @@ -692,9 +764,12 @@ def plot_predictions(self, path: str=None, dim: tuple=(12, 8)):
if path:
plt.savefig(path, format="png", dpi=300, bbox_inches="tight")

plt.show()
if show:
plt.show()
plt.close()
return fig, ax

def plot_qq(self, path: str=None, dim: tuple=(12, 8)):
def plot_qq(self, path: str=None, dim: tuple=(12, 8), show: bool=True) -> Tuple[plt.Figure, plt.Axes]:
"""Display a Q-Q plot from the standardized prediction residuals.

Parameters
Expand All @@ -703,6 +778,15 @@ def plot_qq(self, path: str=None, dim: tuple=(12, 8)):
Path to store the figure.
dim : tuple, optional
Tuple with width and length of the plot.
show : bool, optional
Whether to show the plot or not.

Retruns
-------
fig : plt.Figure
figure object containing the QQ-plot
ax : plt.Axes
axes object linked to the figure
"""

if self.qq is None:
Expand All @@ -711,7 +795,7 @@ def plot_qq(self, path: str=None, dim: tuple=(12, 8)):

raise NotFittedError(msg.format(self.__class__.__name__))

with plt.style.context("seaborn-whitegrid"):
with sns.axes_style("whitegrid"):

fig, ax = plt.subplots(figsize=dim)

Expand All @@ -733,4 +817,7 @@ def plot_qq(self, path: str=None, dim: tuple=(12, 8)):
if path:
plt.savefig(path, format="png", dpi=300, bbox_inches="tight")

plt.show()
if show:
plt.show()
plt.close()
return fig, ax
23 changes: 18 additions & 5 deletions cobra/evaluation/pigs_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import seaborn as sns
import numpy as np
from matplotlib.ticker import FuncFormatter

from typing import Tuple
import cobra.utils as utils

def generate_pig_tables(basetable: pd.DataFrame,
Expand Down Expand Up @@ -107,7 +107,8 @@ def plot_incidence(pig_tables: pd.DataFrame,
variable: str,
model_type: str,
column_order: list=None,
dim: tuple=(12, 8)):
dim: tuple=(12, 8),
show: bool=True) -> Tuple[plt.Figure, plt.Axes]:
"""Plots a Predictor Insights Graph (PIG), a graph in which the mean
target value is plotted for a number of bins constructed from a predictor
variable. When the target is a binary classification target,
Expand All @@ -130,6 +131,15 @@ def plot_incidence(pig_tables: pd.DataFrame,
on the PIG.
dim: tuple, default=(12, 8)
Optional tuple to configure the width and length of the plot.
show: bool, default=True
Whether to show the plot or not.

Retruns
-------
fig : plt.Figure
figure object contining the PIG
ax : plt.Axes
axes object linked to the figure
"""
if model_type not in ["classification", "regression"]:
raise ValueError("An unexpected value was set for the model_type "
Expand All @@ -154,7 +164,7 @@ def plot_incidence(pig_tables: pd.DataFrame,
df_plot.sort_values(by=['avg_target'], ascending=False, inplace=True)
df_plot.reset_index(inplace=True)

with plt.style.context("seaborn-ticks"):
with sns.axes_style("ticks"):
fig, ax = plt.subplots(figsize=dim)

# --------------------------
Expand Down Expand Up @@ -257,5 +267,8 @@ def plot_incidence(pig_tables: pd.DataFrame,
plt.tight_layout()
plt.margins(0.01)

# Show
plt.show()
if show:
plt.show()
plt.close()

return fig, ax
Loading
Loading