diff --git a/pydiffuser/models/core/base.py b/pydiffuser/models/core/base.py index c82b2dd..9967c12 100644 --- a/pydiffuser/models/core/base.py +++ b/pydiffuser/models/core/base.py @@ -85,6 +85,7 @@ def pre_generate(self, *generate_args) -> None: "The calculation will be significantly slower than with non-interacting particles." ) + jax.config.update("jax_platform_name", "cpu") # TODO jax.config.update("jax_enable_x64", self.precision_x64) if self.precision_x64: logger.debug( diff --git a/pydiffuser/utils/config.py b/pydiffuser/utils/config.py index 7312f52..61177f5 100644 --- a/pydiffuser/utils/config.py +++ b/pydiffuser/utils/config.py @@ -35,6 +35,8 @@ def __init__( @classmethod def show_fields(cls) -> Dict[str, Any]: + """Show the variables of the class constructor.""" + return copy.deepcopy(vars(cls())) @classmethod @@ -54,6 +56,8 @@ def to_dataclass(self) -> object: return obj(**info) def to_dict(self) -> Dict[str, Any]: + """Show the variables of the instance of the class.""" + info = copy.deepcopy(vars(self)) return info @@ -62,3 +66,6 @@ def __repr__(self) -> str: def __contains__(self, param: str) -> bool: return param in self.to_dict() + + def __eq__(self, config: object) -> bool: + return self.to_dict() == config.to_dict() # type: ignore diff --git a/pyproject.toml b/pyproject.toml index f0c31f3..b19df9e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,10 @@ disallow_incomplete_defs = false module = "pydiffuser._cli.cli" ignore_errors = true +[[tool.mypy.overrides]] +module = "pydiffuser.models.mips" +ignore_errors = true + [tool.ruff] select = [ "E", # pycodestyle diff --git a/tests/models/test_models.py b/tests/models/test_models.py index ff8191b..4335563 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -22,11 +22,13 @@ def test_models(config_dict, model_dict, name): config = config_dict[name]() model = model_dict[name].from_config(config=config) # instantiation from config ens = model.generate() - assert ens.microstate.shape == (config.realization, config.length, config.dimension) + shape = (config.realization, config.length, config.dimension) + assert ens.microstate.shape == shape ens = model.generate(realization=5, length=50, dimension=2) with pytest.raises(AssertionError): assert ens.microstate.shape == (5, 50, 2) + assert ens.microstate.shape == shape kwargs = config_dict[name].show_fields() params = list(signature(model_dict[name].__init__).parameters.keys())