From 3fbb96b39730c555c0aeeb69a5d238e93c6ca8f4 Mon Sep 17 00:00:00 2001 From: Francesco Zanetta Date: Mon, 26 Aug 2024 10:18:16 +0200 Subject: [PATCH] implement crps for binomial distribution Co-Authored-By: Sam Allen <34094291+sallen12@users.noreply.github.com> --- docs/api/crps.md | 2 + scoringrules/__init__.py | 2 + scoringrules/_crps.py | 44 +++++++++++++++++++ scoringrules/backend/jax.py | 6 ++- scoringrules/backend/tensorflow.py | 2 +- scoringrules/backend/torch.py | 2 +- scoringrules/core/crps/__init__.py | 3 +- scoringrules/core/crps/_closed.py | 70 +++++++++++++++++++++++++++++- scoringrules/core/stats.py | 9 ++-- tests/test_crps.py | 33 ++++++++++++++ 10 files changed, 162 insertions(+), 11 deletions(-) diff --git a/docs/api/crps.md b/docs/api/crps.md index 6c4be20..b96e106 100644 --- a/docs/api/crps.md +++ b/docs/api/crps.md @@ -15,6 +15,8 @@ When the true forecast CDF is not fully known, but represented by a finite ensem ::: scoringrules.crps_beta +::: scoringrules.crps_binomial + ::: scoringrules.crps_exponential ::: scoringrules.crps_lognormal diff --git a/scoringrules/__init__.py b/scoringrules/__init__.py index 235e584..fa4816b 100644 --- a/scoringrules/__init__.py +++ b/scoringrules/__init__.py @@ -3,6 +3,7 @@ from scoringrules._brier import brier_score from scoringrules._crps import ( crps_beta, + crps_binomial, crps_ensemble, crps_exponential, crps_logistic, @@ -38,6 +39,7 @@ "backends", "crps_ensemble", "crps_beta", + "crps_binomial", "crps_normal", "crps_exponential", "crps_lognormal", diff --git a/scoringrules/_crps.py b/scoringrules/_crps.py index 904ed53..30bdf53 100644 --- a/scoringrules/_crps.py +++ b/scoringrules/_crps.py @@ -405,6 +405,50 @@ def crps_beta( return crps.beta(observation, a, b, lower, upper, backend=backend) +def crps_binomial( + observation: "ArrayLike", + n: "ArrayLike", + prob: "ArrayLike", + /, + *, + backend: "Backend" = None, +) -> "ArrayLike": + r"""Compute the closed form of the CRPS for the binomial distribution. + + It is based on the following formulation from + [Jordan et al. (2019)](https://www.jstatsoft.org/article/view/v090i12): + + $$ + \mathrm{CRPS}(F_{n, p}, y) = 2 \sum_{x = 0}^{n} f_{n,p}(x) (1\{y < x\} + - F_{n,p}(x) + f_{n,p}(x)/2) (x - y), + $$ + + where $f_{n, p}$ and $F_{n, p}$ are the PDF and CDF of the binomial distribution + with size parameter $n = 0, 1, 2, ...$ and probability parameter $p \in [0, 1]$. + + Parameters + ---------- + observation: + The observed values as an integer or array of integers. + n: + Size parameter of the forecast binomial distribution as an integer or array of integers. + prob: + Probability parameter of the forecast binomial distribution as a float or array of floats. + + Returns + ------- + score: + The CRPS between Binomial(n, prob) and obs. + + Examples + -------- + >>> import scoringrules as sr + >>> sr.crps_binomial(4, 10, 0.5) + 0.5955715179443359 + """ + return crps.binomial(observation, n, prob, backend=backend) + + def crps_exponential( observation: "ArrayLike", rate: "ArrayLike", diff --git a/scoringrules/backend/jax.py b/scoringrules/backend/jax.py index b112f11..0049d0e 100644 --- a/scoringrules/backend/jax.py +++ b/scoringrules/backend/jax.py @@ -165,7 +165,7 @@ def apply_along_axis( try: x_shape = list(x.shape) return jax.vmap(func1d)(x.reshape(-1, x_shape.pop(axis))).reshape(x_shape) - except Exception: + except (Exception, jax.errors.ConcretizationTypeError): return jnp.apply_along_axis(func1d, axis, x) def floor(self, x: "Array") -> "Array": @@ -205,7 +205,9 @@ def hypergeometric(self, a: "Array", b: "Array", c: "Array", z: "Array"): return jsp.special.hyp2f1(a, b, c, z) def comb(self, n: "ArrayLike", k: "ArrayLike") -> "ArrayLike": - return jsp.special.comb(n, k) + return jsp.special.factorial(n) // ( + jsp.special.factorial(k) * jsp.special.factorial(n - k) + ) def expi(self, x: "Array") -> "Array": return jsp.special.expi(x) diff --git a/scoringrules/backend/tensorflow.py b/scoringrules/backend/tensorflow.py index 28d62d2..4e600c0 100644 --- a/scoringrules/backend/tensorflow.py +++ b/scoringrules/backend/tensorflow.py @@ -247,7 +247,7 @@ def hypergeometric( raise NotImplementedError def comb(self, n: "Tensor", k: "Tensor") -> "Tensor": - return self.factorial(n) / (self.factorial(k) * self.factorial(n - k)) + return self.factorial(n) // (self.factorial(k) * self.factorial(n - k)) def expi(self, x: "Tensor") -> "Tensor": return tf.math.special.expint(x) diff --git a/scoringrules/backend/torch.py b/scoringrules/backend/torch.py index 5407323..1761054 100644 --- a/scoringrules/backend/torch.py +++ b/scoringrules/backend/torch.py @@ -222,7 +222,7 @@ def hypergeometric( return None def comb(self, n: "Tensor", k: "Tensor") -> "Tensor": - return self.factorial(n) / (self.factorial(k) * self.factorial(n - k)) + return self.factorial(n) // (self.factorial(k) * self.factorial(n - k)) def expi(self, x: "Tensor") -> "Tensor": return None diff --git a/scoringrules/core/crps/__init__.py b/scoringrules/core/crps/__init__.py index a878b7f..82539db 100644 --- a/scoringrules/core/crps/__init__.py +++ b/scoringrules/core/crps/__init__.py @@ -1,5 +1,5 @@ from ._approx import ensemble, ow_ensemble, quantile_pinball, vr_ensemble -from ._closed import beta, exponential, logistic, lognormal, normal +from ._closed import beta, binomial, exponential, logistic, lognormal, normal from ._gufuncs import estimator_gufuncs, quantile_pinball_gufunc __all__ = [ @@ -7,6 +7,7 @@ "ow_ensemble", "vr_ensemble", "beta", + "binomial", "exponential", "logistic", "lognormal", diff --git a/scoringrules/core/crps/_closed.py b/scoringrules/core/crps/_closed.py index 4952fbd..84e54f9 100644 --- a/scoringrules/core/crps/_closed.py +++ b/scoringrules/core/crps/_closed.py @@ -1,7 +1,14 @@ import typing as tp from scoringrules.backend import backends -from scoringrules.core.stats import _exp_cdf, _logis_cdf, _norm_cdf, _norm_pdf +from scoringrules.core.stats import ( + _binom_cdf, + _binom_pdf, + _exp_cdf, + _logis_cdf, + _norm_cdf, + _norm_pdf, +) if tp.TYPE_CHECKING: from scoringrules.core.typing import Array, ArrayLike, Backend @@ -42,6 +49,67 @@ def beta( return s +def binomial( + obs: "ArrayLike", + n: "ArrayLike", + prob: "ArrayLike", + backend: "Backend" = None, +) -> "Array": + """Compute the CRPS for the binomial distribution. + + Note + ---- + This is a bit of a hacky implementation, due to how the arrays + must be broadcasted, but it should work for now. + """ + B = backends.active if backend is None else backends[backend] + obs, n, prob = map(B.asarray, (obs, n, prob)) + ones_like_n = 0.0 * n + 1 + + def _inner(params): + obs, n, prob = params + x = B.arange(0, n + 1) + w = _binom_pdf(x, n, prob) + a = _binom_cdf(x, n, prob) - 0.5 * w + s = 2 * B.sum(w * ((obs < x) - a) * (x - obs)) + return s + + # if n is a scalar, then if needed we must broadcast k and p to the same shape as n + # TODO: implement B.broadcast() for backends + if n.size == 1: + x = B.arange(0, n + 1) + need_broadcast = not (obs.size == 1 and prob.size == 1) + + if need_broadcast: + obs = obs[:, None] if obs.size > 1 else obs[None] + prob = prob[:, None] if prob.size > 1 else prob[None] + x = x[None] + x = x * ones_like_n + prob = prob * ones_like_n + obs = obs * ones_like_n + + w = _binom_pdf(x, n, prob) + a = _binom_cdf(x, n, prob) - 0.5 * w + s = 2 * B.sum( + w * ((obs < x) - a) * (x - obs), axis=-1 if need_broadcast else None + ) + + # otherwise, since x would have variable sizes, we must apply the function along the axis + else: + obs = obs * ones_like_n if obs.size == 1 else obs + prob = prob * ones_like_n if prob.size == 1 else prob + + # option 1: in a loop + s = B.stack( + [_inner(params) for params in zip(obs, n, prob, strict=True)], axis=-1 + ) + + # option 2: apply_along_axis (does not work with JAX) + # s = B.apply_along_axis(_inner, B.stack((obs, n, prob), axis=-1), -1) + + return s + + def exponential( obs: "ArrayLike", rate: "ArrayLike", backend: "Backend" = None ) -> "Array": diff --git a/scoringrules/core/stats.py b/scoringrules/core/stats.py index f1c55d4..e1e659f 100644 --- a/scoringrules/core/stats.py +++ b/scoringrules/core/stats.py @@ -90,20 +90,19 @@ def _gpd_cdf(x: "ArrayLike", shape: "ArrayLike", backend: "Backend" = None) -> " def _binom_pdf( - x: "ArrayLike", n: "ArrayLike", prob: "ArrayLike", backend: "Backend" = None + k: "ArrayLike", n: "ArrayLike", prob: "ArrayLike", backend: "Backend" = None ) -> "Array": """Probability mass function for the binomial distribution.""" B = backends.active if backend is None else backends[backend] - ind = B.isinteger(x) * (1 - B.isnegative(x)) * (1 - B.negative(n - x)) - return ind * B.comb(n, x) * prob**x * (1 - prob) ** (n - x) + return B.comb(n, k) * prob**k * (1 - prob) ** (n - k) def _binom_cdf( - x: "ArrayLike", n: "ArrayLike", prob: "ArrayLike", backend: "Backend" = None + k: "ArrayLike", n: "ArrayLike", prob: "ArrayLike", backend: "Backend" = None ) -> "Array": """Cumulative distribution function for the binomial distribution.""" B = backends.active if backend is None else backends[backend] - return (1 - B.isnegative(x)) * B.betainc(n - B.floor(x), B.floor(x) + 1, 1 - prob) + return B.betainc(n - B.minimum(k, n) + 1e-36, k + 1, 1 - prob) def _hypergeo_pdf( diff --git a/tests/test_crps.py b/tests/test_crps.py index 2c1fb79..8afb45d 100644 --- a/tests/test_crps.py +++ b/tests/test_crps.py @@ -137,3 +137,36 @@ def test_beta(backend): res = _crps.crps_beta(-3.0, 0.7, 1.1, lower=-5.0, upper=4.0, backend=backend) expected = 0.883206751 assert np.isclose(res, expected) + + +@pytest.mark.parametrize("backend", BACKENDS) +def test_binomial(backend): + if backend == "torch": + pytest.skip("Not implemented in torch backend") + + # test correctness + res = _crps.crps_binomial(8, 10, 0.9, backend=backend) + expected = 0.6685115 + assert np.isclose(res, expected) + + res = _crps.crps_binomial(-8, 10, 0.9, backend=backend) + expected = 16.49896 + assert np.isclose(res, expected) + + res = _crps.crps_binomial(18, 10, 0.9, backend=backend) + expected = 8.498957 + assert np.isclose(res, expected) + + # test broadcasting + ones = np.ones(2) + k, n, p = 8, 10, 0.9 + s = _crps.crps_binomial(k * ones, n, p, backend=backend) + assert np.isclose(s, np.array([0.6685115, 0.6685115])).all() + s = _crps.crps_binomial(k * ones, n * ones, p, backend=backend) + assert np.isclose(s, np.array([0.6685115, 0.6685115])).all() + s = _crps.crps_binomial(k * ones, n * ones, p * ones, backend=backend) + assert np.isclose(s, np.array([0.6685115, 0.6685115])).all() + s = _crps.crps_binomial(k, n * ones, p * ones, backend=backend) + assert np.isclose(s, np.array([0.6685115, 0.6685115])).all() + s = _crps.crps_binomial(k * ones, n, p * ones, backend=backend) + assert np.isclose(s, np.array([0.6685115, 0.6685115])).all()