Skip to content

Commit

Permalink
add tests for cond and cens likelihood scores
Browse files Browse the repository at this point in the history
  • Loading branch information
sallen12 committed Oct 8, 2024
1 parent ae2b3a6 commit cb5a825
Showing 1 changed file with 37 additions and 0 deletions.
37 changes: 37 additions & 0 deletions tests/test_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,43 @@ def test_ensemble(backend):
assert np.isclose(res, expected)


@pytest.mark.parametrize("backend", BACKENDS)
def test_clogs(backend):
obs = np.random.randn(N)
mu = obs + np.random.randn(N) * 0.1
sigma = abs(np.random.randn(N))
fct = np.random.randn(N, ENSEMBLE_SIZE) * sigma[..., None] + mu[..., None]

res0 = _logs.logs_ensemble(obs, fct, axis=-1, backend=backend)
res = _logs.clogs_ensemble(obs, fct, axis=-1, backend=backend)
res_co = _logs.clogs_ensemble(obs, fct, axis=-1, cens=False, backend=backend)
assert res.shape == (N,)
assert res_co.shape == (N,)
assert np.allclose(res, res0)
assert np.allclose(res_co, res0)

fct = fct.T
res0 = _logs.clogs_ensemble(obs, fct, axis=0, backend=backend)
assert np.allclose(res, res0)

obs, fct = 6.2, [4.2, 5.1, 6.1, 7.6, 8.3, 9.5]
a = 8.0
res = _logs.clogs_ensemble(obs, fct, a=a, backend=backend)
expected = 0.3929448
assert np.isclose(res, expected)

a = 5.0
res = _logs.clogs_ensemble(obs, fct, a=a, cens=False, backend=backend)
expected = 1.646852
assert np.isclose(res, expected)

fct = [-142.7, -160.3, -179.4, -184.5]
b = -150.0
res = _logs.clogs_ensemble(obs, fct, b=b, backend=backend)
expected = 1.420427
assert np.isclose(res, expected)


@pytest.mark.parametrize("backend", BACKENDS)
def test_beta(backend):
if backend == "torch":
Expand Down

0 comments on commit cb5a825

Please sign in to comment.