Skip to content

Commit

Permalink
update check_logic, remove temp changes, add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Corvince committed Nov 5, 2024
1 parent 54cc6c2 commit 7bc340d
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 11 deletions.
2 changes: 1 addition & 1 deletion mesa/examples/basic/boltzmann_wealth_model/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def post_process(ax):


# Create initial model instance
model1 = BoltzmannWealthModel(50, 20, 20)
model1 = BoltzmannWealthModel(50, 10, 10)

# Create visualization elements. The visualization elements are solara components
# that receive the model instance as a "prop" and display it in a certain way.
Expand Down
2 changes: 1 addition & 1 deletion mesa/examples/basic/boltzmann_wealth_model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class BoltzmannWealthModel(mesa.Model):
highly skewed distribution of wealth.
"""

def __init__(self, seed, n=100, width=10, height=10):
def __init__(self, hans=213, n=100, width=10, height=10, seed=None):
super().__init__(seed=seed)
self.num_agents = n
self.grid = mesa.space.MultiGrid(width, height, True)
Expand Down
37 changes: 29 additions & 8 deletions mesa/visualization/solara_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,6 @@ def SolaraViz(
if not isinstance(model, solara.Reactive):
model = solara.use_reactive(model) # noqa: SH102, RUF100

model_class_arguments = inspect.signature(model.value.__class__).parameters
for k in model_class_arguments:
if k not in model_params:
if model_class_arguments[k].default == inspect.Parameter.empty:
print(f"Missing parameter: {k}")
return solara.Text(f"Missing parameter: {k}")

def connect_to_model():
# Patch the step function to force updates
original_step = model.value.step
Expand Down Expand Up @@ -307,6 +300,12 @@ def ModelCreator(model, model_params, seed=1):
- The component provides an interface for adjusting user-defined parameters and reseeding the model.
"""

solara.use_effect(
lambda: _check_model_params(model.value.__class__.__init__, fixed_params),
[model.value],
)

user_params, fixed_params = split_model_params(model_params)

model_parameters, set_model_parameters = solara.use_state(
Expand All @@ -317,14 +316,36 @@ def ModelCreator(model, model_params, seed=1):
)

def on_change(name, value):
print(f"Setting {name} to {value}")
new_model_parameters = {**model_parameters, name: value}
model.value = model.value.__class__(**new_model_parameters)
set_model_parameters(new_model_parameters)

UserInputs(user_params, on_change=on_change)


def _check_model_params(init_func, model_params):
"""Check if model parameters are valid for the model's initialization function.
Args:
init_func: Model initialization function
model_params: Dictionary of model parameters
Raises:
ValueError: If a parameter is not valid for the model's initialization function
"""
model_parameters = inspect.signature(init_func).parameters
for name in model_parameters:
if (
model_parameters[name].default == inspect.Parameter.empty
and name not in model_params
and name != "self"
):
raise ValueError(f"Missing required model parameter: {name}")
for name in model_params:
if name not in model_parameters:
raise ValueError(f"Invalid model parameter: {name}")


@solara.component
def UserInputs(user_params, on_change=None):
"""Initialize user inputs for configurable model parameters.
Expand Down
46 changes: 45 additions & 1 deletion tests/test_solara_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,19 @@
import unittest

import ipyvuetify as vw
import pytest
import solara

import mesa
import mesa.visualization.components.altair
import mesa.visualization.components.matplotlib
from mesa.visualization.components.matplotlib import make_space_component
from mesa.visualization.solara_viz import Slider, SolaraViz, UserInputs
from mesa.visualization.solara_viz import (
Slider,
SolaraViz,
UserInputs,
_check_model_params,
)


class TestMakeUserInput(unittest.TestCase): # noqa: D101
Expand Down Expand Up @@ -148,3 +154,41 @@ def test_slider(): # noqa: D103
assert not slider_int.is_float_slider
slider_dtype_float = Slider("Homophily", 3, 0, 8, 1, dtype=float)
assert slider_dtype_float.is_float_slider


def test_model_param_checks():
class ModelWithOptionalParams:
def __init__(self, required_param, optional_param=10):
pass

class ModelWithOnlyRequired:
def __init__(self, param1, param2):
pass

# Test that optional params can be omitted
_check_model_params(ModelWithOptionalParams.__init__, {"required_param": 1})

# Test that optional params can be provided
_check_model_params(
ModelWithOptionalParams.__init__, {"required_param": 1, "optional_param": 5}
)

# Test invalid parameter name raises ValueError
with pytest.raises(ValueError, match="Invalid model parameter: invalid_param"):
_check_model_params(
ModelWithOptionalParams.__init__, {"required_param": 1, "invalid_param": 2}
)

# Test missing required parameter raises ValueError
with pytest.raises(ValueError, match="Missing required model parameter: param2"):
_check_model_params(ModelWithOnlyRequired.__init__, {"param1": 1})

# Test passing extra parameters raises ValueError
with pytest.raises(ValueError, match="Invalid model parameter: extra"):
_check_model_params(
ModelWithOnlyRequired.__init__, {"param1": 1, "param2": 2, "extra": 3}
)

# Test empty params dict raises ValueError if required params
with pytest.raises(ValueError, match="Missing required model parameter"):
_check_model_params(ModelWithOnlyRequired.__init__, {})

0 comments on commit 7bc340d

Please sign in to comment.