Skip to content

Commit

Permalink
add CDFs
Browse files Browse the repository at this point in the history
  • Loading branch information
thayeral committed Oct 17, 2023
1 parent 2d256a0 commit ab6cb9c
Showing 1 changed file with 42 additions and 3 deletions.
45 changes: 42 additions & 3 deletions src/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,19 @@ def plot_heatmap_p2v(
bins=25,
color='dimgrey'
)

ax1t = ax1.twinx()
ax1t = sns.ecdfplot(
ax=ax1t,
data=x,
x=hist_col,
stat='proportion',
color='C2',
linestyle=':',
)
ax1t.tick_params(axis='y', labelcolor='C2', color='C2')
ax1t.set_ylabel('CDF', color='C2')

ax1.axvline(np.median(x[hist_col]), c='C0', ls='-', lw=2, label='Median')
ax1.axvline(np.mean(x[hist_col]), c='C1', ls='--', lw=2, label='Mean')
ax1.set_ylim(0, 30)
Expand Down Expand Up @@ -624,6 +637,18 @@ def plot_heatmap_p2v(
bins=25,
color='dimgrey'
)
ax2t = ax2.twinx()
ax2t = sns.ecdfplot(
ax=ax2t,
data=x,
x=hist_col,
stat='proportion',
color='C2',
linestyle=':',
)
ax2t.tick_params(axis='y', labelcolor='C2', color='C2')
ax2t.set_ylabel('CDF', color='C2')

ax2.axvline(np.median(x[hist_col]), c='C0', ls='-', lw=2)
ax2.axvline(np.mean(x[hist_col]), c='C1', ls='--', lw=2)
ax2.set_ylim(0, 30)
Expand Down Expand Up @@ -664,6 +689,19 @@ def plot_heatmap_p2v(
bins=25,
color='dimgrey'
)

ax3t = ax3.twinx()
ax3t = sns.ecdfplot(
ax=ax3t,
data=x,
x=hist_col,
stat='proportion',
color='C2',
linestyle=':',
)
ax3t.tick_params(axis='y', labelcolor='C2', color='C2')
ax3t.set_ylabel('CDF', color='C2')

ax3.axvline(np.median(x[hist_col]), c='C0', ls='-', lw=2)
ax3.axvline(np.mean(x[hist_col]), c='C1', ls='--', lw=2)
ax3.set_ylim(0, 30)
Expand Down Expand Up @@ -694,11 +732,11 @@ def plot_heatmap_p2v(

ax1.legend(frameon=False, ncol=1, loc='upper left')
ax1.yaxis.set_major_formatter(PercentFormatter(decimals=0))
ax1.grid(True, which="both", axis='both', lw=.25, ls='--', zorder=0)
ax1t.grid(True, which="both", axis='both', lw=.25, ls='--', zorder=0)
ax2.yaxis.set_major_formatter(PercentFormatter(decimals=0))
ax2.grid(True, which="both", axis='both', lw=.25, ls='--', zorder=0)
ax2t.grid(True, which="both", axis='both', lw=.25, ls='--', zorder=0)
ax3.yaxis.set_major_formatter(PercentFormatter(decimals=0))
ax3.grid(True, which="both", axis='both', lw=.25, ls='--', zorder=0)
ax3t.grid(True, which="both", axis='both', lw=.25, ls='--', zorder=0)

elif label == f'Number of iterations':

Expand All @@ -716,6 +754,7 @@ def plot_heatmap_p2v(
color='dimgrey'
)


ax1.axvline(np.median(x[hist_col]), c='C0', ls='-', lw=2)
ax1.axvline(np.mean(x[hist_col]), c='C1', ls='--', lw=2)
ax1.axvline(np.median(x[hist_col]), c='C0', ls='-', lw=2, label='Median')
Expand Down

0 comments on commit ab6cb9c

Please sign in to comment.