Skip to content

Commit

Permalink
Update examples to use updated space drawing (#2442)
Browse files Browse the repository at this point in the history
* Update app.py

* Update app.py

* ruff fixes

* advanced examples updated

* ruff related fixes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
quaquel and pre-commit-ci[bot] authored Oct 30, 2024
1 parent e605a1f commit 9de5595
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 57 deletions.
18 changes: 15 additions & 3 deletions mesa/examples/advanced/epstein_civil_violence/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def citizen_cop_portrayal(agent):
return

portrayal = {
"size": 25,
"size": 50,
}

if isinstance(agent, Citizen):
Expand All @@ -36,6 +36,13 @@ def citizen_cop_portrayal(agent):
return portrayal


def post_process(ax):
ax.set_aspect("equal")
ax.set_xticks([])
ax.set_yticks([])
ax.get_figure().set_size_inches(10, 10)


model_params = {
"height": 40,
"width": 40,
Expand All @@ -47,8 +54,13 @@ def citizen_cop_portrayal(agent):
"max_jail_term": Slider("Max Jail Term", 30, 0, 50, 1),
}

space_component = make_space_component(citizen_cop_portrayal)
chart_component = make_plot_measure([state.name.lower() for state in CitizenState])
space_component = make_space_component(
citizen_cop_portrayal, post_process=post_process, draw_grid=False
)

chart_component = make_plot_measure(
{state.name.lower(): agent_colors[state] for state in CitizenState}
)

epstein_model = EpsteinCivilViolence()

Expand Down
2 changes: 1 addition & 1 deletion mesa/examples/advanced/pd_grid/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def pd_agent_portrayal(agent):
"""
return {
"color": "blue" if agent.move == "C" else "red",
"shape": "s", # square marker
"marker": "s", # square marker
"size": 25,
}

Expand Down
6 changes: 3 additions & 3 deletions mesa/examples/advanced/sugarscape_g1mt/app.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import numpy as np
import solara
from matplotlib.figure import Figure
from sugarscape_g1mt.model import SugarscapeG1mt
from sugarscape_g1mt.trader_agents import Trader

from mesa.examples.advanced.sugarscape_g1mt.agents import Trader
from mesa.examples.advanced.sugarscape_g1mt.model import SugarscapeG1mt
from mesa.visualization import SolaraViz, make_plot_measure


Expand Down Expand Up @@ -57,6 +57,6 @@ def portray(g):
model1,
components=[SpaceDrawer, make_plot_measure(["Trader", "Price"])],
name="Sugarscape {G1, M, T}",
play_interval=1500,
play_interval=150,
)
page # noqa
16 changes: 14 additions & 2 deletions mesa/examples/basic/conways_game_of_life/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,17 @@


def agent_portrayal(agent):
return {"c": "white" if agent.state == 0 else "black", "marker": "s"}
return {
"color": "white" if agent.state == 0 else "black",
"marker": "s",
"size": 25,
}


def post_process(ax):
ax.set_aspect("equal")
ax.set_xticks([])
ax.set_yticks([])


model_params = {
Expand All @@ -22,7 +32,9 @@ def agent_portrayal(agent):
# Under the hood these are just classes that receive the model instance.
# You can also author your own visualization elements, which can also be functions
# that receive the model instance and return a valid solara component.
SpaceGraph = make_space_component(agent_portrayal)
SpaceGraph = make_space_component(
agent_portrayal, post_process=post_process, draw_grid=False
)


# Create the SolaraViz page. This will automatically create a server and display the
Expand Down
4 changes: 2 additions & 2 deletions mesa/examples/basic/schelling/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ def agent_portrayal(agent):

model1 = Schelling(20, 20, 0.8, 0.2, 3)

HappyPlot = make_plot_measure("happy")
HappyPlot = make_plot_measure({"happy": "tab:green"})

page = SolaraViz(
model1,
components=[
make_space_component(agent_portrayal),
make_plot_measure("happy"),
HappyPlot,
get_happy_agents,
],
model_params=model_params,
Expand Down
62 changes: 16 additions & 46 deletions mesa/examples/basic/virus_on_network/app.py
Original file line number Diff line number Diff line change
@@ -1,70 +1,37 @@
import math

import solara
from matplotlib.figure import Figure
from matplotlib.ticker import MaxNLocator

from mesa.examples.basic.virus_on_network.model import (
State,
VirusOnNetwork,
number_infected,
)
from mesa.visualization import Slider, SolaraViz, make_space_component

from mesa.visualization import (
Slider,
SolaraViz,
make_plot_measure,
make_space_component,
)

def agent_portrayal(graph):
def get_agent(node):
return graph.nodes[node]["agent"][0]

edge_width = []
edge_color = []
for u, v in graph.edges():
agent1 = get_agent(u)
agent2 = get_agent(v)
w = 2
ec = "#e8e8e8"
if State.RESISTANT in (agent1.state, agent2.state):
w = 3
ec = "black"
edge_width.append(w)
edge_color.append(ec)
def agent_portrayal(agent):
node_color_dict = {
State.INFECTED: "tab:red",
State.SUSCEPTIBLE: "tab:green",
State.RESISTANT: "tab:gray",
}
node_color = [node_color_dict[get_agent(node).state] for node in graph.nodes()]
return {
"width": edge_width,
"edge_color": edge_color,
"node_color": node_color,
}
return {"color": node_color_dict[agent.state], "size": 10}


def get_resistant_susceptible_ratio(model):
ratio = model.resistant_susceptible_ratio()
ratio_text = r"$\infty$" if ratio is math.inf else f"{ratio:.2f}"
infected_text = str(number_infected(model))

return f"Resistant/Susceptible Ratio: {ratio_text}<br>Infected Remaining: {infected_text}"


def make_plot(model):
# This is for the case when we want to plot multiple measures in 1 figure.
fig = Figure()
ax = fig.subplots()
measures = ["Infected", "Susceptible", "Resistant"]
colors = ["tab:red", "tab:green", "tab:gray"]
for i, m in enumerate(measures):
color = colors[i]
df = model.datacollector.get_model_vars_dataframe()
ax.plot(df.loc[:, m], label=m, color=color)
fig.legend()
# Set integer x axis
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
ax.set_xlabel("Step")
ax.set_ylabel("Number of Agents")
return solara.FigureMatplotlib(fig)
return solara.Markdown(
f"Resistant/Susceptible Ratio: {ratio_text}<br>Infected Remaining: {infected_text}"
)


model_params = {
Expand Down Expand Up @@ -120,15 +87,18 @@ def make_plot(model):
}

SpacePlot = make_space_component(agent_portrayal)
StatePlot = make_plot_measure(
{"Infected": "tab:red", "Susceptible": "tab:green", "Resistant": "tab:gray"}
)

model1 = VirusOnNetwork()

page = SolaraViz(
model1,
[
SpacePlot,
make_plot,
# get_resistant_susceptible_ratio, # TODO: Fix and uncomment
StatePlot,
get_resistant_susceptible_ratio,
],
model_params=model_params,
name="Virus Model",
Expand Down

0 comments on commit 9de5595

Please sign in to comment.