diff --git a/mesa/agent.py b/mesa/agent.py index 6155284ba42..243a187337b 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -188,15 +188,21 @@ def shuffle(self, inplace: bool = False) -> AgentSet: Returns: AgentSet: A shuffled AgentSet. Returns the current AgentSet if inplace is True. - """ - shuffled_agents = list(self) - self.random.shuffle(shuffled_agents) - return ( - AgentSet(shuffled_agents, self.model) - if not inplace - else self._update(shuffled_agents) - ) + Note: + Using inplace = True is more performant + + """ + weakrefs = list(self._agents.keyrefs()) + self.random.shuffle(weakrefs) + + if inplace: + self._agents.data = {entry: None for entry in weakrefs} + return self + else: + return AgentSet( + (agent for ref in weakrefs if (agent := ref()) is not None), self.model + ) def sort( self, @@ -251,9 +257,9 @@ def do( """ # we iterate over the actual weakref keys and check if weakref is alive before calling the method res = [ - getattr(agentref(), method_name)(*args, **kwargs) + getattr(agent, method_name)(*args, **kwargs) for agentref in self._agents.keyrefs() - if agentref() + if (agent := agentref()) is not None ] return res if return_results else self diff --git a/tests/test_agent.py b/tests/test_agent.py index 60a0eec15fa..7ad538eba27 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -66,7 +66,7 @@ def test_function(agent): assert all(a1 == a2 for a1, a2 in zip(agentset.select(), agentset)) assert all(a1 == a2 for a1, a2 in zip(agentset.select(n=5), agentset[:5])) - assert len(agentset.shuffle().select(n=5)) == 5 + assert len(agentset.shuffle(inplace=False).select(n=5)) == 5 def test_function(agent): return agent.unique_id