Skip to content

Commit

Permalink
change set_theme to axes_style to style locally
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickleonardy committed Oct 24, 2023
1 parent a41d8c0 commit ec98ccf
Show file tree
Hide file tree
Showing 3 changed files with 287 additions and 287 deletions.
227 changes: 112 additions & 115 deletions cobra/evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,28 +178,27 @@ def plot_roc_curve(self, path: str=None, dim: tuple=(12, 8)):
auc = float(self.scalar_metrics.loc["AUC"])


sns.set_theme(style="whitegrid")
with sns.axes_style("whitegrid"):
fig, ax = plt.subplots(figsize=dim)

fig, ax = plt.subplots(figsize=dim)

ax.plot(self.roc_curve["fpr"],
self.roc_curve["tpr"],
color="cornflowerblue", linewidth=3,
label="ROC curve (area = {s:.3})".format(s=auc))
ax.plot(self.roc_curve["fpr"],
self.roc_curve["tpr"],
color="cornflowerblue", linewidth=3,
label="ROC curve (area = {s:.3})".format(s=auc))

ax.plot([0, 1], [0, 1], color="darkorange", linewidth=3,
linestyle="--", label="random selection")
ax.set_xlabel("False positive rate", fontsize=15)
ax.set_ylabel("True positive rate", fontsize=15)
ax.legend(loc="lower right")
ax.set_title("ROC curve", fontsize=20)
ax.plot([0, 1], [0, 1], color="darkorange", linewidth=3,
linestyle="--", label="random selection")
ax.set_xlabel("False positive rate", fontsize=15)
ax.set_ylabel("True positive rate", fontsize=15)
ax.legend(loc="lower right")
ax.set_title("ROC curve", fontsize=20)

ax.set_ylim([0, 1])
ax.set_ylim([0, 1])

if path:
plt.savefig(path, format="png", dpi=300, bbox_inches="tight")
if path:
plt.savefig(path, format="png", dpi=300, bbox_inches="tight")

plt.show()
plt.show()

def plot_confusion_matrix(self, path: str=None, dim: tuple=(12, 8),
labels: list=["0", "1"]):
Expand Down Expand Up @@ -257,36 +256,35 @@ def plot_cumulative_response_curve(self, path: str=None, dim: tuple=(12, 8)):
lifts = np.array(lifts)*inc_rate*100


sns.set_theme(style="ticks")

fig, ax = plt.subplots(figsize=dim)
with sns.axes_style("ticks"):
fig, ax = plt.subplots(figsize=dim)

plt.bar(x_labels[::-1], lifts, align="center",
color="cornflowerblue")
plt.ylabel("Response (%)", fontsize=15)
plt.xlabel("Decile", fontsize=15)
ax.set_xticks(x_labels)
ax.set_xticklabels(x_labels)
plt.bar(x_labels[::-1], lifts, align="center",
color="cornflowerblue")
plt.ylabel("Response (%)", fontsize=15)
plt.xlabel("Decile", fontsize=15)
ax.set_xticks(x_labels)
ax.set_xticklabels(x_labels)

plt.axhline(y=inc_rate*100, color="darkorange", linestyle="--",
xmin=0.05, xmax=0.95, linewidth=3, label="incidence")
plt.axhline(y=inc_rate*100, color="darkorange", linestyle="--",
xmin=0.05, xmax=0.95, linewidth=3, label="incidence")

# Legend
ax.legend(loc="upper right")
# Legend
ax.legend(loc="upper right")

# Set Axis - make them pretty
sns.despine(ax=ax, right=True, left=True)
# Set Axis - make them pretty
sns.despine(ax=ax, right=True, left=True)

# Remove white lines from the second axis
ax.grid(False)
# Remove white lines from the second axis
ax.grid(False)

# Description
ax.set_title("Cumulative Response curve", fontsize=20)
# Description
ax.set_title("Cumulative Response curve", fontsize=20)

if path is not None:
plt.savefig(path, format="png", dpi=300, bbox_inches="tight")
if path is not None:
plt.savefig(path, format="png", dpi=300, bbox_inches="tight")

plt.show()
plt.show()

