Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stochastic Branch PR #202

Open
wants to merge 7 commits into
base: stochastic
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#
import os
import sys
sys.path.insert(0, os.path.abspath('../../'))
sys.path.insert(0, os.path.abspath('../'))
html_static_path = []


Expand Down
46 changes: 46 additions & 0 deletions neurodiffeq/conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1164,3 +1164,49 @@ def parameterize(self, output_tensor, r):
return self.R_0 * torch.exp(-self.order * dr) + \
self.R_inf * torch.tanh(dr) + \
torch.exp(-self.order * dr) * torch.tanh(dr) * output_tensor



class BrownianIVP(BaseCondition):
"""
Represents a Brownian initial value condition :math:`X_{t=0} = X_0`.

:param u_0: The initial value of :math:`X_t` at time :math:`t=0`.
"""

def __init__(self, u_0):
super().__init__()
self.X_0 = u_0

def parameterize(self, output_tensor, t, Bt):
"""
Re-parameterizes outputs such that the initial value condition is satisfied.

The re-parameterization is
:math:`u(t, B_t) = u_0 + (1-e^{-t}) \mathrm{ANN}(t, B_t)` and :math:`\mathrm{ANN}` is the neural network.

Why there is no condition on :math:`B_t`? Because the constraint :math:`B_0 = 0` is implemented in `BrownianGenerator()` and `UniBrownianGenerator()`.

:param output_tensor: The output tensor.
:type output_tensor: `torch.Tensor`
:param t: Input variable :math:`t` to the networks. It's time samples generated from generator.
:type t: `torch.Tensor`
:param Bt: Input variable :math:`B_t` to the networks. It's brownian motion samples generated from generator. This one is just a placeholder parameter and need not to be set mannualy because constraint on :math:`B_t` is already encoded in generator. We just need to set the impose on :math:`t` (0-th parameter).
:type Bt: `torch.Tensor`
:return: The parameterized tensor.
:rtype: `torch.Tensor`
"""
return self.X_0 + (1 - torch.exp(-t)) * output_tensor

def set_impose_on(self, ith_unit=0):
r"""**[DEPRECATED]** When training several functions with a single, multi-output network, this method is called
(by a `Solver` class or a `solve` function) to keep track of which output is being parameterized.

:param ith_unit: The index of network output to be parameterized.
:type ith_unit: int

.. note::
This method is deprecated and retained for backward compatibility only. Users interested in enforcing
conditions on multi-output networks should consider using a ``neurodiffeq.conditions.EnsembleCondition``.
"""
return super().set_impose_on(ith_unit)
93 changes: 93 additions & 0 deletions neurodiffeq/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,3 +1016,96 @@ def _internal_vars(self) -> dict:
generator=self.generator,
))
return d


class BrownianGenerator(BaseGenerator):
"""
A generator for generating time samples :math:`t` and standard Brownian motion samples :math:`W_t` where:

- :math:`t \sim U(0, T)`
- :math:`W_t \sim N(0, t)`
- :math:`Cov(W_s, W_t) = \min(s, t)`

:param size: The number of points to generate each time ``get_examples()`` is called. Defaults to 8.
:type size: int
:param T: The time length. Defaults to 1.0.
:type T: float
"""

def __init__(self, size=8, T=1.0):
"""
Initializes the BrownianGenerator with the given size and T.
"""
super(BrownianGenerator, self).__init__()
self.size = size
self.T = T

def get_examples(self):
"""
:returns: A tuple containing the sample of :math:`t` and sample of :math:`W_t`, both of shape (size,).
:rtype: tuple[torch.Tensor, torch.Tensor]
"""
t_sample = torch.rand(self.size, requires_grad=True) * self.T
t_np = t_sample.detach().cpu().numpy()
sigma = np.minimum.outer(t_np, t_np)
Wt_sample = torch.tensor(
np.linalg.cholesky(sigma) @ np.random.normal(size=len(t_np)),
requires_grad=True,
)

return t_sample, Wt_sample

def _internal_vars(self) -> dict:
d = super(BrownianGenerator, self)._internal_vars()
d.update(dict(size=self.size, T=self.T))
return d



class UniBrownianGenerator(BaseGenerator):
"""
Generate examples of unifomed :math:`t` and :math:`W_t`, where:

- :math:`t \sim U(0, T)`
- :math:`W_t \sim U(-3\sqrt{t}, 3\sqrt{t})`

This generator is used to generate samples uniformly within the 99.5% confidence interval, which can be used as training points when solving a SDE.

:param size: The number of examples to generate. Default is 8.
:type size: int
:param T: The time horizon. Default is 1.0.
:type T: float
"""

