Skip to content

Commit

Permalink
use better repr from core (#142)
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman authored Sep 12, 2023
1 parent b89fa71 commit 2b5517c
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/stream_ml/pytorch/builtin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
# [("array_namespace", ArrayNamespace[Array], field(default=xp, kw_only=True))],
# bases=(CoreSkewNormal[Array, NNModel], ModelBase),
# unsafe_hash=True,
# repr=False,
# )


Expand All @@ -97,4 +98,5 @@
# [("array_namespace", ArrayNamespace[Array], field(default=xp, kw_only=True))],
# bases=(CoreTruncatedSkewNormal[Array, NNModel], ModelBase),
# unsafe_hash=True,
# repr=False,
# )
2 changes: 2 additions & 0 deletions src/stream_ml/pytorch/params/bounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
[("array_namespace", ArrayNamespace[Array], field(default=xp, kw_only=True))],
bases=(CoreNoBounds[Array],),
frozen=True,
repr=False,
)


Expand All @@ -42,6 +43,7 @@
[("array_namespace", ArrayNamespace[Array], field(default=xp, kw_only=True))],
bases=(CoreClippedBounds[Array],),
frozen=True,
repr=False,
)


Expand Down
1 change: 1 addition & 0 deletions src/stream_ml/pytorch/prior/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,6 @@
[("array_namespace", ArrayNamespace[Array], field(default=xp, kw_only=True))],
bases=(CoreHardThreshold[Array],),
frozen=True,
repr=False,
unsafe_hash=True,
)
5 changes: 5 additions & 0 deletions src/stream_ml/pytorch/prior/_track.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,13 @@ def logpdf(
)

lnpdf = xp.zeros_like(cmp_arr)

# Lower side
# Note that comparison to NaN is always False.
where = cmp_arr <= self._y - self._w
lnpdf[where] = (cmp_arr[where] - (self._y[where] - self._w[where])) ** 2

# Upper side
where = cmp_arr >= self._y + self._w
lnpdf[where] = (cmp_arr[where] - (self._y[where] + self._w[where])) ** 2

Expand Down

0 comments on commit 2b5517c

Please sign in to comment.