def plot_lift_curve(self, path: str=None, dim: tuple=(12, 8)):
"""Plot lift per decile.
Expand All @@ -308,36 +306,35 @@ def plot_lift_curve(self, path: str=None, dim: tuple=(12, 8)):
x_labels, lifts, _ = self.lift_curve


sns.set_theme(style="ticks")

fig, ax = plt.subplots(figsize=dim)
with sns.axes_style("ticks"):
fig, ax = plt.subplots(figsize=dim)

plt.bar(x_labels[::-1], lifts, align="center",
color="cornflowerblue")
plt.ylabel("Lift", fontsize=15)
plt.xlabel("Decile", fontsize=15)
ax.set_xticks(x_labels)
ax.set_xticklabels(x_labels)
plt.bar(x_labels[::-1], lifts, align="center",
color="cornflowerblue")
plt.ylabel("Lift", fontsize=15)
plt.xlabel("Decile", fontsize=15)
ax.set_xticks(x_labels)
ax.set_xticklabels(x_labels)

plt.axhline(y=1, color="darkorange", linestyle="--",
xmin=0.05, xmax=0.95, linewidth=3, label="baseline")
plt.axhline(y=1, color="darkorange", linestyle="--",
xmin=0.05, xmax=0.95, linewidth=3, label="baseline")

# Legend
ax.legend(loc="upper right")
# Legend
ax.legend(loc="upper right")

# Set Axis - make them pretty
sns.despine(ax=ax, right=True, left=True)
# Set Axis - make them pretty
sns.despine(ax=ax, right=True, left=True)

# Remove white lines from the second axis
ax.grid(False)
# Remove white lines from the second axis
ax.grid(False)

# Description
ax.set_title("Cumulative Lift curve", fontsize=20)
# Description
ax.set_title("Cumulative Lift curve", fontsize=20)

if path is not None:
plt.savefig(path, format="png", dpi=300, bbox_inches="tight")
if path is not None:
plt.savefig(path, format="png", dpi=300, bbox_inches="tight")

plt.show()
plt.show()

def plot_cumulative_gains(self, path: str=None, dim: tuple=(12, 8)):
"""Plot cumulative gains per decile.
Expand All @@ -350,39 +347,39 @@ def plot_cumulative_gains(self, path: str=None, dim: tuple=(12, 8)):
Tuple with width and length of the plot.
"""

sns.set_theme(style="whitegrid")

fig, ax = plt.subplots(figsize=dim)

ax.plot(self.cumulative_gains[0]*100, self.cumulative_gains[1]*100,
color="cornflowerblue", linewidth=3,
label="cumulative gains")
ax.plot(ax.get_xlim(), ax.get_ylim(), linewidth=3,
ls="--", color="darkorange", label="random selection")
with sns.axes_style("whitegrid"):
fig, ax = plt.subplots(figsize=dim)

ax.set_title("Cumulative Gains curve", fontsize=20)
ax.plot(self.cumulative_gains[0]*100, self.cumulative_gains[1]*100,
color="cornflowerblue", linewidth=3,
label="cumulative gains")
ax.plot(ax.get_xlim(), ax.get_ylim(), linewidth=3,
ls="--", color="darkorange", label="random selection")

# Format axes
ax.set_xlim([0, 100])
ax.set_ylim([0, 100])
plt.ylabel("Gain", fontsize=15)
plt.xlabel("Percentage", fontsize=15)
ax.set_title("Cumulative Gains curve", fontsize=20)

# Format ticks
ticks_loc_y = ax.get_yticks().tolist()
ax.yaxis.set_major_locator(mticker.FixedLocator(ticks_loc_y))
ax.set_yticklabels(["{:3.0f}%".format(x) for x in ticks_loc_y])
# Format axes
ax.set_xlim([0, 100])
ax.set_ylim([0, 100])
plt.ylabel("Gain", fontsize=15)
plt.xlabel("Percentage", fontsize=15)

ticks_loc_x = ax.get_xticks().tolist()
ax.xaxis.set_major_locator(mticker.FixedLocator(ticks_loc_x))
ax.set_xticklabels(["{:3.0f}%".format(x) for x in ticks_loc_x])
# Format ticks
ticks_loc_y = ax.get_yticks().tolist()
ax.yaxis.set_major_locator(mticker.FixedLocator(ticks_loc_y))
ax.set_yticklabels(["{:3.0f}%".format(x) for x in ticks_loc_y])

# Legend
ax.legend(loc="lower right")
ticks_loc_x = ax.get_xticks().tolist()
ax.xaxis.set_major_locator(mticker.FixedLocator(ticks_loc_x))
ax.set_xticklabels(["{:3.0f}%".format(x) for x in ticks_loc_x])

if path is not None:
plt.savefig(path, format="png", dpi=300, bbox_inches="tight")
plt.show()
# Legend
ax.legend(loc="lower right")

if path is not None:
plt.savefig(path, format="png", dpi=300, bbox_inches="tight")
plt.show()

@staticmethod
def _find_optimal_cutoff(y_true: np.ndarray,
Expand Down Expand Up @@ -681,25 +678,25 @@ def plot_predictions(self, path: str=None, dim: tuple=(12, 8)):
y_true = self.y_true
y_pred = self.y_pred

sns.set_theme(style="whitegrid")


fig, ax = plt.subplots(figsize=dim)

with sns.axes_style("whitegrid"):
fig, ax = plt.subplots(figsize=dim)

x = np.arange(1, len(y_true)+1)
x = np.arange(1, len(y_true)+1)

ax.plot(x, y_true, ls="--", label="actuals", color="darkorange", linewidth=3)
ax.plot(x, y_pred, label="predictions", color="cornflowerblue", linewidth=3)
ax.plot(x, y_true, ls="--", label="actuals", color="darkorange", linewidth=3)
ax.plot(x, y_pred, label="predictions", color="cornflowerblue", linewidth=3)

ax.set_xlabel("Index", fontsize=15)
ax.set_ylabel("Value", fontsize=15)
ax.legend(loc="best")
ax.set_title("Predictions vs. Actuals", fontsize=20)
ax.set_xlabel("Index", fontsize=15)
ax.set_ylabel("Value", fontsize=15)
ax.legend(loc="best")
ax.set_title("Predictions vs. Actuals", fontsize=20)

if path:
plt.savefig(path, format="png", dpi=300, bbox_inches="tight")
if path:
plt.savefig(path, format="png", dpi=300, bbox_inches="tight")

plt.show()
plt.show()

def plot_qq(self, path: str=None, dim: tuple=(12, 8)):
"""Display a Q-Q plot from the standardized prediction residuals.
Expand All @@ -718,26 +715,26 @@ def plot_qq(self, path: str=None, dim: tuple=(12, 8)):