def __init__(self, size=8, T=1.0):
"""
Initializes the UniBrownianGenerator with the given size and T.
"""
super(UniBrownianGenerator, self).__init__()
self.size = size
self.T = T

def get_examples(self):
"""
:returns: A tuple containing the sample of :math:`t` and sample of :math:`W_t`, both of shape (size,).
:rtype: tuple[torch.Tensor, torch.Tensor]
"""
sample_t = torch.rand(self.size, requires_grad=True) * self.T
sample_Bt = (2 * torch.rand(self.size, requires_grad=True) - 1) * (
3 * np.sqrt(self.T)
)
filtered_t = sample_t[
(-3 * torch.sqrt(sample_t) < sample_Bt)
& (sample_Bt < 3 * torch.sqrt(sample_t))
]
filtered_Bt = sample_Bt[
(-3 * torch.sqrt(sample_t) < sample_Bt)
& (sample_Bt < 3 * torch.sqrt(sample_t))
]

return filtered_t, filtered_Bt

def _internal_vars(self) -> dict:
d = super(UniBrownianGenerator, self)._internal_vars()
d.update(dict(size=self.size, T=self.T))
return d
206 changes: 206 additions & 0 deletions neurodiffeq/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@
from .generators import Generator1D
from .generators import Generator2D
from .generators import GeneratorND
from .generators import BrownianGenerator
from .generators import UniBrownianGenerator
from .function_basis import RealSphericalHarmonics
from .conditions import BaseCondition
from .conditions import BrownianIVP
from .neurodiffeq import safe_diff as diff
from .losses import _losses

Expand Down Expand Up @@ -1590,3 +1593,206 @@ def _get_internal_variables(self):
'xy_max': self.xy_max,
})
return available_variables


class StochasticSolver1D(BaseSolver):
r"""A solver class for solving Stochastic Differential Equation (SDE), specifically Ito's process:

:math:`dX_t = \mu(X_t,t,B_t)dt + \sigma(X_t,t,B_t)dB_t` with initial value :math:`X_0`.

:param mu_func:
The drift function :math:`\mu(X_t,t,B_t)`. It should take in three arguments: X, t, B and returns the function value.

.. code-block:: python3

# example 1
lambda X, t, B: 1/2*X
# example 2
def my_mu_func(X, t, B):
return 1-X
:type mu_func: callable
:param sigma_func:
The diffusion function :math:`\sigma(X_t,t,B_t)`. It should take in three arguments: X, t, B and returns the function value.

.. code-block:: python3

# example 1
lambda X, t, B: 1/2*X
# example 2
def my_sigma_func(X, t, B):
return 1-X
:type sigma_func: callable
:param init_value:
The initial value of the Ito's process :math:`X_0`.
:type init_value: float
:param T:
The time span of the process we want to solve. The training points will be generated within :math:`[0,T]`
:type T: float
:param n_sample_train:
The number of training samples generated on time interval :math:`[0,T]`. Default to 512.
:type n_sample_train: int
:param n_sample_valid:
The number of validation samples generated on time interval :math:`[0,T]`. Default to 512.
:type n_sample_valid: int
:param nets:
List of neural networks for parameterized solution.
If provided, length of ``nets`` must equal that of ``conditions``
:type nets: list[torch.nn.Module], optional
:param train_generator:
Generator for sampling training points,
which must provide a ``.get_examples()`` method and a ``.size`` field.
``train_generator`` must be specified if ``T`` is not set. By default we use ``BrownianGenerator()``.
:type train_generator: `neurodiffeq.generators.BaseGenerator`, optional
:param valid_generator:
Generator for sampling validation points,
which must provide a ``.get_examples()`` method and a ``.size`` field.
``valid_generator`` must be specified if ``T`` is not set. By default we use ``UniBrownianGenerator()``.
:type valid_generator: `neurodiffeq.generators.BaseGenerator`, optional
:param analytic_solutions:
Analytical solutions to be compared with neural net solutions.
It maps a torch.Tensor to a tuple of function values.
Output shape should match that of ``nets``.
:type analytic_solutions: callable, optional
:param optimizer:
Optimizer to be used for training.
Defaults to a ``torch.optim.Adam`` instance that trains on all parameters of ``nets``.
:type optimizer: ``torch.nn.optim.Optimizer``, optional
:param loss_fn:
The loss function used for training.

- If a str, must be present in the keys of `neurodiffeq.losses._losses`.
- If a `torch.nn.modules.loss._Loss` instance, just pass the instance.
- If any other callable, it must map
A) a residual tensor (shape `(n_points, n_equations)`),
B) a function values tuple (length `n_funcs`, each element a tensor of shape `(n_points, 1)`), and
C) a coordinate values tuple (length `n_coords`, each element a tensor of shape `(n_coords, 1)`
to a tensor of empty shape (i.e. a scalar). The returned tensor must be connected to the computational graph,
so that backpropagation can be performed.

:type loss_fn:
str or `torch.nn.moduesl.loss._Loss` or callable
:param n_batches_train:
Number of batches to train in every epoch, where batch-size equals ``train_generator.size``.
Defaults to 1.
:type n_batches_train: int, optional
:param n_batches_valid:
Number of batches to validate in every epoch, where batch-size equals ``valid_generator.size``.
Defaults to 4.
:type n_batches_valid: int, optional
:param metrics:
Additional metrics to be logged (besides loss). ``metrics`` should be a dict where

- Keys are metric names (e.g. 'analytic_mse');
- Values are functions (callables) that computes the metric value.
These functions must accept the same input as the drift and diffusion function, i.e. X, t, B.

:type metrics: dict[str, callable], optional
:param batch_size:
**[DEPRECATED and IGNORED]**
Each batch will use all samples generated.
Please specify ``n_batches_train`` and ``n_batches_valid`` instead.
:type batch_size: int
:param shuffle:
**[DEPRECATED and IGNORED]**
Shuffling should be performed by generators.
:type shuffle: bool
"""

