Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

docs: add _BundleConditionMixin in documentation #217

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ API Reference
:show-inheritance:
:inherited-members:
:members:
:private-members:

`neurodiffeq.solvers`
------------------------------------------------------
Expand Down
38 changes: 19 additions & 19 deletions neurodiffeq/conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,24 +76,25 @@ def set_impose_on(self, ith_unit):


class _BundleConditionMixin:
"""Mixin class for bundle conditions. Unlike other conditions, parameters of a bundle conditions are not fixed.
For example, a regular `neurodiffeq.conditions.IVP` instance only allows fixed t_0 and u_0. But a bundle
condition allows for t_0 and/or u_0 to be dynamically sampled by generators during training. One can use this to
learn a bundle of solutions (as a function of t_0 and/or u_0).

:param bundle_param_lookup:
A dictionary that maps bundle parameter name to its corresponding index.
For example, if bundle_param_lookup == {'t_0': 0, 'u_0': 1}, then the `thetas` passed to self.enforce()`
must contain two entries, with thetas[0] being a tensor of `t_0` and thetas[1] being a tensor of `u_0`.
Similarly, if bundle_param_lookup == {'t_0': 1, 'u_0': 0}, then the two entries in thetas should be
swapped. Defaults to empty dictionary.
:type bundle_param_lookup: Dict[str, int]
:param allowed_params:
A collection of parameter names allowed in ``bundle_param_lookup``. If specified, an error will be raised
if ``bundle_param_lookup`` contains names that do not appear in ``allowed_params``.
:type allowed_params: Set[str] or List[str] or Tuple[str]
"""

def __init__(self, bundle_param_lookup=None, allowed_params=None):
"""Mixin class for bundle conditions. Unlike other conditions, parameters of a bundle conditions are not fixed.
For example, a regular `neurodiffeq.conditions.IVP` instance only allows fixed t_0 and u_0. But a bundle
condition allows for t_0 and/or u_0 to be dynamically sampled by generators during training. One can use this to
learn a bundle of solutions (as a function of t_0 and/or u_0).

:param bundle_param_lookup:
A dictionary that maps bundle parameter name to its corresponding index.
For example, if bundle_param_lookup == {'t_0': 0, 'u_0': 1}, then the `thetas` passed to self.enforce()`
must contain two entries, with thetas[0] being a tensor of `t_0` and thetas[1] being a tensor of `u_0`.
Similarly, if bundle_param_lookup == {'t_0': 1, 'u_0': 0}, then the two entries in thetas should be
swapped. Defaults to empty dictionary.
:type bundle_param_lookup: Dict[str, int]
:param allowed_params:
A collection of parameter names allowed in ``bundle_param_lookup``. If specified, an error will be raised
if ``bundle_param_lookup`` contains names that do not appear in ``allowed_params``.
:type allowed_params: Set[str] or List[str] or Tuple[str]
"""
self.bundle_param_lookup = bundle_param_lookup or {}

if isinstance(allowed_params, str):
Expand All @@ -113,8 +114,7 @@ def _get_parameter(self, param_name, thetas, override_name=None):

Example:
- If param_name == 't_0', and self.bundle_param_lookup['t_0'] == 0, the method will return `thetas[0]`.
- If `t_0` is not present in self.bundle_param_lookup, it will return `self.t_0`, or, self.some_other_attribute
if override_name is set to 'some_other_attribute'.
- If `t_0` is not present in self.bundle_param_lookup, it will return `self.t_0`, or, self.some_other_attribute if override_name is set to 'some_other_attribute'.

:param param_name: Name of parameter to be used. E.g. `t_0`, `u_0` in the case of an initial value problem.
:type param_name: str
Expand Down