Skip to content

Commit

Permalink
Control region names (#121)
Browse files Browse the repository at this point in the history
* remove ControlPoints because they are not used.
* rename field to center

Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman authored Jul 17, 2023
1 parent 1f56a00 commit bb3135f
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 70 deletions.
3 changes: 1 addition & 2 deletions src/stream_ml/pytorch/prior/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from stream_ml.core.prior import FunctionPrior, Prior
from stream_ml.core.prior import HardThreshold as CoreHardThreshold

from stream_ml.pytorch.prior._track import ControlPoints, ControlRegions
from stream_ml.pytorch.prior._track import ControlRegions
from stream_ml.pytorch.typing import Array, ArrayNamespace

__all__ = [
Expand All @@ -16,7 +16,6 @@
"FunctionPrior",
"HardThreshold",
# from here
"ControlPoints",
"ControlRegions",
]

Expand Down
75 changes: 7 additions & 68 deletions src/stream_ml/pytorch/prior/_track.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def _atleast_2d(x: Array) -> Array:
class TrackPrior(Prior[Array]):
"""Track Prior Base."""

control_points: Data[Array]
center: Data[Array]
lamda: float = 0.05
_: KW_ONLY
coord_name: str = "phi1"
Expand All @@ -49,80 +49,23 @@ def __post_init__(self) -> None:

# Pre-store the control points, seprated by indep & dep parameters.
self._x: Data[Array]
object.__setattr__(self, "_x", self.control_points[(self.coord_name,)])
object.__setattr__(self, "_x", self.center[(self.coord_name,)])

dep_names: tuple[str, ...] = tuple(
n for n in self.control_points.names if n != self.coord_name
n for n in self.center.names if n != self.coord_name
)
self._y_names: tuple[str, ...]
object.__setattr__(self, "_y_names", dep_names)

self._y: Array
object.__setattr__(
self, "_y", _atleast_2d(xp.squeeze(self.control_points[dep_names].array))
self, "_y", _atleast_2d(xp.squeeze(self.center[dep_names].array))
)


#####################################################################


@dataclass(frozen=True)
class ControlPoints(TrackPrior):
"""Control points prior.
Parameters
----------
control_points : Data[Array]
The control points.
lamda : float, optional
Importance hyperparameter.
"""

def logpdf(
self,
mpars: Params[Array],
data: Data[Array],
model: ModelAPI[Array, NNModel],
current_lnpdf: Array | None = None,
/,
) -> Array:
"""Evaluate the logpdf.
This log-pdf is added to the current logpdf. So if you want to set the
logpdf to a specific value, you can uses the `current_lnpdf` to set the
output value such that ``current_lnpdf + logpdf = <want>``.
Parameters
----------
mpars : Params[Array], positional-only
Model parameters. Note that these are different from the ML
parameters.
data : Data[Array], position-only
The data for which evaluate the prior.
model : Model, position-only
The model for which evaluate the prior.
current_lnpdf : Array | None, optional position-only
The current logpdf, by default `None`. This is useful for setting
the additive log-pdf to a specific value.
Returns
-------
Array
The logpdf.
"""
# Get the model parameters evaluated at the control points. shape (C, 1).
cmpars = model.unpack_params(model(self._x)) # type: ignore[call-overload] # noqa: E501
cmp_arr = xp.hstack( # (C, F)
tuple(cmpars[(n, self.component_param_name)] for n in self._y_names)
)

# For each control point, add the squared distance to the logpdf.
return -self.lamda * self.xp.sum((cmp_arr - self._y) ** 2) # (C, F) -> 1


#####################################################################


@dataclass(frozen=True)
class ControlRegions(TrackPrior):
r"""Control regions prior.
Expand All @@ -142,16 +85,12 @@ class ControlRegions(TrackPrior):
Parameters
----------
control_points : Data[Array]
center : Data[Array]
The control points. These are the means of the regions (mu in the above).
width : Data[Array], optional
Width(s) of the region(s).
lamda : float, optional
Importance hyperparameter.
TODO: make this also able to be an array, so that each region can have
a different width.
width : float, optional
Width of the region.
TODO: make this also able to be an array, so that each region can have
a different width.
"""

width: float | Data[Array] = 0.5
Expand Down

0 comments on commit bb3135f

Please sign in to comment.