From bb3135f43f3bd6bdceccb5a3a4d3ebd94b54c6a6 Mon Sep 17 00:00:00 2001 From: Nathaniel Starkman Date: Mon, 17 Jul 2023 19:17:56 -0400 Subject: [PATCH] Control region names (#121) * remove ControlPoints because they are not used. * rename field to center Signed-off-by: nstarman --- src/stream_ml/pytorch/prior/__init__.py | 3 +- src/stream_ml/pytorch/prior/_track.py | 75 +++---------------------- 2 files changed, 8 insertions(+), 70 deletions(-) diff --git a/src/stream_ml/pytorch/prior/__init__.py b/src/stream_ml/pytorch/prior/__init__.py index 48e31f5..6481327 100644 --- a/src/stream_ml/pytorch/prior/__init__.py +++ b/src/stream_ml/pytorch/prior/__init__.py @@ -7,7 +7,7 @@ from stream_ml.core.prior import FunctionPrior, Prior from stream_ml.core.prior import HardThreshold as CoreHardThreshold -from stream_ml.pytorch.prior._track import ControlPoints, ControlRegions +from stream_ml.pytorch.prior._track import ControlRegions from stream_ml.pytorch.typing import Array, ArrayNamespace __all__ = [ @@ -16,7 +16,6 @@ "FunctionPrior", "HardThreshold", # from here - "ControlPoints", "ControlRegions", ] diff --git a/src/stream_ml/pytorch/prior/_track.py b/src/stream_ml/pytorch/prior/_track.py index eec75a3..0564af6 100644 --- a/src/stream_ml/pytorch/prior/_track.py +++ b/src/stream_ml/pytorch/prior/_track.py @@ -35,7 +35,7 @@ def _atleast_2d(x: Array) -> Array: class TrackPrior(Prior[Array]): """Track Prior Base.""" - control_points: Data[Array] + center: Data[Array] lamda: float = 0.05 _: KW_ONLY coord_name: str = "phi1" @@ -49,80 +49,23 @@ def __post_init__(self) -> None: # Pre-store the control points, seprated by indep & dep parameters. self._x: Data[Array] - object.__setattr__(self, "_x", self.control_points[(self.coord_name,)]) + object.__setattr__(self, "_x", self.center[(self.coord_name,)]) dep_names: tuple[str, ...] = tuple( - n for n in self.control_points.names if n != self.coord_name + n for n in self.center.names if n != self.coord_name ) self._y_names: tuple[str, ...] object.__setattr__(self, "_y_names", dep_names) self._y: Array object.__setattr__( - self, "_y", _atleast_2d(xp.squeeze(self.control_points[dep_names].array)) + self, "_y", _atleast_2d(xp.squeeze(self.center[dep_names].array)) ) ##################################################################### -@dataclass(frozen=True) -class ControlPoints(TrackPrior): - """Control points prior. - - Parameters - ---------- - control_points : Data[Array] - The control points. - lamda : float, optional - Importance hyperparameter. - """ - - def logpdf( - self, - mpars: Params[Array], - data: Data[Array], - model: ModelAPI[Array, NNModel], - current_lnpdf: Array | None = None, - /, - ) -> Array: - """Evaluate the logpdf. - - This log-pdf is added to the current logpdf. So if you want to set the - logpdf to a specific value, you can uses the `current_lnpdf` to set the - output value such that ``current_lnpdf + logpdf = ``. - - Parameters - ---------- - mpars : Params[Array], positional-only - Model parameters. Note that these are different from the ML - parameters. - data : Data[Array], position-only - The data for which evaluate the prior. - model : Model, position-only - The model for which evaluate the prior. - current_lnpdf : Array | None, optional position-only - The current logpdf, by default `None`. This is useful for setting - the additive log-pdf to a specific value. - - Returns - ------- - Array - The logpdf. - """ - # Get the model parameters evaluated at the control points. shape (C, 1). - cmpars = model.unpack_params(model(self._x)) # type: ignore[call-overload] # noqa: E501 - cmp_arr = xp.hstack( # (C, F) - tuple(cmpars[(n, self.component_param_name)] for n in self._y_names) - ) - - # For each control point, add the squared distance to the logpdf. - return -self.lamda * self.xp.sum((cmp_arr - self._y) ** 2) # (C, F) -> 1 - - -##################################################################### - - @dataclass(frozen=True) class ControlRegions(TrackPrior): r"""Control regions prior. @@ -142,16 +85,12 @@ class ControlRegions(TrackPrior): Parameters ---------- - control_points : Data[Array] + center : Data[Array] The control points. These are the means of the regions (mu in the above). + width : Data[Array], optional + Width(s) of the region(s). lamda : float, optional Importance hyperparameter. - TODO: make this also able to be an array, so that each region can have - a different width. - width : float, optional - Width of the region. - TODO: make this also able to be an array, so that each region can have - a different width. """ width: float | Data[Array] = 0.5