Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Cythonize get_cell_list_contents #1995

Closed
wants to merge 11 commits into from
5 changes: 5 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ repos:
hooks:
- id: pyupgrade
args: [--py38-plus]
- repo: https://github.com/MarcoGorelli/cython-lint
rev: v0.16.0
hooks:
- id: cython-lint
- id: double-quote-cython-strings
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0 # Use the ref you want to point at
hooks:
Expand Down
55 changes: 26 additions & 29 deletions mesa/space.py → mesa/space.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,6 @@
NetworkGrid: a network where each node contains zero or more agents.
"""

# Mypy; for the `|` operator purpose
# Remove this __future__ import once the oldest supported Python is 3.10
from __future__ import annotations

import collections
import contextlib
import inspect
Expand All @@ -25,7 +21,7 @@
import warnings
from collections.abc import Iterable, Iterator, Sequence
from numbers import Real
from typing import Any, Callable, TypeVar, Union, cast, overload
from typing import Any, Callable, TypeVar, Union, cast
from warnings import warn

with contextlib.suppress(ImportError):
Expand All @@ -49,6 +45,7 @@

GridContent = Union[Agent, None]
MultiGridContent = list[Agent]
GridIndex = tuple[int | slice, int | slice] | int | Sequence[Coordinate]

F = TypeVar("F", bound=Callable[..., Any])

Expand All @@ -66,12 +63,18 @@ def wrapper(grid_instance, positions) -> Any:
return cast(F, wrapper)


def ensure_positions_as_list(positions):
if len(positions) == 2 and not isinstance(positions[0], tuple):
return [positions]
return positions


def is_integer(x: Real) -> bool:
# Check if x is either a CPython integer or Numpy integer.
return isinstance(x, _types_integer)


class _Grid:
cdef class _Grid:
"""Base class for a rectangular grid.

Grid cells are indexed by [x, y], where [0, 0] is assumed to be the
Expand Down Expand Up @@ -134,17 +137,7 @@ def build_empties(self) -> None:
)
self._empties_built = True

@overload
def __getitem__(self, index: int | Sequence[Coordinate]) -> list[GridContent]:
...

@overload
def __getitem__(
self, index: tuple[int | slice, int | slice]
) -> GridContent | list[GridContent]:
...

def __getitem__(self, index):
def __getitem__(self, index: GridIndex) -> GridContent | list[GridContent]:
"""Access contents from the grid."""

if isinstance(index, int):
Expand Down Expand Up @@ -328,7 +321,8 @@ def iter_neighbors(
"""
default_val = self.default_val()
for x, y in self.get_neighborhood(pos, moore, include_center, radius):
if (cell := self._grid[x][y]) != default_val:
cell = self._grid[x][y]
if cell != default_val:
yield cell

def get_neighbors(
Expand Down Expand Up @@ -389,11 +383,11 @@ def iter_cell_list_contents(
# iter_cell_list_contents returns only non-empty contents.
default_val = self.default_val()
for x, y in cell_list:
if (cell := self._grid[x][y]) != default_val:
cell = self._grid[x][y]
if cell != default_val:
yield cell

@accept_tuple_argument
def get_cell_list_contents(self, cell_list: Iterable[Coordinate]) -> list[Agent]:
cpdef object get_cell_list_contents(self, cell_list: Iterable[Coordinate]):
"""Returns an iterator of the agents contained in the cells identified
in `cell_list`; cells with empty content are excluded.

Expand All @@ -403,7 +397,7 @@ def get_cell_list_contents(self, cell_list: Iterable[Coordinate]) -> list[Agent]
Returns:
A list of the agents contained in the cells identified in `cell_list`.
"""
return list(self.iter_cell_list_contents(cell_list))
return list(self.iter_cell_list_contents(ensure_positions_as_list(cell_list)))

def place_agent(self, agent: Agent, pos: Coordinate) -> None:
...
Expand Down Expand Up @@ -489,9 +483,11 @@ def _distance_squared(self, pos1: Coordinate, pos2: Coordinate) -> float:
def swap_pos(self, agent_a: Agent, agent_b: Agent) -> None:
"""Swap agents positions"""
agents_no_pos = []
if (pos_a := agent_a.pos) is None:
pos_a = agent_a.pos
if pos_a is None:
agents_no_pos.append(agent_a)
if (pos_b := agent_b.pos) is None:
pos_b = agent_b.pos
if pos_b is None:
agents_no_pos.append(agent_b)
if agents_no_pos:
agents_no_pos = [f"<Agent id: {a.unique_id}>" for a in agents_no_pos]
Expand Down Expand Up @@ -992,7 +988,8 @@ def place_agent(self, agent: Agent, pos: Coordinate) -> None:

def remove_agent(self, agent: Agent) -> None:
"""Remove the agent from the grid and set its pos attribute to None."""
if (pos := agent.pos) is None:
pos = agent.pos
if pos is None:
return
x, y = pos
self._grid[x][y] = self.default_val()
Expand Down Expand Up @@ -1072,7 +1069,7 @@ def iter_cell_list_contents(
"""
default_val = self.default_val()
return itertools.chain.from_iterable(
cell for x, y in cell_list if (cell := self._grid[x][y]) != default_val
self._grid[x][y] for x, y in cell_list if self._grid[x][y] != default_val
)


Expand All @@ -1099,7 +1096,7 @@ def torus_adj_2d(self, pos: Coordinate) -> Coordinate:

def get_neighborhood(
self, pos: Coordinate, include_center: bool = False, radius: int = 1
) -> list[Coordinate]:
) -> Sequence[Coordinate]:
"""Return a list of coordinates that are in the
neighborhood of a certain point. To calculate the neighborhood
for a HexGrid the parity of the x coordinate of the point is
Expand Down Expand Up @@ -1561,7 +1558,7 @@ def is_cell_empty(self, node_id: int) -> bool:
"""Returns a bool of the contents of a cell."""
return self.G.nodes[node_id]["agent"] == self.default_val()

def get_cell_list_contents(self, cell_list: list[int]) -> list[Agent]:
def get_cell_list_contents(self, cell_list: list[int] | nx.Graph) -> list[Agent]:
"""Returns a list of the agents contained in the nodes identified
in `cell_list`; nodes with empty content are excluded.
"""
Expand All @@ -1571,7 +1568,7 @@ def get_all_cell_contents(self) -> list[Agent]:
"""Returns a list of all the agents in the network."""
return self.get_cell_list_contents(self.G)

def iter_cell_list_contents(self, cell_list: list[int]) -> Iterator[Agent]:
def iter_cell_list_contents(self, cell_list: list[int] | nx.Graph) -> Iterator[Agent]:
"""Returns an iterator of the agents contained in the nodes identified
in `cell_list`; nodes with empty content are excluded.
"""
Expand Down
10 changes: 10 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ packages = ["mesa"]
[tool.hatch.version]
path = "mesa/__init__.py"

[tool.hatch.build.hooks.cython]
dependencies = ["hatch-cython"]

[tool.hatch.build.hooks.cython.options]
src = "mesa"
compile_py = false

[tool.ruff]
# See https://github.com/charliermarsh/ruff#rules for error code definitions.
select = [
Expand Down Expand Up @@ -130,3 +137,6 @@ extend-exclude = ["docs", "build"]
# Hardcode to Python 3.9.
# Reminder to update mesa-examples if the value below is changed.
target-version = "py39"

[tool.cython-lint]
ignore = ["E501"]
Loading