Skip to content

Commit

Permalink
test: add a test file for model instantiation (#12)
Browse files Browse the repository at this point in the history
* test: add a test file for model instantiation

* feat: minor updates

* fix

* fix
  • Loading branch information
jung235 authored Jun 23, 2024
1 parent 4fb21ef commit 05a5381
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 28 deletions.
1 change: 1 addition & 0 deletions pydiffuser/models/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def pre_generate(self, *generate_args) -> None:
"The calculation will be significantly slower than with non-interacting particles."
)

jax.config.update("jax_platform_name", "cpu") # TODO
jax.config.update("jax_enable_x64", self.precision_x64)
if self.precision_x64:
logger.debug(
Expand Down
7 changes: 7 additions & 0 deletions pydiffuser/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def __init__(

@classmethod
def show_fields(cls) -> Dict[str, Any]:
"""Show the variables of the class constructor."""

return copy.deepcopy(vars(cls()))

@classmethod
Expand All @@ -54,6 +56,8 @@ def to_dataclass(self) -> object:
return obj(**info)

def to_dict(self) -> Dict[str, Any]:
"""Show the variables of the instance of the class."""

info = copy.deepcopy(vars(self))
return info

Expand All @@ -62,3 +66,6 @@ def __repr__(self) -> str:

def __contains__(self, param: str) -> bool:
return param in self.to_dict()

def __eq__(self, config: object) -> bool:
return self.to_dict() == config.to_dict() # type: ignore
15 changes: 7 additions & 8 deletions pydiffuser/utils/jitted.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,15 @@ def get_noise(
try:
noise = generator(size=size) # type: ignore[call-arg]
except Exception as exc:
if isinstance(exc, TypeError):
module = inspect.getmodule(generator)
if module is None:
noise = generator(size)
elif "jax" in module.__name__:
raise NotImplementedError(
"Random number generator via JAX is unsupported"
) from exc
if "rand" in generator.__name__:
noise = generator(size)
elif "jax" in inspect.getmodule(generator).__name__: # type: ignore[union-attr]
raise NotImplementedError(
"Random number generator via JAX is unsupported"
) from exc
else:
raise RuntimeError(f"{exc}") from exc

noise = jnp.array(noise)
if shape is not None:
noise = noise.reshape(shape)
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ disallow_incomplete_defs = false
module = "pydiffuser._cli.cli"
ignore_errors = true

[[tool.mypy.overrides]]
module = "pydiffuser.models.mips"
ignore_errors = true

[tool.ruff]
select = [
"E", # pycodestyle
Expand Down
13 changes: 13 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import pytest

from pydiffuser.models import CONFIG_REGISTRY, MODEL_REGISTRY


@pytest.fixture
def config_dict():
return CONFIG_REGISTRY


@pytest.fixture
def model_dict():
return MODEL_REGISTRY
20 changes: 0 additions & 20 deletions tests/models/test_bm.py

This file was deleted.

43 changes: 43 additions & 0 deletions tests/models/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from inspect import signature

import pytest

from pydiffuser.models.core.base import MODEL_KWARGS


@pytest.mark.parametrize(
"name",
[
("abp"),
("aoup"),
("bm"),
("levy"),
("mips"),
("rtp"),
("smoluchowski"),
("vicsek"),
],
)
def test_models(config_dict, model_dict, name):
config = config_dict[name]()
model = model_dict[name].from_config(config=config) # instantiation from config
ens = model.generate()
shape = (config.realization, config.length, config.dimension)
assert ens.microstate.shape == shape

ens = model.generate(realization=5, length=50, dimension=2)
with pytest.raises(AssertionError):
assert ens.microstate.shape == (5, 50, 2)
assert ens.microstate.shape == shape

kwargs = config_dict[name].show_fields()
params = list(signature(model_dict[name].__init__).parameters.keys())
args = (
[kwargs[param] for param in params[1:]]
if MODEL_KWARGS not in params
else ([kwargs[param] for param in params[1:-1]])
)
model_v2 = model_dict[name](*args) # instantiation without config

ens_v2 = model_v2.generate(realization=5, length=50, dimension=2)
assert ens_v2.microstate.shape == (5, 50, 2)

0 comments on commit 05a5381

Please sign in to comment.