From c6dd836b80d79d5421f17727274e10366261a6a2 Mon Sep 17 00:00:00 2001 From: Janis Klaise Date: Wed, 27 Sep 2023 11:58:39 +0100 Subject: [PATCH 1/3] Fix typechecking with matplotlib 3.8.0 --- .../anchors/anchor_tabular_distributed.py | 2 +- alibi/explainers/pd_variance.py | 24 +++++++++---------- .../tests/test_partial_dependence.py | 8 +++---- alibi/prototypes/protoselect.py | 2 +- alibi/utils/visualization.py | 24 +++++++++---------- 5 files changed, 30 insertions(+), 30 deletions(-) diff --git a/alibi/explainers/anchors/anchor_tabular_distributed.py b/alibi/explainers/anchors/anchor_tabular_distributed.py index 56b7b9f0f..1481204c7 100644 --- a/alibi/explainers/anchors/anchor_tabular_distributed.py +++ b/alibi/explainers/anchors/anchor_tabular_distributed.py @@ -235,7 +235,7 @@ def fit(self, # type: ignore[override] d_samplers = [] for sampler in samplers: d_samplers.append( - ray.remote(RemoteSampler).remote( # type: ignore[call-arg] + ray.remote(RemoteSampler).remote( *(train_data_id, d_train_data_id, sampler) ) ) diff --git a/alibi/explainers/pd_variance.py b/alibi/explainers/pd_variance.py index 1d7ae8aae..4be02b42f 100644 --- a/alibi/explainers/pd_variance.py +++ b/alibi/explainers/pd_variance.py @@ -458,14 +458,14 @@ def _plot_hbar(exp_values: np.ndarray, if isinstance(ax, plt.Axes) and n_targets != 1: ax.set_axis_off() # treat passed axis as a canvas for subplots - fig = ax.figure + fig = ax.figure # type: ignore[assignment] n_cols = min(n_cols, n_targets) n_rows = math.ceil(n_targets / n_cols) axes = np.empty((n_rows, n_cols), dtype=object) axes_ravel = axes.ravel() gs = GridSpec(n_rows, n_cols) - for i, spec in enumerate(list(gs)[:n_targets]): + for i, spec in enumerate(list(gs)[:n_targets]): # type: ignore[call-overload] axes_ravel[i] = fig.add_subplot(spec) else: if isinstance(ax, plt.Axes): @@ -489,15 +489,15 @@ def _plot_hbar(exp_values: np.ndarray, default_bar_kw = {'align': 'center'} bar_kw = default_bar_kw if bar_kw is None else {**default_bar_kw, **bar_kw} - ax.barh(y=y, width=width, **bar_kw) - ax.set_yticks(y) - ax.set_yticklabels(y_labels) - ax.invert_yaxis() # labels read top-to-bottom - ax.set_xlabel(title) - ax.set_title(target_name) + ax.barh(y=y, width=width, **bar_kw) # type: ignore[union-attr,arg-type] + ax.set_yticks(y) # type: ignore[union-attr] + ax.set_yticklabels(y_labels) # type: ignore[union-attr] + ax.invert_yaxis() # type: ignore[union-attr] # labels read top-to-bottom + ax.set_xlabel(title) # type: ignore[union-attr] + ax.set_title(target_name) # type: ignore[union-attr] fig.set(**fig_kw) - return axes + return axes # type: ignore[return-value] def _plot_feature_importance(exp: Explanation, @@ -697,15 +697,15 @@ def _plot_feature_interaction(exp: Explanation, # set title for the 2-way pdp ax = axes_flatten[step * i] (ft_name1, ft_name2) = feature_names[features[i]] # type: Tuple[str, str] # type: ignore[misc] - ax.set_title('inter({},{}) = {:.3f}'.format(ft_name1, ft_name2, feature_interaction[i])) + ax.set_title('inter({},{}) = {:.3f}'.format(ft_name1, ft_name2, feature_interaction[i])) # type: ignore[union-attr] # set title for the first conditional importance plot ax = axes.flatten()[step * i + 1] - ax.set_title('inter({}|{}) = {:.3f}'.format(ft_name2, ft_name1, conditional_importance[i][0][target_idx])) + ax.set_title('inter({}|{}) = {:.3f}'.format(ft_name2, ft_name1, conditional_importance[i][0][target_idx])) # type: ignore[union-attr] # set title for the second conditional importance plot ax = axes.flatten()[step * i + 2] - ax.set_title('inter({}|{}) = {:.3f}'.format(ft_name1, ft_name2, conditional_importance[i][1][target_idx])) + ax.set_title('inter({}|{}) = {:.3f}'.format(ft_name1, ft_name2, conditional_importance[i][1][target_idx])) # type: ignore[union-attr] return axes diff --git a/alibi/explainers/tests/test_partial_dependence.py b/alibi/explainers/tests/test_partial_dependence.py index 7b42d6c28..7c75cff50 100644 --- a/alibi/explainers/tests/test_partial_dependence.py +++ b/alibi/explainers/tests/test_partial_dependence.py @@ -617,7 +617,7 @@ def assert_deciles(xsegments: Optional[List[np.ndarray]] = None, def assert_pd_values(feature_values: np.ndarray, pd_values: np.ndarray, line: plt.Line2D): """ Checks if the plotted pd values are correct. """ - x, y = line.get_xydata().T + x, y = line.get_xydata().T # type: ignore[union-attr] assert np.allclose(x, feature_values) assert np.allclose(y, pd_values) @@ -625,7 +625,7 @@ def assert_pd_values(feature_values: np.ndarray, pd_values: np.ndarray, line: pl def assert_ice_values(feature_values: np.ndarray, ice_values: np.ndarray, lines: List[plt.Line2D]): """ Checks if the plotted ice values are correct. """ for ice_vals, line in zip(ice_values, lines): - x, y = line.get_xydata().T + x, y = line.get_xydata().T # type: ignore[union-attr] assert np.allclose(x, feature_values) assert np.allclose(y, ice_vals) @@ -637,14 +637,14 @@ def assert_pd_ice_values(feature: int, target_idx: int, kind: str, explanation: line = ax.lines[0] if kind == 'average' else ax.lines[2] assert_pd_values(feature_values=explanation.data['feature_values'][feature], pd_values=explanation.data['pd_values'][feature][target_idx], - line=line) + line=line) # type: ignore[arg-type] if kind in ['individual', 'both']: # check the ice values lines = ax.lines if kind == 'individual' else ax.lines[:2] assert_ice_values(feature_values=explanation.data['feature_values'][feature], ice_values=explanation.data['ice_values'][feature][target_idx], - lines=lines) + lines=lines) # type: ignore[arg-type] @pytest.mark.parametrize('explanation', ['average', 'individual', 'both'], indirect=True) diff --git a/alibi/prototypes/protoselect.py b/alibi/prototypes/protoselect.py index d9e66f243..8e5ddaa21 100644 --- a/alibi/prototypes/protoselect.py +++ b/alibi/prototypes/protoselect.py @@ -530,7 +530,7 @@ def _imscatterplot(x: np.ndarray, ax.set_xticks([]) ax.set_yticks([]) else: - fig = ax.figure + fig = ax.figure # type: ignore[assignment] resized_imgs = [resize(images[i], image_size) for i in range(len(images))] imgs = [OffsetImage(img, zoom=zoom[i], cmap='gray') for i, img in enumerate(resized_imgs)] # type: ignore diff --git a/alibi/utils/visualization.py b/alibi/utils/visualization.py index 0a5cecd8b..83f681673 100644 --- a/alibi/utils/visualization.py +++ b/alibi/utils/visualization.py @@ -6,8 +6,7 @@ import numpy as np from matplotlib import pyplot as plt from matplotlib.colors import LinearSegmentedColormap -from matplotlib.figure import Figure -from matplotlib.pyplot import axis, figure +from matplotlib.pyplot import Axes, Figure from mpl_toolkits.axes_grid1 import make_axes_locatable from numpy import ndarray @@ -83,7 +82,7 @@ def visualize_image_attr( original_image: Union[None, ndarray] = None, method: str = "heat_map", sign: str = "absolute_value", - plt_fig_axis: Union[None, Tuple[figure, axis]] = None, + plt_fig_axis: Union[None, Tuple[Figure, Axes]] = None, outlier_perc: Union[int, float] = 2, cmap: Union[None, str] = None, alpha_overlay: float = 0.5, @@ -91,7 +90,7 @@ def visualize_image_attr( title: Union[None, str] = None, fig_size: Tuple[int, int] = (6, 6), use_pyplot: bool = True, -): +) -> Tuple[Figure, Axes]: """ Visualizes attribution for a given image by normalizing attribution values of the desired sign (``'positive'`` | ``'negative'`` | ``'absolute_value'`` | ``'all'``) and displaying them using the desired mode @@ -163,10 +162,10 @@ def visualize_image_attr( Returns ------- 2-element tuple of consisting of - - `figure` : ``matplotlib.pyplot.figure`` - Figure object on which visualization is created. If `plt_fig_axis` \ + - `figure` : ``matplotlib.pyplot.Figure`` - Figure object on which visualization is created. If `plt_fig_axis` \ argument is given, this is the same figure provided. - - `axis` : ``matplotlib.pyplot.axis`` - Axis object on which visualization is created. If `plt_fig_axis` argument \ + - `axis` : ``matplotlib.pyplot.Axes`` - Axes object on which visualization is created. If `plt_fig_axis` argument \ is given, this is the same axis provided. """ @@ -178,7 +177,7 @@ def visualize_image_attr( plt_fig, plt_axis = plt.subplots(figsize=fig_size) else: plt_fig = Figure(figsize=fig_size) - plt_axis = plt_fig.subplots() + plt_axis = plt_fig.subplots() # type: ignore[assignment] if original_image is not None: if np.max(original_image) <= 1.0: @@ -204,7 +203,7 @@ def visualize_image_attr( # Set default colormap and bounds based on sign. if VisualizeSign[sign] == VisualizeSign.all: - default_cmap = LinearSegmentedColormap.from_list( + default_cmap: Union[LinearSegmentedColormap, str] = LinearSegmentedColormap.from_list( "RdWhGn", ["red", "white", "green"] ) vmin, vmax = -1, 1 @@ -219,7 +218,7 @@ def visualize_image_attr( vmin, vmax = 0, 1 else: raise AssertionError("Visualize Sign type is not valid.") - cmap = cmap if cmap is not None else default_cmap + cmap = cmap if cmap is not None else default_cmap # type: ignore[assignment] # Show appropriate image visualization. if ImageVisualizationMethod[method] == ImageVisualizationMethod.heat_map: @@ -339,7 +338,7 @@ def _create_heatmap(data: np.ndarray, if cbar: if cbar_ax is None: cbar_ax = ax - cbar_obj = ax.figure.colorbar(im, ax=cbar_ax, **cbar_kws) + cbar_obj = ax.figure.colorbar(im, ax=cbar_ax, **cbar_kws) # type: ignore[union-attr] cbar_obj.ax.set_ylabel(cbar_label, rotation=-90, va="bottom") # show all ticks and label them with the respective list entries. @@ -355,7 +354,7 @@ def _create_heatmap(data: np.ndarray, plt.setp(ax.get_xticklabels(), rotation=90, ha="right", rotation_mode="anchor") # turn spines off and create white grid. - ax.spines[:].set_visible(False) + ax.spines[:].set_visible(False) # type: ignore[call-overload] ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True) ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True) @@ -401,6 +400,7 @@ def _annotate_heatmap(im: matplotlib.image.AxesImage, if not isinstance(data, (list, np.ndarray)): data = im.get_array() + assert isinstance(data, np.ndarray) # for mypy, since get_array() can sometimes be None # normalize the threshold to the images color range. if threshold is not None: @@ -422,7 +422,7 @@ def _annotate_heatmap(im: matplotlib.image.AxesImage, for i in range(data.shape[0]): for j in range(data.shape[1]): kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)]) - text = im.axes.text(j, i, fmt(data[i, j], None), **kw) + text = im.axes.text(j, i, fmt(data[i, j], None), **kw) # type: ignore[arg-type] texts.append(text) return texts From e03bb024eb286ff6330f729c9002cd84e4fbbcb9 Mon Sep 17 00:00:00 2001 From: Janis Klaise Date: Wed, 27 Sep 2023 12:12:23 +0100 Subject: [PATCH 2/3] Lint --- alibi/explainers/pd_variance.py | 9 ++++++--- alibi/prototypes/protoselect.py | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/alibi/explainers/pd_variance.py b/alibi/explainers/pd_variance.py index 4be02b42f..44d7fcf36 100644 --- a/alibi/explainers/pd_variance.py +++ b/alibi/explainers/pd_variance.py @@ -697,15 +697,18 @@ def _plot_feature_interaction(exp: Explanation, # set title for the 2-way pdp ax = axes_flatten[step * i] (ft_name1, ft_name2) = feature_names[features[i]] # type: Tuple[str, str] # type: ignore[misc] - ax.set_title('inter({},{}) = {:.3f}'.format(ft_name1, ft_name2, feature_interaction[i])) # type: ignore[union-attr] + ax.set_title('inter({},{}) = {:.3f}'.format(ft_name1, ft_name2, # type: ignore[union-attr] + feature_interaction[i])) # set title for the first conditional importance plot ax = axes.flatten()[step * i + 1] - ax.set_title('inter({}|{}) = {:.3f}'.format(ft_name2, ft_name1, conditional_importance[i][0][target_idx])) # type: ignore[union-attr] + ax.set_title('inter({}|{}) = {:.3f}'.format(ft_name2, ft_name1, # type: ignore[union-attr] + conditional_importance[i][0][target_idx])) # set title for the second conditional importance plot ax = axes.flatten()[step * i + 2] - ax.set_title('inter({}|{}) = {:.3f}'.format(ft_name1, ft_name2, conditional_importance[i][1][target_idx])) # type: ignore[union-attr] + ax.set_title('inter({}|{}) = {:.3f}'.format(ft_name1, ft_name2, # type: ignore[union-attr] + conditional_importance[i][1][target_idx])) return axes diff --git a/alibi/prototypes/protoselect.py b/alibi/prototypes/protoselect.py index 8e5ddaa21..5791da8f4 100644 --- a/alibi/prototypes/protoselect.py +++ b/alibi/prototypes/protoselect.py @@ -530,7 +530,7 @@ def _imscatterplot(x: np.ndarray, ax.set_xticks([]) ax.set_yticks([]) else: - fig = ax.figure # type: ignore[assignment] + fig = ax.figure # type: ignore[assignment] resized_imgs = [resize(images[i], image_size) for i in range(len(images))] imgs = [OffsetImage(img, zoom=zoom[i], cmap='gray') for i, img in enumerate(resized_imgs)] # type: ignore From 50291aaee518c0e502f0a1141ed6249806b9415b Mon Sep 17 00:00:00 2001 From: Janis Klaise Date: Wed, 27 Sep 2023 12:27:40 +0100 Subject: [PATCH 3/3] undo --- alibi/explainers/anchors/anchor_tabular_distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/alibi/explainers/anchors/anchor_tabular_distributed.py b/alibi/explainers/anchors/anchor_tabular_distributed.py index 1481204c7..56b7b9f0f 100644 --- a/alibi/explainers/anchors/anchor_tabular_distributed.py +++ b/alibi/explainers/anchors/anchor_tabular_distributed.py @@ -235,7 +235,7 @@ def fit(self, # type: ignore[override] d_samplers = [] for sampler in samplers: d_samplers.append( - ray.remote(RemoteSampler).remote( + ray.remote(RemoteSampler).remote( # type: ignore[call-arg] *(train_data_id, d_train_data_id, sampler) ) )