Skip to content

Commit

Permalink
test: add a test file for model instantiation
Browse files Browse the repository at this point in the history
  • Loading branch information
jung235 committed Jun 13, 2024
1 parent 4fb21ef commit 0619e9c
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 20 deletions.
13 changes: 13 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import pytest

from pydiffuser.models import CONFIG_REGISTRY, MODEL_REGISTRY


@pytest.fixture
def config_dict():
return CONFIG_REGISTRY


@pytest.fixture
def model_dict():
return MODEL_REGISTRY
20 changes: 0 additions & 20 deletions tests/models/test_bm.py

This file was deleted.

41 changes: 41 additions & 0 deletions tests/models/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from inspect import signature

import pytest

from pydiffuser.models.core.base import MODEL_KWARGS


@pytest.mark.parametrize(
"name",
[
("abp"),
("aoup"),
("bm"),
("levy"),
("mips"),
("rtp"),
("smoluchowski"),
("vicsek"),
],
)
def test_models(config_dict, model_dict, name):
config = config_dict[name]()
model = model_dict[name].from_config(config=config) # instantiation from config
ens = model.generate()
assert ens.microstate.shape == (config.realization, config.length, config.dimension)

ens = model.generate(realization=5, length=50, dimension=2)
with pytest.raises(AssertionError):
assert ens.microstate.shape == (5, 50, 2)

kwargs = config_dict[name].show_fields()
params = list(signature(model_dict[name].__init__).parameters.keys())
args = (
[kwargs[param] for param in params[1:]]
if MODEL_KWARGS not in params
else ([kwargs[param] for param in params[1:-1]])
)
model_v2 = model_dict[name](*args) # instantiation without config

ens_v2 = model_v2.generate(realization=5, length=50, dimension=2)
assert ens_v2.microstate.shape == (5, 50, 2)

0 comments on commit 0619e9c

Please sign in to comment.