Skip to content

Commit

Permalink
Merge pull request #13 from jung235/feat/alter-ctrw-strategy
Browse files Browse the repository at this point in the history
feat: alter the simulation strategy for CTRW
  • Loading branch information
jung235 authored Jul 30, 2024
2 parents 5333005 + d54b807 commit 8d74f27
Show file tree
Hide file tree
Showing 11 changed files with 207 additions and 42 deletions.
3 changes: 2 additions & 1 deletion pydiffuser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
VicsekModelConfig,
)
from .tracer import Ensemble, Trajectory
from .utils import load, save
from .utils import load, random, save

__all__ = [
"__version__",
Expand All @@ -41,5 +41,6 @@
"Ensemble",
"Trajectory",
"load",
"random",
"save",
]
9 changes: 2 additions & 7 deletions pydiffuser/models/bm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import jax.numpy as jnp
from scipy.stats import norm

from pydiffuser.logger import logger
from pydiffuser.models.core import BaseDiffusion
from pydiffuser.tracer import Ensemble, Trajectory
from pydiffuser.utils import BaseDiffusionConfig, jitted
from pydiffuser.utils import BaseDiffusionConfig, helpers, jitted


class BrownianMotionConfig(BaseDiffusionConfig):
Expand Down Expand Up @@ -49,14 +48,10 @@ def generate(
ens.update_microstate(microstate=x)
return ens

@helpers.deprecated
def create(
self, realization: int, length: int, dimension: int, dt: float
) -> Ensemble:
logger.warning(
"The method `create` is deprecated. "
"Please use `generate` instead, which utilizes batch calculation.",
stacklevel=2,
)
import numpy as np

ens = Ensemble(dt=dt)
Expand Down
75 changes: 62 additions & 13 deletions pydiffuser/models/core/ctrw.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from functools import partial
from typing import Any, Callable, Dict, Tuple

import jax.numpy as jnp
import numpy as np
from jax import Array
from numpy.random import rand, randint

from pydiffuser.exceptions import MemoryAllocationError, ShapeMismatchError
from pydiffuser.models.core import BaseDiffusion
from pydiffuser.tracer import Ensemble
from pydiffuser.typing import ConstType
from pydiffuser.utils import BaseDiffusionConfig, jitted
from pydiffuser.typing import ConstType, LongLongPosType
from pydiffuser.utils import BaseDiffusionConfig, helpers, jitted


class ContinuousTimeRandomWalkConfig(BaseDiffusionConfig):
Expand All @@ -23,6 +25,7 @@ class ContinuousTimeRandomWalk(BaseDiffusion):

def __init__(self) -> None:
super(ContinuousTimeRandomWalk, self).__init__()
self.precision_x64 = True

@property
def one_step(self) -> ConstType:
Expand All @@ -41,21 +44,30 @@ def generate(
ens = super().generate(realization, length, dimension, dt, **generate_kwargs)
realization, length, dimension, _ = list(self.generate_info.values())[:4]

t = self.get_time_steps()
if t.shape[0] != realization or t.ndim != 2:
raise ShapeMismatchError("(`realization`, -1) is required")
interarrival: Dict[int, Array] = self.get_interarrival_times()
ens.update_meta_dict(item={"interarrival": interarrival})
return ens

u = self.get_orientation_steps(realization, capacity=t.shape[1])
u = jnp.stack(
arrays=[jnp.repeat(_u, _t, axis=0) for _u, _t in zip(u, t, strict=True)],
axis=0,
)
u = u[:, : length - 1] # realization x (length - 1) x -1
def get_interarrival_times(self) -> Dict[int, Array]:
raise NotImplementedError

def _get_interarrival_times_per_tracer(
self, generator: Callable[[int], Any], key: Array
) -> Tuple[Array, Array]:
_, length, _, dt = list(self.generate_info.values())[:4]

tau = jitted.get_noise(generator=partial(generator, key=key), size=length)
# make it countable, and use `int64` to prevent overflow
tau = jnp.array(jnp.round(tau, -int(np.log10(dt))) / dt, dtype=jnp.int64)
cutoff = jnp.argmax(jnp.cumsum(tau) >= length)
return tau, cutoff

def _get_microstate_from_orientation(self, u: Array) -> LongLongPosType:
x = self._get_initial_position()
if dimension == 1:
dim = x.shape[-1]
if dim == 1:
dx = self.one_step * jnp.cos(u)
elif dimension == 2:
elif dim == 2:
dx = jnp.concatenate(jitted.polar_to_cartesian(self.one_step, u), axis=2)
else:
dx = jnp.concatenate(
Expand All @@ -64,6 +76,31 @@ def generate(
)
x = jnp.concatenate((x, dx), axis=1)
x = jnp.cumsum(x, axis=1)
return x

@helpers.deprecated
def create(
self,
realization: int = 10,
length: int = 1000,
dimension: int = 2,
dt: float = 1.0,
**generate_kwargs,
) -> Ensemble:
ens = super().generate(realization, length, dimension, dt, **generate_kwargs)
realization, length, dimension, _ = list(self.generate_info.values())[:4]

t = self.get_time_steps()
if t.shape[0] != realization or t.ndim != 2:
raise ShapeMismatchError("(`realization`, -1) is required")

u = self.get_orientation_steps(realization, capacity=t.shape[1])
u = jnp.stack(
arrays=[jnp.repeat(ui, ti, axis=0) for ui, ti in zip(u, t, strict=True)],
axis=0,
)
u = u[:, : length - 1] # realization x (length - 1) x -1
x = self._get_microstate_from_orientation(u)
ens.update_microstate(microstate=x)
return ens

Expand All @@ -81,6 +118,18 @@ def get_orientation_steps(self, realization: int, capacity: int) -> Array:
u = u.at[:, :, -1].multiply(2) if dim == 3 else u
return u

@staticmethod
def repeat(
arr: Array, repeats_info: Dict[int, Array], cutoff: ConstType | None = None
) -> Array:
arrs = []
for id, ti in repeats_info.items():
arrs.append(jnp.repeat(arr[id, : len(ti)], repeats=ti, axis=0))
arr = jnp.stack(arrs, axis=0)
if cutoff:
arr = arr[:, :cutoff] # realization x cutoff x -1
return arr

@staticmethod
def slice(arr: Array, threshold: ConstType, axis: int = 0) -> Array:
"""For a 2darray with shape `realization` x -1 (here, `arr.T`),
Expand Down
52 changes: 50 additions & 2 deletions pydiffuser/models/levy.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
from functools import partial
from typing import Dict

import jax
import jax.numpy as jnp
from jax import Array
from jax import Array, vmap
from jax.random import PRNGKey
from numpy.random import randint
from scipy.stats import pareto

from pydiffuser.models.core import (
ContinuousTimeRandomWalk,
ContinuousTimeRandomWalkConfig,
)
from pydiffuser.tracer import Ensemble
from pydiffuser.typing import ConstType
from pydiffuser.utils import jitted

Expand All @@ -21,7 +26,8 @@ def __init__(self, speed: float = 1.0, exponent: float = 1.5, **kwargs):
Args:
speed (float): A constant speed.
exponent (float): The positive scaling exponent `b` of pareto distribution
in https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.pareto.html.
in https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.pareto.html
or https://jax.readthedocs.io/en/latest/_autosummary/jax.random.pareto.html.
"""

super(LevyWalkConfig, self).__init__(**kwargs)
Expand All @@ -41,6 +47,43 @@ def __init__(self, speed: float, exponent: float):
def one_step(self) -> ConstType:
return self.speed * self.generate_info["dt"]

def generate(
self,
realization: int = 10,
length: int = 1000,
dimension: int = 2,
dt: float = 1.0,
**generate_kwargs,
) -> Ensemble:
ens = super().generate(realization, length, dimension, dt, **generate_kwargs)
realization, length, dimension, _ = list(self.generate_info.values())[:4]
interarrival: Dict[int, Array] = ens.meta["interarrival"]

u = self.get_orientation_steps(realization, capacity=length)
u = self.repeat(
u, interarrival, cutoff=length - 1 # realization x (length - 1) x -1
)
x = self._get_microstate_from_orientation(u)
ens.update_microstate(microstate=x)
return ens

def get_interarrival_times(self) -> Dict[int, Array]:
realization, length, _, _ = self.generate_info.values()

generator = partial(jax.random.pareto, b=self.exponent)
keys = jax.random.split(PRNGKey(randint(0, 1e9)), num=realization)
vmap_fn = vmap(self._get_interarrival_times_per_tracer, (None, 0))
taus, cutoffs = vmap_fn(generator, keys)

interarrival = {}
for id, (tau, cutoff) in enumerate(zip(taus, cutoffs, strict=True)):
if cutoff > 0:
arr = tau[:cutoff]
interarrival[id] = jnp.append(arr, length - jnp.sum(arr))
else:
interarrival[id] = jnp.array([length], dtype=int)
return interarrival

def get_time_steps(self) -> Array:
realization, length, _, dt = self.generate_info.values()

Expand All @@ -53,5 +96,10 @@ def get_time_steps(self) -> Array:
shape=(-1, realization),
)
tau = jnp.array(jnp.round(tau, -int(jnp.log10(dt))) / dt) # countable
if jnp.any(jnp.isinf(tau)):
raise ValueError(
f"`tau` contains infinite values for exponent={self.exponent}"
)

tau = self.slice(arr=tau.astype(int), threshold=length - 1).T
return tau
2 changes: 0 additions & 2 deletions pydiffuser/models/mips.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,6 @@ def generate(
x = jnp.concatenate((x, jnp.transpose(stx, (1, 0, 2))), axis=1)

ens.update_microstate(microstate=x)
for id in range(realization):
ens[id].update_meta_dict(item={"diameter": sigma[id]})
ens.update_meta_dict(item={"diameter": sigma})
return ens

Expand Down
48 changes: 45 additions & 3 deletions pydiffuser/models/rtp.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
from functools import partial
from typing import Dict

import jax
import jax.numpy as jnp
from jax import Array
from jax import Array, vmap
from jax.random import PRNGKey
from numpy.random import randint
from scipy.stats import expon

from pydiffuser.logger import logger
from pydiffuser.models.core import (
ContinuousTimeRandomWalk,
ContinuousTimeRandomWalkConfig,
)
from pydiffuser.tracer import Ensemble
from pydiffuser.typing import ConstType
from pydiffuser.utils import jitted
from pydiffuser.utils import jitted, random


class RunAndTumbleParticleConfig(ContinuousTimeRandomWalkConfig):
Expand Down Expand Up @@ -47,10 +52,47 @@ def pre_generate(self, *generate_args) -> None:
logger.warning("The precision is significantly low.", stacklevel=2)
return

def generate(
self,
realization: int = 10,
length: int = 1000,
dimension: int = 2,
dt: float = 1.0,
**generate_kwargs,
) -> Ensemble:
ens = super().generate(realization, length, dimension, dt, **generate_kwargs)
realization, length, dimension, _ = list(self.generate_info.values())[:4]
interarrival: Dict[int, Array] = ens.meta["interarrival"]

u = self.get_orientation_steps(realization, capacity=length)
u = self.repeat(
u, interarrival, cutoff=length - 1 # realization x (length - 1) x -1
)
x = self._get_microstate_from_orientation(u)
ens.update_microstate(microstate=x)
return ens

def get_interarrival_times(self) -> Dict[int, Array]:
realization, length, _, _ = self.generate_info.values()

generator = partial(random.exponential, scale=self.tc)
keys = jax.random.split(PRNGKey(randint(0, 1e9)), num=realization)
vmap_fn = vmap(self._get_interarrival_times_per_tracer, (None, 0))
taus, cutoffs = vmap_fn(generator, keys)

interarrival = {}
for id, (tau, cutoff) in enumerate(zip(taus, cutoffs, strict=True)):
if cutoff > 0:
arr = tau[:cutoff]
interarrival[id] = jnp.append(arr, length - jnp.sum(arr))
else:
interarrival[id] = jnp.array([length], dtype=int)
return interarrival

def get_time_steps(self) -> Array:
realization, length, _, dt = self.generate_info.values()

num_runs = int(length / self.tc * 2) if self.rate < 1 else length # TODO
num_runs = int(length / self.tc * 2) if self.rate < 1 else length
tau = jitted.get_noise(
generator=partial(expon.rvs, scale=self.tc),
size=realization * num_runs,
Expand Down
2 changes: 0 additions & 2 deletions pydiffuser/models/vicsek.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,6 @@ def generate(
x = jnp.concatenate((x, jnp.transpose(stx_x, (1, 0, 2))), axis=1)

ens.update_microstate(microstate=x)
for id in range(realization):
ens[id].update_meta_dict(item={"direction": phi[id]})
ens.update_meta_dict(item={"direction": phi})
return ens

Expand Down
14 changes: 14 additions & 0 deletions pydiffuser/utils/helpers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import warnings
from dataclasses import dataclass
from typing import Callable, Optional

Expand Down Expand Up @@ -54,3 +55,16 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
return wrapper

return decorator


def deprecated(fn: Callable[P, T]) -> Callable[P, T]:
@functools.wraps(fn)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
warnings.warn(
f"The callable `{fn.__name__}` in `{fn.__module__}` is deprecated.",
category=DeprecationWarning,
stacklevel=2,
)
return fn(*args, **kwargs)

return wrapper
17 changes: 9 additions & 8 deletions pydiffuser/utils/jitted.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import os
import time
from functools import partial
from typing import Any, Callable, Tuple

import jax
import jax.numpy as jnp
from jax import Array, jit
from jax.numpy import linalg as jLA
from jaxlib.xla_extension import PjitFunction

from pydiffuser.typing import ArrayType, ConstType

Expand Down Expand Up @@ -54,12 +52,15 @@ def get_noise(
):
noise = jnp.array(generator(size))

elif (hasattr(generator, "__module__") and "jax" in generator.__module__) or (
isinstance(generator, partial) and "jax" in generator.func.__module__
elif (
(hasattr(generator, "__module__") and "jax" in generator.__module__)
or (isinstance(generator, partial) and "jax" in generator.func.__module__)
or (
isinstance(generator, partial)
and isinstance(generator.func, PjitFunction)
)
):
seed = int(time.time()) + os.getpid()
key = jax.random.PRNGKey(seed)
noise = generator(key, shape=(size,)) # type: ignore[call-arg]
noise = generator(shape=(size,)) # type: ignore[call-arg]

else:
raise RuntimeError(f"{exc}") from exc
Expand Down
Loading

0 comments on commit 8d74f27

Please sign in to comment.