From 21ed5d73d8416899f06b352e39f74d9070c764d8 Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Tue, 16 Jan 2024 18:25:25 +0530 Subject: [PATCH 1/2] docs: fix docstring for _BundleConditionMixin --- neurodiffeq/conditions.py | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/neurodiffeq/conditions.py b/neurodiffeq/conditions.py index 397803a..da2ad50 100644 --- a/neurodiffeq/conditions.py +++ b/neurodiffeq/conditions.py @@ -76,24 +76,25 @@ def set_impose_on(self, ith_unit): class _BundleConditionMixin: + """Mixin class for bundle conditions. Unlike other conditions, parameters of a bundle conditions are not fixed. + For example, a regular `neurodiffeq.conditions.IVP` instance only allows fixed t_0 and u_0. But a bundle + condition allows for t_0 and/or u_0 to be dynamically sampled by generators during training. One can use this to + learn a bundle of solutions (as a function of t_0 and/or u_0). + + :param bundle_param_lookup: + A dictionary that maps bundle parameter name to its corresponding index. + For example, if bundle_param_lookup == {'t_0': 0, 'u_0': 1}, then the `thetas` passed to self.enforce()` + must contain two entries, with thetas[0] being a tensor of `t_0` and thetas[1] being a tensor of `u_0`. + Similarly, if bundle_param_lookup == {'t_0': 1, 'u_0': 0}, then the two entries in thetas should be + swapped. Defaults to empty dictionary. + :type bundle_param_lookup: Dict[str, int] + :param allowed_params: + A collection of parameter names allowed in ``bundle_param_lookup``. If specified, an error will be raised + if ``bundle_param_lookup`` contains names that do not appear in ``allowed_params``. + :type allowed_params: Set[str] or List[str] or Tuple[str] + """ + def __init__(self, bundle_param_lookup=None, allowed_params=None): - """Mixin class for bundle conditions. Unlike other conditions, parameters of a bundle conditions are not fixed. - For example, a regular `neurodiffeq.conditions.IVP` instance only allows fixed t_0 and u_0. But a bundle - condition allows for t_0 and/or u_0 to be dynamically sampled by generators during training. One can use this to - learn a bundle of solutions (as a function of t_0 and/or u_0). - - :param bundle_param_lookup: - A dictionary that maps bundle parameter name to its corresponding index. - For example, if bundle_param_lookup == {'t_0': 0, 'u_0': 1}, then the `thetas` passed to self.enforce()` - must contain two entries, with thetas[0] being a tensor of `t_0` and thetas[1] being a tensor of `u_0`. - Similarly, if bundle_param_lookup == {'t_0': 1, 'u_0': 0}, then the two entries in thetas should be - swapped. Defaults to empty dictionary. - :type bundle_param_lookup: Dict[str, int] - :param allowed_params: - A collection of parameter names allowed in ``bundle_param_lookup``. If specified, an error will be raised - if ``bundle_param_lookup`` contains names that do not appear in ``allowed_params``. - :type allowed_params: Set[str] or List[str] or Tuple[str] - """ self.bundle_param_lookup = bundle_param_lookup or {} if isinstance(allowed_params, str): @@ -113,8 +114,7 @@ def _get_parameter(self, param_name, thetas, override_name=None): Example: - If param_name == 't_0', and self.bundle_param_lookup['t_0'] == 0, the method will return `thetas[0]`. - - If `t_0` is not present in self.bundle_param_lookup, it will return `self.t_0`, or, self.some_other_attribute - if override_name is set to 'some_other_attribute'. + - If `t_0` is not present in self.bundle_param_lookup, it will return `self.t_0`, or, self.some_other_attribute if override_name is set to 'some_other_attribute'. :param param_name: Name of parameter to be used. E.g. `t_0`, `u_0` in the case of an initial value problem. :type param_name: str From 291d69695c5bc7c053d8d6f1466f19a03d28944f Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Tue, 16 Jan 2024 18:26:02 +0530 Subject: [PATCH 2/2] docs: add private members in automodule for conditions for _BundleConditionMixin --- docs/api.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/api.rst b/docs/api.rst index 8a4e83a..d8c3bbe 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -21,6 +21,7 @@ API Reference :show-inheritance: :inherited-members: :members: + :private-members: `neurodiffeq.solvers` ------------------------------------------------------