raise NotFittedError(msg.format(self.__class__.__name__))

sns.set_theme(style="whitegrid")

fig, ax = plt.subplots(figsize=dim)
with sns.axes_style("whitegrid"):
fig, ax = plt.subplots(figsize=dim)

x = self.qq["quantiles"]
y = self.qq["residuals"]
x = self.qq["quantiles"]
y = self.qq["residuals"]

ax.plot(x, x, ls="--", label="perfect model", color="darkorange", linewidth=3)
ax.plot(x, y, label="current model", color="cornflowerblue", linewidth=3)
ax.plot(x, x, ls="--", label="perfect model", color="darkorange", linewidth=3)
ax.plot(x, y, label="current model", color="cornflowerblue", linewidth=3)

ax.set_xlabel("Theoretical quantiles", fontsize=15)
ax.set_xticks(range(int(np.floor(min(x))), int(np.ceil(max(x[x < float("inf")])))+1, 1))
ax.set_xlabel("Theoretical quantiles", fontsize=15)
ax.set_xticks(range(int(np.floor(min(x))), int(np.ceil(max(x[x < float("inf")])))+1, 1))

ax.set_ylabel("Standardized residuals", fontsize=15)
ax.set_yticks(range(int(np.floor(min(y))), int(np.ceil(max(y[x < float("inf")])))+1, 1))
ax.set_ylabel("Standardized residuals", fontsize=15)
ax.set_yticks(range(int(np.floor(min(y))), int(np.ceil(max(y[x < float("inf")])))+1, 1))

ax.legend(loc="best")
ax.set_title("Q-Q plot", fontsize=20)
ax.legend(loc="best")
ax.set_title("Q-Q plot", fontsize=20)

if path:
plt.savefig(path, format="png", dpi=300, bbox_inches="tight")
if path:
plt.savefig(path, format="png", dpi=300, bbox_inches="tight")

plt.show()
plt.show()
Loading

0 comments on commit ec98ccf

Please sign in to comment.