Skip to content

Commit

Permalink
Get mypy working again (#1457)
Browse files Browse the repository at this point in the history
* Get mypy working again

* Add test for new assert

* Test error string

* Inject get_data for testing

* More typing stuff for mypy to pass

* Ran black

* Added more mypy stuff to plot
  • Loading branch information
gaffney2010 authored Jan 13, 2025
1 parent 6a6c8a9 commit 6d2d465
Show file tree
Hide file tree
Showing 22 changed files with 162 additions and 84 deletions.
4 changes: 2 additions & 2 deletions axelrod/deterministic_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import pickle
from collections import UserDict
from typing import List, Tuple
from typing import List, Optional, Tuple

from axelrod import Classifiers

Expand Down Expand Up @@ -104,7 +104,7 @@ class DeterministicCache(UserDict):
methods to save/load the cache to/from a file.
"""

def __init__(self, file_name: str = None) -> None:
def __init__(self, file_name: Optional[str] = None) -> None:
"""Initialize a new cache.
Parameters
Expand Down
6 changes: 3 additions & 3 deletions axelrod/ecosystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
"""

import random
from typing import Callable, List
from typing import Callable, List, Optional

from axelrod.result_set import ResultSet

Expand All @@ -29,8 +29,8 @@ class Ecosystem(object):
def __init__(
self,
results: ResultSet,
fitness: Callable[[float], float] = None,
population: List[int] = None,
fitness: Optional[Callable[[float], float]] = None,
population: Optional[List[int]] = None,
) -> None:
"""Create a new ecosystem.
Expand Down
20 changes: 10 additions & 10 deletions axelrod/fingerprint.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from collections import namedtuple
from tempfile import mkstemp
from typing import Any, List, Union
from typing import Any, List, Optional, Union

import dask.dataframe as dd
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -280,10 +280,10 @@ def fingerprint(
turns: int = 50,
repetitions: int = 10,
step: float = 0.01,
processes: int = None,
filename: str = None,
processes: Optional[int] = None,
filename: Optional[str] = None,
progress_bar: bool = True,
seed: int = None,
seed: Optional[int] = None,
) -> dict:
"""Build and play the spatial tournament.
Expand Down Expand Up @@ -358,7 +358,7 @@ def plot(
self,
cmap: str = "seismic",
interpolation: str = "none",
title: str = None,
title: Optional[str] = None,
colorbar: bool = True,
labels: bool = True,
) -> plt.Figure:
Expand Down Expand Up @@ -437,11 +437,11 @@ def fingerprint(
self,
turns: int = 50,
repetitions: int = 1000,
noise: float = None,
processes: int = None,
filename: str = None,
noise: Optional[float] = None,
processes: Optional[int] = None,
filename: Optional[str] = None,
progress_bar: bool = True,
seed: int = None,
seed: Optional[int] = None,
) -> np.ndarray:
"""Creates a spatial tournament to run the necessary matches to obtain
fingerprint data.
Expand Down Expand Up @@ -556,7 +556,7 @@ def plot(
self,
cmap: str = "viridis",
interpolation: str = "none",
title: str = None,
title: Optional[str] = None,
colorbar: bool = True,
labels: bool = True,
display_names: bool = False,
Expand Down
3 changes: 2 additions & 1 deletion axelrod/game.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Tuple, Union

import numpy as np
import numpy.typing as npt

from axelrod import Action

Expand All @@ -20,7 +21,7 @@ class AsymmetricGame(object):
"""

# pylint: disable=invalid-name
def __init__(self, A: np.array, B: np.array) -> None:
def __init__(self, A: npt.NDArray, B: npt.NDArray) -> None:
"""
Creates an asymmetric game from two matrices.
Expand Down
12 changes: 9 additions & 3 deletions axelrod/load_data_.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pathlib
import pkgutil
from typing import Dict, List, Text, Tuple
from typing import Callable, Dict, List, Optional, Tuple


def axl_filename(path: pathlib.Path) -> pathlib.Path:
Expand All @@ -20,12 +20,18 @@ def axl_filename(path: pathlib.Path) -> pathlib.Path:
return axl_path / path


def load_file(filename: str, directory: str) -> List[List[str]]:
def load_file(
filename: str,
directory: str,
get_data: Callable[[str, str], Optional[bytes]] = pkgutil.get_data,
) -> List[List[str]]:
"""Loads a data file stored in the Axelrod library's data subdirectory,
likely for parameters for a strategy."""

path = str(pathlib.Path(directory) / filename)
data_bytes = pkgutil.get_data(__name__, path)
data_bytes = get_data(__name__, path)
if data_bytes is None:
raise FileNotFoundError(f"Some loader issue for path {path}")
data = data_bytes.decode("UTF-8", "replace")

rows = []
Expand Down
4 changes: 2 additions & 2 deletions axelrod/mock_player.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from itertools import cycle
from typing import List
from typing import List, Optional

from axelrod.action import Action
from axelrod.player import Player
Expand All @@ -14,7 +14,7 @@ class MockPlayer(Player):

name = "Mock Player"

def __init__(self, actions: List[Action] = None) -> None:
def __init__(self, actions: Optional[List[Action]] = None) -> None:
super().__init__()
if not actions:
actions = []
Expand Down
11 changes: 5 additions & 6 deletions axelrod/moran.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ def __init__(
self,
players: List[Player],
turns: int = DEFAULT_TURNS,
prob_end: float = None,
prob_end: Optional[float] = None,
noise: float = 0,
game: Game = None,
deterministic_cache: DeterministicCache = None,
mutation_rate: float = 0.0,
mode: str = "bd",
interaction_graph: Graph = None,
reproduction_graph: Graph = None,
fitness_transformation: Callable = None,
fitness_transformation: Optional[Callable] = None,
mutation_method="transition",
stop_on_fixation=True,
seed=None,
Expand Down Expand Up @@ -175,7 +175,7 @@ def set_players(self) -> None:
self.populations = [self.population_distribution()]

def fitness_proportionate_selection(
self, scores: List, fitness_transformation: Callable = None
self, scores: List, fitness_transformation: Optional[Callable] = None
) -> int:
"""Randomly selects an individual proportionally to score.
Expand Down Expand Up @@ -229,7 +229,7 @@ def mutate(self, index: int) -> Player:
# Just clone the player
return self.players[index].clone()

def death(self, index: int = None) -> int:
def death(self, index: Optional[int] = None) -> int:
"""
Selects the player to be removed.
Expand Down Expand Up @@ -258,7 +258,7 @@ def death(self, index: int = None) -> int:
i = self.index[vertex]
return i

def birth(self, index: int = None) -> int:
def birth(self, index: Optional[int] = None) -> int:
"""The birth event.
Parameters
Expand Down Expand Up @@ -349,7 +349,6 @@ def _matchup_indices(self) -> Set[Tuple[int, int]]:
# The other calculations are unnecessary
if self.mode == "db":
source = self.index[self.dead]
self.dead = None
sources = sorted(self.interaction_graph.out_vertices(source))
else:
# birth-death is global
Expand Down
71 changes: 52 additions & 19 deletions axelrod/plot.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pathlib
from typing import List, Union
from typing import Any, Callable, List, Optional, Union

import matplotlib
import matplotlib.pyplot as plt
Expand All @@ -10,7 +10,7 @@
from .load_data_ import axl_filename
from .result_set import ResultSet

titleType = List[str]
titleType = str
namesType = List[str]
dataType = List[List[Union[int, float]]]

Expand All @@ -25,8 +25,11 @@ def _violinplot(
self,
data: dataType,
names: namesType,
title: titleType = None,
ax: matplotlib.axes.SubplotBase = None,
title: Optional[titleType] = None,
ax: Optional[matplotlib.axes.Axes] = None,
get_figure: Callable[
[matplotlib.axes.Axes], Union[matplotlib.figure.Figure, Any, None]
] = lambda ax: ax.get_figure(),
) -> matplotlib.figure.Figure:
"""For making violinplots."""

Expand All @@ -35,7 +38,11 @@ def _violinplot(
else:
ax = ax

figure = ax.get_figure()
figure = get_figure(ax)
if not isinstance(figure, matplotlib.figure.Figure):
raise RuntimeError(
"get_figure unexpectedly returned a non-figure object"
)
width = max(self.num_players / 3, 12)
height = width / 2
spacing = 4
Expand All @@ -50,7 +57,7 @@ def _violinplot(
)
ax.set_xticks(positions)
ax.set_xticklabels(names, rotation=90)
ax.set_xlim([0, spacing * (self.num_players + 1)])
ax.set_xlim((0, spacing * (self.num_players + 1)))
ax.tick_params(axis="both", which="both", labelsize=8)
if title:
ax.set_title(title)
Expand All @@ -76,7 +83,9 @@ def _boxplot_xticks_labels(self):
return [str(n) for n in self.result_set.ranked_names]

def boxplot(
self, title: titleType = None, ax: matplotlib.axes.SubplotBase = None
self,
title: Optional[titleType] = None,
ax: Optional[matplotlib.axes.Axes] = None,
) -> matplotlib.figure.Figure:
"""For the specific mean score boxplot."""
data = self._boxplot_dataset
Expand All @@ -98,7 +107,9 @@ def _winplot_dataset(self):
return wins, ranked_names

def winplot(
self, title: titleType = None, ax: matplotlib.axes.SubplotBase = None
self,
title: Optional[titleType] = None,
ax: Optional[matplotlib.axes.Axes] = None,
) -> matplotlib.figure.Figure:
"""Plots the distributions for the number of wins for each strategy."""

Expand Down Expand Up @@ -126,7 +137,9 @@ def _sdv_plot_dataset(self):
return diffs, ranked_names

def sdvplot(
self, title: titleType = None, ax: matplotlib.axes.SubplotBase = None
self,
title: Optional[titleType] = None,
ax: Optional[matplotlib.axes.Axes] = None,
) -> matplotlib.figure.Figure:
"""Score difference violin plots to visualize the distributions of how
players attain their payoffs."""
Expand All @@ -143,7 +156,9 @@ def _lengthplot_dataset(self):
]

def lengthplot(
self, title: titleType = None, ax: matplotlib.axes.SubplotBase = None
self,
title: Optional[titleType] = None,
ax: Optional[matplotlib.axes.Axes] = None,
) -> matplotlib.figure.Figure:
"""For the specific match length boxplot."""
data = self._lengthplot_dataset
Expand Down Expand Up @@ -174,9 +189,12 @@ def _payoff_heatmap(
self,
data: dataType,
names: namesType,
title: titleType = None,
ax: matplotlib.axes.SubplotBase = None,
title: Optional[titleType] = None,
ax: Optional[matplotlib.axes.Axes] = None,
cmap: str = "viridis",
get_figure: Callable[
[matplotlib.axes.Axes], Union[matplotlib.figure.Figure, Any, None]
] = lambda ax: ax.get_figure(),
) -> matplotlib.figure.Figure:
"""Generic heatmap plot"""

Expand All @@ -185,7 +203,11 @@ def _payoff_heatmap(
else:
ax = ax

figure = ax.get_figure()
figure = get_figure(ax)
if not isinstance(figure, matplotlib.figure.Figure):
raise RuntimeError(
"get_figure unexpectedly returned a non-figure object"
)
width = max(self.num_players / 4, 12)
height = width
figure.set_size_inches(width, height)
Expand All @@ -202,15 +224,19 @@ def _payoff_heatmap(
return figure

def pdplot(
self, title: titleType = None, ax: matplotlib.axes.SubplotBase = None
self,
title: Optional[titleType] = None,
ax: Optional[matplotlib.axes.Axes] = None,
) -> matplotlib.figure.Figure:
"""Payoff difference heatmap to visualize the distributions of how
players attain their payoffs."""
matrix, names = self._pdplot_dataset
return self._payoff_heatmap(matrix, names, title=title, ax=ax)

def payoff(
self, title: titleType = None, ax: matplotlib.axes.SubplotBase = None
self,
title: Optional[titleType] = None,
ax: Optional[matplotlib.axes.Axes] = None,
) -> matplotlib.figure.Figure:
"""Payoff heatmap to visualize the distributions of how
players attain their payoffs."""
Expand All @@ -223,9 +249,12 @@ def payoff(
def stackplot(
self,
eco,
title: titleType = None,
title: Optional[titleType] = None,
logscale: bool = True,
ax: matplotlib.axes.SubplotBase = None,
ax: Optional[matplotlib.axes.Axes] = None,
get_figure: Callable[
[matplotlib.axes.Axes], Union[matplotlib.figure.Figure, Any, None]
] = lambda ax: ax.get_figure(),
) -> matplotlib.figure.Figure:

populations = eco.population_sizes
Expand All @@ -235,7 +264,11 @@ def stackplot(
else:
ax = ax

figure = ax.get_figure()
figure = get_figure(ax)
if not isinstance(figure, matplotlib.figure.Figure):
raise RuntimeError(
"get_figure unexpectedly returned a non-figure object"
)
turns = range(len(populations))
pops = [
[populations[iturn][ir] for iturn in turns]
Expand All @@ -247,7 +280,7 @@ def stackplot(
ax.yaxis.set_label_position("right")
ax.yaxis.labelpad = 25.0

ax.set_ylim([0.0, 1.0])
ax.set_ylim((0.0, 1.0))
ax.set_ylabel("Relative population size")
ax.set_xlabel("Turn")
if title is not None:
Expand Down
Loading

0 comments on commit 6d2d465

Please sign in to comment.