Skip to content

Commit

Permalink
implement crps for beta distribution (#37)
Browse files Browse the repository at this point in the history
Co-authored-by: sallen12 <34094291+sallen12@users.noreply.github.com>
  • Loading branch information
frazane and sallen12 authored Aug 23, 2024
1 parent d998250 commit 725c1b9
Show file tree
Hide file tree
Showing 8 changed files with 135 additions and 3 deletions.
2 changes: 2 additions & 0 deletions docs/api/crps.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ When the true forecast CDF is not fully known, but represented by a finite ensem

<h2>Analytical formulations</h2>

::: scoringrules.crps_beta

::: scoringrules.crps_exponential

::: scoringrules.crps_lognormal
Expand Down
2 changes: 2 additions & 0 deletions scoringrules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from scoringrules._brier import brier_score
from scoringrules._crps import (
crps_beta,
crps_ensemble,
crps_exponential,
crps_logistic,
Expand Down Expand Up @@ -36,6 +37,7 @@
"register_backend",
"backends",
"crps_ensemble",
"crps_beta",
"crps_normal",
"crps_exponential",
"crps_lognormal",
Expand Down
55 changes: 55 additions & 0 deletions scoringrules/_crps.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,61 @@ def vrcrps_ensemble(
)


def crps_beta(
observation: "ArrayLike",
a: "ArrayLike",
b: "ArrayLike",
/,
lower: "ArrayLike" = 0.0,
upper: "ArrayLike" = 1.0,
*,
backend: "Backend" = None,
) -> "ArrayLike":
r"""Compute the closed form of the CRPS for the beta distribution.
It is based on the following formulation from
[Jordan et al. (2019)](https://www.jstatsoft.org/article/view/v090i12):
$$
\mathrm{CRPS}(F_{\alpha, \beta}, y) = (u - l)\left\{ \frac{y - l}{u - l}
\left( 2F_{\alpha, \beta} \left( \frac{y - l}{u - l} \right) - 1 \right)
+ \frac{\alpha}{\alpha + \beta} \left( 1 - 2F_{\alpha + 1, \beta}
\left( \frac{y - l}{u - l} \right)
- \frac{2B(2\alpha, 2\beta)}{\alpha B(\alpha, \beta)^{2}} \right) \right\}
$$
where $F_{\alpha, \beta}$ is the beta distribution function with shape parameters
$\alpha, \beta > 0$, and lower and upper bounds $l, u \in \R$, $l < u$.
Parameters
----------
observation:
The observed values.
a:
First shape parameter of the forecast beta distribution.
b:
Second shape parameter of the forecast beta distribution.
lower:
Lower bound of the forecast beta distribution.
upper:
Upper bound of the forecast beta distribution.
backend:
The name of the backend used for computations. Defaults to 'numba' if available, else 'numpy'.
Returns
-------
score:
The CRPS between Beta(a, b) and obs.
Examples
--------
>>> import scoringrules as sr
>>> sr.crps_beta(0.3, 0.7, 1.1)
0.0850102437
"""
return crps.beta(observation, a, b, lower, upper, backend=backend)


def crps_exponential(
observation: "ArrayLike",
rate: "ArrayLike",
Expand Down
4 changes: 3 additions & 1 deletion scoringrules/backend/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,9 @@ def beta(self, x: "Tensor", y: "Tensor") -> "Tensor":
return torch.exp(torch.lgamma(x) + torch.lgamma(y) - torch.lgamma(x + y))

def betainc(self, x: "Tensor", y: "Tensor", z: "Tensor") -> "Tensor":
return None
raise NotImplementedError(
"The `betainc` function is not implemented in the torch backend."
)

def mbessel0(self, x: "Tensor") -> "Tensor":
return torch.special.i0(x)
Expand Down
3 changes: 2 additions & 1 deletion scoringrules/core/crps/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from ._approx import ensemble, ow_ensemble, quantile_pinball, vr_ensemble
from ._closed import exponential, logistic, lognormal, normal
from ._closed import beta, exponential, logistic, lognormal, normal
from ._gufuncs import estimator_gufuncs, quantile_pinball_gufunc

__all__ = [
"ensemble",
"ow_ensemble",
"vr_ensemble",
"beta",
"exponential",
"logistic",
"lognormal",
Expand Down
41 changes: 41 additions & 0 deletions scoringrules/core/crps/_closed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,41 @@
from scoringrules.core.typing import Array, ArrayLike, Backend


def beta(
obs: "ArrayLike",
a: "ArrayLike",
b: "ArrayLike",
lower: "ArrayLike" = 0.0,
upper: "ArrayLike" = 1.0,
backend: "Backend" = None,
) -> "Array":
"""Compute the CRPS for the beta distribution."""
B = backends.active if backend is None else backends[backend]
obs, a, b, lower, upper = map(B.asarray, (obs, a, b, lower, upper))

if _is_scalar_value(lower, 0.0) and _is_scalar_value(upper, 1.0):
special_limits = False
else:
if B.any(lower >= upper):
raise ValueError("lower must be less than upper")
special_limits = True

if special_limits:
obs = (obs - lower) / (upper - lower)

I_ab = B.betainc(a, b, obs)
I_a1b = B.betainc(a + 1, b, obs)
F_ab = B.minimum(B.maximum(I_ab, 0), 1)
F_a1b = B.minimum(B.maximum(I_a1b, 0), 1)
bet_rat = 2 * B.beta(2 * a, 2 * b) / (a * B.beta(a, b) ** 2)
s = obs * (2 * F_ab - 1) + (a / (a + b)) * (1 - 2 * F_a1b - bet_rat)

if special_limits:
s = s * (upper - lower)

return s


def exponential(
obs: "ArrayLike", rate: "ArrayLike", backend: "Backend" = None
) -> "Array":
Expand Down Expand Up @@ -57,3 +92,9 @@ def logistic(
mu, sigma, obs = map(B.asarray, (mu, sigma, obs))
ω = (obs - mu) / sigma
return sigma * (ω - 2 * B.log(_logis_cdf(ω, backend=backend)) - 1)


def _is_scalar_value(x, value):
if x.size != 1:
return False
return x.item() == value
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from scoringrules.backend import backends

DATA_DIR = Path(__file__).parent / "data"
RUN_TESTS = ["numba", "jax", "torch"]
RUN_TESTS = ["numpy", "numba", "jax", "torch"]
BACKENDS = [b for b in backends.available_backends if b in RUN_TESTS]

if os.getenv("SR_TEST_OUTPUT", "False").lower() in ("true", "1", "t"):
Expand Down
29 changes: 29 additions & 0 deletions tests/test_crps.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,32 @@ def test_lognormal(backend):

assert not np.any(np.isnan(res))
assert not np.any(res - 0.0 > 0.0001)


@pytest.mark.parametrize("backend", BACKENDS)
def test_beta(backend):
if backend == "torch":
pytest.skip("Not implemented in torch backend")

res = _crps.crps_beta(
np.random.uniform(0, 1, (3, 3)),
np.random.uniform(0, 3, (3, 3)),
1.1,
backend=backend,
)
assert res.shape == (3, 3)
assert not np.any(np.isnan(res))

# test exceptions
with pytest.raises(ValueError):
_crps.crps_beta(0.3, 0.7, 1.1, lower=1.0, upper=0.0, backend=backend)
return

# correctness tests
res = _crps.crps_beta(0.3, 0.7, 1.1, backend=backend)
expected = 0.0850102437
assert np.isclose(res, expected)

res = _crps.crps_beta(-3.0, 0.7, 1.1, lower=-5.0, upper=4.0, backend=backend)
expected = 0.883206751
assert np.isclose(res, expected)

0 comments on commit 725c1b9

Please sign in to comment.