Skip to content

Commit

Permalink
Implement log-probability for StudentTRV
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf authored and brandonwillard committed Nov 4, 2022
1 parent 961496a commit 54a5888
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 1 deletion.
14 changes: 14 additions & 0 deletions aeppl/logprob.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,20 @@ def hypergeometric_logprob(op, values, *inputs, **kwargs):
return res


@_logprob.register(arb.StudentTRV)
def studentt_logprob(op, values, *inputs, **kwargs):
(value,) = values
df, loc, scale = inputs[3:]
res = (
at.gammaln((df + 1.0) / 2.0)
- at.gammaln(df / 2.0)
- 0.5 * at.log(np.pi * df * scale**2)
- (df + 1.0) / 2.0 * at.log1p(((value - loc) / scale) ** 2 / df)
)
res = CheckParameterValue("scale >= 0")(res, at.all(at.ge(scale, 0.0)))
return res


@_logprob.register(arb.CategoricalRV)
def categorical_logprob(op, values, *inputs, **kwargs):
(value,) = values
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
install_requires=[
"numpy>=1.18.1",
"scipy>=1.4.0",
"aesara >= 2.8.5",
"aesara >= 2.8.8",
],
tests_require=["pytest"],
long_description=open("README.rst").read() if exists("README.rst") else "",
Expand Down
28 changes: 28 additions & 0 deletions tests/test_logprob.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,6 +883,34 @@ def scipy_logprob(obs, good, bad, n):
scipy_logprob_tester(x, obs, dist_params, test_fn=scipy_logprob)


@pytest.mark.parametrize(
"dist_params, obs, size, error",
[
((1, 0, 2), np.array([-10, 0, 10], dtype=np.float32), (), False),
((1, 0, 2), np.array([-10, 0, 10], dtype=np.float32), (3, 2), False),
(
(np.array([10, 5, 3], dtype=np.int64), 1, 2),
np.array([-1, 1, 84], dtype=np.int64),
(),
False,
),
],
)
def test_studentt_logprob(dist_params, obs, size, error):
dist_params_at, obs_at, size_at = create_aesara_params(dist_params, obs, size)
dist_params = dict(zip(dist_params_at, dist_params))

x = at.random.t(*dist_params_at, size=size_at)

cm = contextlib.suppress() if not error else pytest.raises(AssertionError)

def scipy_logprob(obs, df, loc, scale):
return stats.t.logpdf(obs, df, loc=loc, scale=scale)

with cm:
scipy_logprob_tester(x, obs, dist_params, test_fn=scipy_logprob)


@pytest.mark.parametrize(
"dist_params, obs, size, exc_type, chk_bcast",
[
Expand Down

0 comments on commit 54a5888

Please sign in to comment.