Skip to content

Commit

Permalink
test: add test code (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
jung235 authored Oct 18, 2023
1 parent 35b3a66 commit 1e3f986
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 3 deletions.
20 changes: 20 additions & 0 deletions tests/models/test_bm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pytest

from pydiffuser.models.bm import BrownianMotion, BrownianMotionConfig


def test_bm():
model = BrownianMotion()
model.generate(dimension=3)
_, _, dim, _ = model.generate_info.values()
assert dim == 3

config = BrownianMotionConfig(dimension=1)
assert config.name == model.name

model_v2 = BrownianMotion.from_config(config)
model_v2.generate(dimension=3)
_, _, dim, _ = model_v2.generate_info.values()
with pytest.raises(AssertionError):
assert dim == 3
assert dim == 1
18 changes: 18 additions & 0 deletions tests/tracer/test_ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import jax.numpy as jnp
import pytest

from pydiffuser.tracer.ensemble import Ensemble
from pydiffuser.tracer.trajectory import Trajectory


def test_ensemble():
ens = Ensemble(dt=1.0)

tracer = Trajectory(dt=0.1, position_x1=[0, 1, 2], position_x2=[0, -1, -2])
with pytest.raises(AssertionError):
ens.add(tracer)

tracer = Trajectory(dt=1.0, position_x1=[0, 1, 2], position_x2=[0, -1, -2])
ens.add(tracer)
microstate = jnp.array([[[3, -3], [4, -4], [5, -5]]])
ens.update_microstate(microstate)
11 changes: 8 additions & 3 deletions tests/utils/test_jitted.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@


def test_normalize():
arr = jnp.array([3, 4])
arr = jnp.array([[3, 4]])
normed_arr = normalize(arr=arr)
assert normed_arr[0] == 0.6
assert normed_arr[1] == 0.8
expected_arr = jnp.array([[0.6, 0.8]])
assert jnp.all(normed_arr == expected_arr)

arr = jnp.array([[3], [4]])
normed_arr = normalize(arr=arr, axis=0)
expected_arr = jnp.array([[0.6], [0.8]])
assert jnp.all(normed_arr == expected_arr)

0 comments on commit 1e3f986

Please sign in to comment.