def __init__(
self,
mu_func,
sigma_func,
init_value,
T,
n_sample_train=512,
n_sample_valid=512,
nets=None,
train_generator=None,
valid_generator=None,
analytic_solutions=None,
optimizer=None,
loss_fn=None,
n_batches_train=1,
n_batches_valid=4,
metrics=None,
# deprecated arguments are listed below
batch_size=None,
shuffle=None,
):
# write in internal variable
self.init_value = init_value
self.T = T
self.mu_func = mu_func
self.sigma_func = sigma_func

# construct the sde system
sde_system = lambda X, t, B: [
diff(X, t, 1) + 1 / 2 * diff(X, B, 2) - mu_func(X, t, B),
diff(X, B, 1) - sigma_func(X, t, B),
]

# to formulate initial condition
mycondition = BrownianIVP(u_0=init_value)
mycondition.set_impose_on(0)
conditions = [mycondition]

super(StochasticSolver1D, self).__init__(
diff_eqs=sde_system,
conditions=conditions,
nets=nets,
train_generator=UniBrownianGenerator(size=n_sample_train, T=T),
valid_generator=BrownianGenerator(size=n_sample_valid, T=T),
analytic_solutions=analytic_solutions,
optimizer=optimizer,
loss_fn=loss_fn,
n_batches_train=n_batches_train,
n_batches_valid=n_batches_valid,
metrics=metrics,
n_input_units=2,
n_output_units=1,
shuffle=shuffle,
batch_size=batch_size,
)

def get_solution(self, copy=True, best=True):
r"""Get a (callable) solution object. See this usage example:

.. code-block:: python3

solution = solver.get_solution()
point_coords = train_generator.get_examples()
value_at_points = solution(point_coords)

:param copy:
Whether to make a copy of the networks so that subsequent training doesn't affect the solution;
Defaults to True.
:type copy: bool
:param best:
Whether to return the solution with lowest loss instead of the solution after the last epoch.
Defaults to True.
:type best: bool
:return:
A solution object which can be called.
To evaluate the solution on certain points,
you should pass the coordinates vector(s) to the returned solution.
:rtype: BaseSolution
"""
nets = self.best_nets if best else self.nets
conditions = self.conditions
if copy:
nets = deepcopy(nets)
conditions = deepcopy(conditions)

return Solution2D(nets, conditions)

def _get_internal_variables(self):
available_variables = super(StochasticSolver1D, self)._get_internal_variables()
available_variables.update(
{
"init_value": self.init_value,
"T": self.T,
"mu_func": self.mu_func,
"sigma_func": self.sigma_func,
}
)
return available_variables