Skip to content

Commit

Permalink
test: test get_noise
Browse files Browse the repository at this point in the history
  • Loading branch information
jung235 committed Jun 30, 2024
1 parent 05a5381 commit 5333005
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 21 deletions.
31 changes: 19 additions & 12 deletions pydiffuser/utils/jitted.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import inspect
import os
import time
from functools import partial
from typing import Any, Callable, Optional, Tuple
from typing import Any, Callable, Tuple

import jax
import jax.numpy as jnp
from jax import Array, jit
from jax.numpy import linalg as jLA
Expand Down Expand Up @@ -41,21 +43,26 @@ def spherical_to_cartesian(
def get_noise(
generator: Callable[[int], Any],
size: int,
shape: Optional[Tuple[int, ...]] = None,
shape: Tuple[int, ...] | None = None,
) -> Array:
try:
noise = generator(size=size) # type: ignore[call-arg]
noise = jnp.array(generator(size=size)) # type: ignore[call-arg]

except Exception as exc:
if "rand" in generator.__name__:
noise = generator(size)
elif "jax" in inspect.getmodule(generator).__name__: # type: ignore[union-attr]
raise NotImplementedError(
"Random number generator via JAX is unsupported"
) from exc
if (hasattr(generator, "__name__") and "rand" in generator.__name__) or (
isinstance(generator, partial) and "rand" in generator.func.__name__
):
noise = jnp.array(generator(size))

elif (hasattr(generator, "__module__") and "jax" in generator.__module__) or (
isinstance(generator, partial) and "jax" in generator.func.__module__
):
seed = int(time.time()) + os.getpid()
key = jax.random.PRNGKey(seed)
noise = generator(key, shape=(size,)) # type: ignore[call-arg]

else:
raise RuntimeError(f"{exc}") from exc

noise = jnp.array(noise)
if shape is not None:
noise = noise.reshape(shape)
return noise
16 changes: 8 additions & 8 deletions tests/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
@pytest.mark.parametrize(
"name",
[
("abp"),
("aoup"),
("bm"),
("levy"),
("mips"),
("rtp"),
("smoluchowski"),
("vicsek"),
"abp",
"aoup",
"bm",
"levy",
"mips",
"rtp",
"smoluchowski",
"vicsek",
],
)
def test_models(config_dict, model_dict, name):
Expand Down
29 changes: 28 additions & 1 deletion tests/utils/test_jitted.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
from functools import partial

import jax
import jax.numpy as jnp
import pytest
from numpy.random import rand, randint, uniform
from scipy.stats import expon, norm, pareto

from pydiffuser.utils.jitted import normalize
from pydiffuser.utils.jitted import get_noise, normalize


def test_normalize():
Expand All @@ -13,3 +19,24 @@ def test_normalize():
normed_arr = normalize(arr=arr, axis=0)
expected_arr = jnp.array([[0.6], [0.8]])
assert jnp.all(normed_arr == expected_arr)


@pytest.mark.parametrize(
"generator",
[
rand,
partial(rand),
partial(randint, 0, 2),
partial(uniform, -1, 1),
partial(expon.rvs, scale=10),
norm.rvs,
partial(pareto.rvs, b=1.5),
jax.random.uniform,
jax.random.exponential,
jax.random.normal,
partial(jax.random.pareto, b=1.5),
],
)
def test_get_noise(generator):
p = get_noise(generator=generator, size=500, shape=(5, 50, 2))
assert p.shape == (5, 50, 2)

0 comments on commit 5333005

Please sign in to comment.