From 9de55959f8225cdc9124f0d1b7e2ddf5fbf7d0c9 Mon Sep 17 00:00:00 2001 From: Jan Kwakkel Date: Wed, 30 Oct 2024 20:08:17 +0100 Subject: [PATCH] Update examples to use updated space drawing (#2442) * Update app.py * Update app.py * ruff fixes * advanced examples updated * ruff related fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../advanced/epstein_civil_violence/app.py | 18 +++++- mesa/examples/advanced/pd_grid/app.py | 2 +- mesa/examples/advanced/sugarscape_g1mt/app.py | 6 +- .../basic/conways_game_of_life/app.py | 16 ++++- mesa/examples/basic/schelling/app.py | 4 +- mesa/examples/basic/virus_on_network/app.py | 62 +++++-------------- 6 files changed, 51 insertions(+), 57 deletions(-) diff --git a/mesa/examples/advanced/epstein_civil_violence/app.py b/mesa/examples/advanced/epstein_civil_violence/app.py index 538ef186f57..99304cd4618 100644 --- a/mesa/examples/advanced/epstein_civil_violence/app.py +++ b/mesa/examples/advanced/epstein_civil_violence/app.py @@ -25,7 +25,7 @@ def citizen_cop_portrayal(agent): return portrayal = { - "size": 25, + "size": 50, } if isinstance(agent, Citizen): @@ -36,6 +36,13 @@ def citizen_cop_portrayal(agent): return portrayal +def post_process(ax): + ax.set_aspect("equal") + ax.set_xticks([]) + ax.set_yticks([]) + ax.get_figure().set_size_inches(10, 10) + + model_params = { "height": 40, "width": 40, @@ -47,8 +54,13 @@ def citizen_cop_portrayal(agent): "max_jail_term": Slider("Max Jail Term", 30, 0, 50, 1), } -space_component = make_space_component(citizen_cop_portrayal) -chart_component = make_plot_measure([state.name.lower() for state in CitizenState]) +space_component = make_space_component( + citizen_cop_portrayal, post_process=post_process, draw_grid=False +) + +chart_component = make_plot_measure( + {state.name.lower(): agent_colors[state] for state in CitizenState} +) epstein_model = EpsteinCivilViolence() diff --git a/mesa/examples/advanced/pd_grid/app.py b/mesa/examples/advanced/pd_grid/app.py index 6edf8140536..d5bfd626e3c 100644 --- a/mesa/examples/advanced/pd_grid/app.py +++ b/mesa/examples/advanced/pd_grid/app.py @@ -13,7 +13,7 @@ def pd_agent_portrayal(agent): """ return { "color": "blue" if agent.move == "C" else "red", - "shape": "s", # square marker + "marker": "s", # square marker "size": 25, } diff --git a/mesa/examples/advanced/sugarscape_g1mt/app.py b/mesa/examples/advanced/sugarscape_g1mt/app.py index 752998891bd..39969e24079 100644 --- a/mesa/examples/advanced/sugarscape_g1mt/app.py +++ b/mesa/examples/advanced/sugarscape_g1mt/app.py @@ -1,9 +1,9 @@ import numpy as np import solara from matplotlib.figure import Figure -from sugarscape_g1mt.model import SugarscapeG1mt -from sugarscape_g1mt.trader_agents import Trader +from mesa.examples.advanced.sugarscape_g1mt.agents import Trader +from mesa.examples.advanced.sugarscape_g1mt.model import SugarscapeG1mt from mesa.visualization import SolaraViz, make_plot_measure @@ -57,6 +57,6 @@ def portray(g): model1, components=[SpaceDrawer, make_plot_measure(["Trader", "Price"])], name="Sugarscape {G1, M, T}", - play_interval=1500, + play_interval=150, ) page # noqa diff --git a/mesa/examples/basic/conways_game_of_life/app.py b/mesa/examples/basic/conways_game_of_life/app.py index 7a45125a30a..168681b7ba6 100644 --- a/mesa/examples/basic/conways_game_of_life/app.py +++ b/mesa/examples/basic/conways_game_of_life/app.py @@ -6,7 +6,17 @@ def agent_portrayal(agent): - return {"c": "white" if agent.state == 0 else "black", "marker": "s"} + return { + "color": "white" if agent.state == 0 else "black", + "marker": "s", + "size": 25, + } + + +def post_process(ax): + ax.set_aspect("equal") + ax.set_xticks([]) + ax.set_yticks([]) model_params = { @@ -22,7 +32,9 @@ def agent_portrayal(agent): # Under the hood these are just classes that receive the model instance. # You can also author your own visualization elements, which can also be functions # that receive the model instance and return a valid solara component. -SpaceGraph = make_space_component(agent_portrayal) +SpaceGraph = make_space_component( + agent_portrayal, post_process=post_process, draw_grid=False +) # Create the SolaraViz page. This will automatically create a server and display the diff --git a/mesa/examples/basic/schelling/app.py b/mesa/examples/basic/schelling/app.py index 53fab7ba0f0..86f5a2941fd 100644 --- a/mesa/examples/basic/schelling/app.py +++ b/mesa/examples/basic/schelling/app.py @@ -28,13 +28,13 @@ def agent_portrayal(agent): model1 = Schelling(20, 20, 0.8, 0.2, 3) -HappyPlot = make_plot_measure("happy") +HappyPlot = make_plot_measure({"happy": "tab:green"}) page = SolaraViz( model1, components=[ make_space_component(agent_portrayal), - make_plot_measure("happy"), + HappyPlot, get_happy_agents, ], model_params=model_params, diff --git a/mesa/examples/basic/virus_on_network/app.py b/mesa/examples/basic/virus_on_network/app.py index 7cf54f308d5..8e82a72830c 100644 --- a/mesa/examples/basic/virus_on_network/app.py +++ b/mesa/examples/basic/virus_on_network/app.py @@ -1,44 +1,27 @@ import math import solara -from matplotlib.figure import Figure -from matplotlib.ticker import MaxNLocator from mesa.examples.basic.virus_on_network.model import ( State, VirusOnNetwork, number_infected, ) -from mesa.visualization import Slider, SolaraViz, make_space_component - +from mesa.visualization import ( + Slider, + SolaraViz, + make_plot_measure, + make_space_component, +) -def agent_portrayal(graph): - def get_agent(node): - return graph.nodes[node]["agent"][0] - edge_width = [] - edge_color = [] - for u, v in graph.edges(): - agent1 = get_agent(u) - agent2 = get_agent(v) - w = 2 - ec = "#e8e8e8" - if State.RESISTANT in (agent1.state, agent2.state): - w = 3 - ec = "black" - edge_width.append(w) - edge_color.append(ec) +def agent_portrayal(agent): node_color_dict = { State.INFECTED: "tab:red", State.SUSCEPTIBLE: "tab:green", State.RESISTANT: "tab:gray", } - node_color = [node_color_dict[get_agent(node).state] for node in graph.nodes()] - return { - "width": edge_width, - "edge_color": edge_color, - "node_color": node_color, - } + return {"color": node_color_dict[agent.state], "size": 10} def get_resistant_susceptible_ratio(model): @@ -46,25 +29,9 @@ def get_resistant_susceptible_ratio(model): ratio_text = r"$\infty$" if ratio is math.inf else f"{ratio:.2f}" infected_text = str(number_infected(model)) - return f"Resistant/Susceptible Ratio: {ratio_text}
Infected Remaining: {infected_text}" - - -def make_plot(model): - # This is for the case when we want to plot multiple measures in 1 figure. - fig = Figure() - ax = fig.subplots() - measures = ["Infected", "Susceptible", "Resistant"] - colors = ["tab:red", "tab:green", "tab:gray"] - for i, m in enumerate(measures): - color = colors[i] - df = model.datacollector.get_model_vars_dataframe() - ax.plot(df.loc[:, m], label=m, color=color) - fig.legend() - # Set integer x axis - ax.xaxis.set_major_locator(MaxNLocator(integer=True)) - ax.set_xlabel("Step") - ax.set_ylabel("Number of Agents") - return solara.FigureMatplotlib(fig) + return solara.Markdown( + f"Resistant/Susceptible Ratio: {ratio_text}
Infected Remaining: {infected_text}" + ) model_params = { @@ -120,6 +87,9 @@ def make_plot(model): } SpacePlot = make_space_component(agent_portrayal) +StatePlot = make_plot_measure( + {"Infected": "tab:red", "Susceptible": "tab:green", "Resistant": "tab:gray"} +) model1 = VirusOnNetwork() @@ -127,8 +97,8 @@ def make_plot(model): model1, [ SpacePlot, - make_plot, - # get_resistant_susceptible_ratio, # TODO: Fix and uncomment + StatePlot, + get_resistant_susceptible_ratio, ], model_params=model_params, name="Virus Model",