Skip to content

Commit

Permalink
feat: minor updates
Browse files Browse the repository at this point in the history
  • Loading branch information
jung235 committed Jun 23, 2024
1 parent 0619e9c commit ef1feb9
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 1 deletion.
1 change: 1 addition & 0 deletions pydiffuser/models/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def pre_generate(self, *generate_args) -> None:
"The calculation will be significantly slower than with non-interacting particles."
)

jax.config.update("jax_platform_name", "cpu") # TODO
jax.config.update("jax_enable_x64", self.precision_x64)
if self.precision_x64:
logger.debug(
Expand Down
7 changes: 7 additions & 0 deletions pydiffuser/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def __init__(

@classmethod
def show_fields(cls) -> Dict[str, Any]:
"""Show the variables of the class constructor."""

return copy.deepcopy(vars(cls()))

@classmethod
Expand All @@ -54,6 +56,8 @@ def to_dataclass(self) -> object:
return obj(**info)

def to_dict(self) -> Dict[str, Any]:
"""Show the variables of the instance of the class."""

info = copy.deepcopy(vars(self))
return info

Expand All @@ -62,3 +66,6 @@ def __repr__(self) -> str:

def __contains__(self, param: str) -> bool:
return param in self.to_dict()

def __eq__(self, config: object) -> bool:
return self.to_dict() == config.to_dict() # type: ignore
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ disallow_incomplete_defs = false
module = "pydiffuser._cli.cli"
ignore_errors = true

[[tool.mypy.overrides]]
module = "pydiffuser.models.mips"
ignore_errors = true

[tool.ruff]
select = [
"E", # pycodestyle
Expand Down
4 changes: 3 additions & 1 deletion tests/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@ def test_models(config_dict, model_dict, name):
config = config_dict[name]()
model = model_dict[name].from_config(config=config) # instantiation from config
ens = model.generate()
assert ens.microstate.shape == (config.realization, config.length, config.dimension)
shape = (config.realization, config.length, config.dimension)
assert ens.microstate.shape == shape

ens = model.generate(realization=5, length=50, dimension=2)
with pytest.raises(AssertionError):
assert ens.microstate.shape == (5, 50, 2)
assert ens.microstate.shape == shape

kwargs = config_dict[name].show_fields()
params = list(signature(model_dict[name].__init__).parameters.keys())
Expand Down

0 comments on commit ef1feb9

Please sign in to comment.