diff --git a/benchmarks/WolfSheep/wolf_sheep.py b/benchmarks/WolfSheep/wolf_sheep.py index 9fd71846b7f..f085ce429df 100644 --- a/benchmarks/WolfSheep/wolf_sheep.py +++ b/benchmarks/WolfSheep/wolf_sheep.py @@ -10,7 +10,7 @@ import math from mesa import Model -from mesa.experimental.cell_space import CellAgent, OrthogonalVonNeumannGrid +from mesa.experimental.cell_space import CellAgent, FixedAgent, OrthogonalVonNeumannGrid from mesa.experimental.devs import ABMSimulator @@ -87,7 +87,7 @@ def feed(self): sheep_to_eat.remove() -class GrassPatch(CellAgent): +class GrassPatch(FixedAgent): """A patch of grass that grows at a fixed rate and it is eaten by sheep.""" @property diff --git a/mesa/experimental/cell_space/__init__.py b/mesa/experimental/cell_space/__init__.py index 792dde611b9..69386a4cacf 100644 --- a/mesa/experimental/cell_space/__init__.py +++ b/mesa/experimental/cell_space/__init__.py @@ -6,7 +6,11 @@ """ from mesa.experimental.cell_space.cell import Cell -from mesa.experimental.cell_space.cell_agent import CellAgent +from mesa.experimental.cell_space.cell_agent import ( + CellAgent, + FixedAgent, + Grid2DMovingAgent, +) from mesa.experimental.cell_space.cell_collection import CellCollection from mesa.experimental.cell_space.discrete_space import DiscreteSpace from mesa.experimental.cell_space.grid import ( @@ -22,6 +26,8 @@ "CellCollection", "Cell", "CellAgent", + "Grid2DMovingAgent", + "FixedAgent", "DiscreteSpace", "Grid", "HexGrid", diff --git a/mesa/experimental/cell_space/cell.py b/mesa/experimental/cell_space/cell.py index 8c7fb9bb632..6c92afbe162 100644 --- a/mesa/experimental/cell_space/cell.py +++ b/mesa/experimental/cell_space/cell.py @@ -7,12 +7,12 @@ from random import Random from typing import TYPE_CHECKING, Any +from mesa.experimental.cell_space.cell_agent import CellAgent from mesa.experimental.cell_space.cell_collection import CellCollection from mesa.space import PropertyLayer if TYPE_CHECKING: from mesa.agent import Agent - from mesa.experimental.cell_space.cell_agent import CellAgent Coordinate = tuple[int, ...] @@ -69,7 +69,7 @@ def __init__( self.agents: list[ Agent ] = [] # TODO:: change to AgentSet or weakrefs? (neither is very performant, ) - self.capacity: int = capacity + self.capacity: int | None = capacity self.properties: dict[Coordinate, object] = {} self.random = random self._mesa_property_layers: dict[str, PropertyLayer] = {} @@ -136,7 +136,7 @@ def __repr__(self): # noqa return f"Cell({self.coordinate}, {self.agents})" @cached_property - def neighborhood(self) -> CellCollection: + def neighborhood(self) -> CellCollection[Cell]: """Returns the direct neighborhood of the cell. This is equivalent to cell.get_neighborhood(radius=1) @@ -148,7 +148,7 @@ def neighborhood(self) -> CellCollection: @cache # noqa: B019 def get_neighborhood( self, radius: int = 1, include_center: bool = False - ) -> CellCollection: + ) -> CellCollection[Cell]: """Returns a list of all neighboring cells for the given radius. For getting the direct neighborhood (i.e., radius=1) you can also use diff --git a/mesa/experimental/cell_space/cell_agent.py b/mesa/experimental/cell_space/cell_agent.py index fea71b02a85..5dc967ee76d 100644 --- a/mesa/experimental/cell_space/cell_agent.py +++ b/mesa/experimental/cell_space/cell_agent.py @@ -2,18 +2,24 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Protocol -from mesa import Agent +from mesa.agent import Agent if TYPE_CHECKING: - from mesa.experimental.cell_space.cell import Cell + from mesa.experimental.cell_space import Cell + + +class HasCellProtocol(Protocol): + """Protocol for discrete space cell holders.""" + + cell: Cell class HasCell: """Descriptor for cell movement behavior.""" - _mesa_cell: Cell = None + _mesa_cell: Cell | None = None @property def cell(self) -> Cell | None: # noqa: D102 @@ -33,17 +39,95 @@ def cell(self, cell: Cell | None) -> None: cell.add_agent(self) -class CellAgent(Agent, HasCell): +class BasicMovement: + """Mixin for moving agents in discrete space.""" + + def move_to(self: HasCellProtocol, cell: Cell) -> None: + """Move to a new cell.""" + self.cell = cell + + def move_relative(self: HasCellProtocol, direction: tuple[int, ...]): + """Move to a cell relative to the current cell. + + Args: + direction: The direction to move in. + """ + new_cell = self.cell.connections.get(direction) + if new_cell is not None: + self.cell = new_cell + else: + raise ValueError(f"No cell in direction {direction}") + + +class FixedCell(HasCell): + """Mixin for agents that are fixed to a cell.""" + + @property + def cell(self) -> Cell | None: # noqa: D102 + return self._mesa_cell + + @cell.setter + def cell(self, cell: Cell) -> None: + if self.cell is not None: + raise ValueError("Cannot move agent in FixedCell") + self._mesa_cell = cell + + cell.add_agent(self) + + +class CellAgent(Agent, HasCell, BasicMovement): """Cell Agent is an extension of the Agent class and adds behavior for moving in discrete spaces. Attributes: - unique_id (int): A unique identifier for this agent. - model (Model): The model instance to which the agent belongs - pos: (Position | None): The position of the agent in the space - cell: (Cell | None): the cell which the agent occupies + cell (Cell): The cell the agent is currently in. """ def remove(self): """Remove the agent from the model.""" super().remove() self.cell = None # ensures that we are also removed from cell + + +class FixedAgent(Agent, FixedCell): + """A patch in a 2D grid.""" + + def remove(self): + """Remove the agent from the model.""" + super().remove() + + # fixme we leave self._mesa_cell on the original value + # so you cannot hijack remove() to move patches + self.cell.remove_agent(self) + + +class Grid2DMovingAgent(CellAgent): + """Mixin for moving agents in 2D grids.""" + + # fmt: off + DIRECTION_MAP = { + "n": (-1, 0), "north": (-1, 0), "up": (-1, 0), + "s": (1, 0), "south": (1, 0), "down": (1, 0), + "e": (0, 1), "east": (0, 1), "right": (0, 1), + "w": (0, -1), "west": (0, -1), "left": (0, -1), + "ne": (-1, 1), "northeast": (-1, 1), "upright": (-1, 1), + "nw": (-1, -1), "northwest": (-1, -1), "upleft": (-1, -1), + "se": (1, 1), "southeast": (1, 1), "downright": (1, 1), + "sw": (1, -1), "southwest": (1, -1), "downleft": (1, -1) + } + # fmt: on + + def move(self, direction: str, distance: int = 1): + """Move the agent in a cardinal direction. + + Args: + direction: The cardinal direction to move in. + distance: The distance to move. + """ + direction = direction.lower() # Convert direction to lowercase + + if direction not in self.DIRECTION_MAP: + raise ValueError(f"Invalid direction: {direction}") + + move_vector = self.DIRECTION_MAP[direction] + for _ in range(distance): + self.move_relative(move_vector) diff --git a/mesa/experimental/devs/examples/wolf_sheep.py b/mesa/experimental/devs/examples/wolf_sheep.py index 8d7d16d671a..74318ef88af 100644 --- a/mesa/experimental/devs/examples/wolf_sheep.py +++ b/mesa/experimental/devs/examples/wolf_sheep.py @@ -1,6 +1,7 @@ """Example of using ABM simulator for Wolf-Sheep Predation Model.""" import mesa +from mesa.experimental.cell_space import FixedAgent from mesa.experimental.devs.simulator import ABMSimulator @@ -90,7 +91,7 @@ def feed(self): sheep_to_eat.die() -class GrassPatch(mesa.Agent): +class GrassPatch(FixedAgent): """A patch of grass that grows at a fixed rate and it is eaten by sheep.""" @property diff --git a/tests/test_cell_space.py b/tests/test_cell_space.py index a8e4abad336..4d52e159045 100644 --- a/tests/test_cell_space.py +++ b/tests/test_cell_space.py @@ -10,6 +10,8 @@ Cell, CellAgent, CellCollection, + FixedAgent, + Grid2DMovingAgent, HexGrid, Network, OrthogonalMooreGrid, @@ -641,3 +643,51 @@ def test_cell_agent(): # noqa: D103 assert agent not in model._all_agents assert agent not in cell1.agents assert agent not in cell2.agents + + model = Model() + agent = CellAgent(model) + agent.cell = cell1 + agent.move_to(cell2) + assert agent not in cell1.agents + assert agent in cell2.agents + + +def test_grid2DMovingAgent(): # noqa: D103 + # we first test on a moore grid because all directions are defined + grid = OrthogonalMooreGrid((10, 10), torus=False) + + model = Model() + agent = Grid2DMovingAgent(model) + + agent.cell = grid[4, 4] + agent.move("up") + assert agent.cell == grid[3, 4] + + grid = OrthogonalVonNeumannGrid((10, 10), torus=False) + + model = Model() + agent = Grid2DMovingAgent(model) + agent.cell = grid[4, 4] + + with pytest.raises(ValueError): # test for invalid direction + agent.move("upright") + + with pytest.raises(ValueError): # test for unknown direction + agent.move("back") + + +def test_patch(): # noqa: D103 + cell1 = Cell((1,), capacity=None, random=random.Random()) + cell2 = Cell((2,), capacity=None, random=random.Random()) + + # connect + # add_agent + model = Model() + agent = FixedAgent(model) + agent.cell = cell1 + + with pytest.raises(ValueError): + agent.cell = cell2 + + agent.remove() + assert agent not in model._agents