Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix typechecking with matplotlib 3.8.0 #969

Merged
merged 3 commits into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 15 additions & 12 deletions alibi/explainers/pd_variance.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,14 +458,14 @@ def _plot_hbar(exp_values: np.ndarray,

if isinstance(ax, plt.Axes) and n_targets != 1:
ax.set_axis_off() # treat passed axis as a canvas for subplots
fig = ax.figure
fig = ax.figure # type: ignore[assignment]
n_cols = min(n_cols, n_targets)
n_rows = math.ceil(n_targets / n_cols)
axes = np.empty((n_rows, n_cols), dtype=object)
axes_ravel = axes.ravel()
gs = GridSpec(n_rows, n_cols)

for i, spec in enumerate(list(gs)[:n_targets]):
for i, spec in enumerate(list(gs)[:n_targets]): # type: ignore[call-overload]
axes_ravel[i] = fig.add_subplot(spec)
else:
if isinstance(ax, plt.Axes):
Expand All @@ -489,15 +489,15 @@ def _plot_hbar(exp_values: np.ndarray,
default_bar_kw = {'align': 'center'}
bar_kw = default_bar_kw if bar_kw is None else {**default_bar_kw, **bar_kw}

ax.barh(y=y, width=width, **bar_kw)
ax.set_yticks(y)
ax.set_yticklabels(y_labels)
ax.invert_yaxis() # labels read top-to-bottom
ax.set_xlabel(title)
ax.set_title(target_name)
ax.barh(y=y, width=width, **bar_kw) # type: ignore[union-attr,arg-type]
ax.set_yticks(y) # type: ignore[union-attr]
ax.set_yticklabels(y_labels) # type: ignore[union-attr]
ax.invert_yaxis() # type: ignore[union-attr] # labels read top-to-bottom
ax.set_xlabel(title) # type: ignore[union-attr]
ax.set_title(target_name) # type: ignore[union-attr]

fig.set(**fig_kw)
return axes
return axes # type: ignore[return-value]


def _plot_feature_importance(exp: Explanation,
Expand Down Expand Up @@ -697,15 +697,18 @@ def _plot_feature_interaction(exp: Explanation,
# set title for the 2-way pdp
ax = axes_flatten[step * i]
(ft_name1, ft_name2) = feature_names[features[i]] # type: Tuple[str, str] # type: ignore[misc]
ax.set_title('inter({},{}) = {:.3f}'.format(ft_name1, ft_name2, feature_interaction[i]))
ax.set_title('inter({},{}) = {:.3f}'.format(ft_name1, ft_name2, # type: ignore[union-attr]
feature_interaction[i]))

# set title for the first conditional importance plot
ax = axes.flatten()[step * i + 1]
ax.set_title('inter({}|{}) = {:.3f}'.format(ft_name2, ft_name1, conditional_importance[i][0][target_idx]))
ax.set_title('inter({}|{}) = {:.3f}'.format(ft_name2, ft_name1, # type: ignore[union-attr]
conditional_importance[i][0][target_idx]))

# set title for the second conditional importance plot
ax = axes.flatten()[step * i + 2]
ax.set_title('inter({}|{}) = {:.3f}'.format(ft_name1, ft_name2, conditional_importance[i][1][target_idx]))
ax.set_title('inter({}|{}) = {:.3f}'.format(ft_name1, ft_name2, # type: ignore[union-attr]
conditional_importance[i][1][target_idx]))

return axes

Expand Down
8 changes: 4 additions & 4 deletions alibi/explainers/tests/test_partial_dependence.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,15 +617,15 @@ def assert_deciles(xsegments: Optional[List[np.ndarray]] = None,

def assert_pd_values(feature_values: np.ndarray, pd_values: np.ndarray, line: plt.Line2D):
""" Checks if the plotted pd values are correct. """
x, y = line.get_xydata().T
x, y = line.get_xydata().T # type: ignore[union-attr]
assert np.allclose(x, feature_values)
assert np.allclose(y, pd_values)


def assert_ice_values(feature_values: np.ndarray, ice_values: np.ndarray, lines: List[plt.Line2D]):
""" Checks if the plotted ice values are correct. """
for ice_vals, line in zip(ice_values, lines):
x, y = line.get_xydata().T
x, y = line.get_xydata().T # type: ignore[union-attr]
assert np.allclose(x, feature_values)
assert np.allclose(y, ice_vals)

Expand All @@ -637,14 +637,14 @@ def assert_pd_ice_values(feature: int, target_idx: int, kind: str, explanation:
line = ax.lines[0] if kind == 'average' else ax.lines[2]
assert_pd_values(feature_values=explanation.data['feature_values'][feature],
pd_values=explanation.data['pd_values'][feature][target_idx],
line=line)
line=line) # type: ignore[arg-type]

if kind in ['individual', 'both']:
# check the ice values
lines = ax.lines if kind == 'individual' else ax.lines[:2]
assert_ice_values(feature_values=explanation.data['feature_values'][feature],
ice_values=explanation.data['ice_values'][feature][target_idx],
lines=lines)
lines=lines) # type: ignore[arg-type]


@pytest.mark.parametrize('explanation', ['average', 'individual', 'both'], indirect=True)
Expand Down
2 changes: 1 addition & 1 deletion alibi/prototypes/protoselect.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ def _imscatterplot(x: np.ndarray,
ax.set_xticks([])
ax.set_yticks([])
else:
fig = ax.figure
fig = ax.figure # type: ignore[assignment]

resized_imgs = [resize(images[i], image_size) for i in range(len(images))]
imgs = [OffsetImage(img, zoom=zoom[i], cmap='gray') for i, img in enumerate(resized_imgs)] # type: ignore
Expand Down
24 changes: 12 additions & 12 deletions alibi/utils/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.figure import Figure
from matplotlib.pyplot import axis, figure
from matplotlib.pyplot import Axes, Figure
from mpl_toolkits.axes_grid1 import make_axes_locatable
from numpy import ndarray

Expand Down Expand Up @@ -83,15 +82,15 @@ def visualize_image_attr(
original_image: Union[None, ndarray] = None,
method: str = "heat_map",
sign: str = "absolute_value",
plt_fig_axis: Union[None, Tuple[figure, axis]] = None,
plt_fig_axis: Union[None, Tuple[Figure, Axes]] = None,
outlier_perc: Union[int, float] = 2,
cmap: Union[None, str] = None,
alpha_overlay: float = 0.5,
show_colorbar: bool = False,
title: Union[None, str] = None,
fig_size: Tuple[int, int] = (6, 6),
use_pyplot: bool = True,
):
) -> Tuple[Figure, Axes]:
"""
Visualizes attribution for a given image by normalizing attribution values of the desired sign
(``'positive'`` | ``'negative'`` | ``'absolute_value'`` | ``'all'``) and displaying them using the desired mode
Expand Down Expand Up @@ -163,10 +162,10 @@ def visualize_image_attr(
Returns
-------
2-element tuple of consisting of
- `figure` : ``matplotlib.pyplot.figure`` - Figure object on which visualization is created. If `plt_fig_axis` \
- `figure` : ``matplotlib.pyplot.Figure`` - Figure object on which visualization is created. If `plt_fig_axis` \
argument is given, this is the same figure provided.

- `axis` : ``matplotlib.pyplot.axis`` - Axis object on which visualization is created. If `plt_fig_axis` argument \
- `axis` : ``matplotlib.pyplot.Axes`` - Axes object on which visualization is created. If `plt_fig_axis` argument \
is given, this is the same axis provided.

"""
Expand All @@ -178,7 +177,7 @@ def visualize_image_attr(
plt_fig, plt_axis = plt.subplots(figsize=fig_size)
else:
plt_fig = Figure(figsize=fig_size)
plt_axis = plt_fig.subplots()
plt_axis = plt_fig.subplots() # type: ignore[assignment]

if original_image is not None:
if np.max(original_image) <= 1.0:
Expand All @@ -204,7 +203,7 @@ def visualize_image_attr(

# Set default colormap and bounds based on sign.
if VisualizeSign[sign] == VisualizeSign.all:
default_cmap = LinearSegmentedColormap.from_list(
default_cmap: Union[LinearSegmentedColormap, str] = LinearSegmentedColormap.from_list(
"RdWhGn", ["red", "white", "green"]
)
vmin, vmax = -1, 1
Expand All @@ -219,7 +218,7 @@ def visualize_image_attr(
vmin, vmax = 0, 1
else:
raise AssertionError("Visualize Sign type is not valid.")
cmap = cmap if cmap is not None else default_cmap
cmap = cmap if cmap is not None else default_cmap # type: ignore[assignment]

# Show appropriate image visualization.
if ImageVisualizationMethod[method] == ImageVisualizationMethod.heat_map:
Expand Down Expand Up @@ -339,7 +338,7 @@ def _create_heatmap(data: np.ndarray,
if cbar:
if cbar_ax is None:
cbar_ax = ax
cbar_obj = ax.figure.colorbar(im, ax=cbar_ax, **cbar_kws)
cbar_obj = ax.figure.colorbar(im, ax=cbar_ax, **cbar_kws) # type: ignore[union-attr]
cbar_obj.ax.set_ylabel(cbar_label, rotation=-90, va="bottom")

# show all ticks and label them with the respective list entries.
Expand All @@ -355,7 +354,7 @@ def _create_heatmap(data: np.ndarray,
plt.setp(ax.get_xticklabels(), rotation=90, ha="right", rotation_mode="anchor")

# turn spines off and create white grid.
ax.spines[:].set_visible(False)
ax.spines[:].set_visible(False) # type: ignore[call-overload]

ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True)
ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True)
Expand Down Expand Up @@ -401,6 +400,7 @@ def _annotate_heatmap(im: matplotlib.image.AxesImage,

if not isinstance(data, (list, np.ndarray)):
data = im.get_array()
assert isinstance(data, np.ndarray) # for mypy, since get_array() can sometimes be None

# normalize the threshold to the images color range.
if threshold is not None:
Expand All @@ -422,7 +422,7 @@ def _annotate_heatmap(im: matplotlib.image.AxesImage,
for i in range(data.shape[0]):
for j in range(data.shape[1]):
kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)])
text = im.axes.text(j, i, fmt(data[i, j], None), **kw)
text = im.axes.text(j, i, fmt(data[i, j], None), **kw) # type: ignore[arg-type]
texts.append(text)

return texts
Expand Down