From 1e450ca437ce90bd79ba7b78f5fd966cc30e87f2 Mon Sep 17 00:00:00 2001 From: Sam Allen <34094291+sallen12@users.noreply.github.com> Date: Thu, 5 Sep 2024 15:58:32 +0200 Subject: [PATCH] Add CRPS for the Poisson distribution (#51) * add crps for poisson distribution * fix bessel functions and poisson crps --- docs/api/crps.md | 2 + scoringrules/__init__.py | 2 + scoringrules/_crps.py | 44 +++++++++++++++++++-- scoringrules/backend/jax.py | 4 +- scoringrules/backend/numpy.py | 7 ++-- scoringrules/backend/tensorflow.py | 4 +- scoringrules/core/crps/__init__.py | 2 + scoringrules/core/crps/_closed.py | 62 +++++++++++++++++++----------- scoringrules/core/stats.py | 8 +++- tests/test_crps.py | 18 +++++++++ 10 files changed, 119 insertions(+), 34 deletions(-) diff --git a/docs/api/crps.md b/docs/api/crps.md index 216dbda..5e60f40 100644 --- a/docs/api/crps.md +++ b/docs/api/crps.md @@ -53,6 +53,8 @@ When the true forecast CDF is not fully known, but represented by a finite ensem ::: scoringrules.crps_normal +::: scoringrules.crps_poisson + ::: scoringrules.crps_uniform ## Ensemble-based estimators diff --git a/scoringrules/__init__.py b/scoringrules/__init__.py index 5011eeb..49b46a3 100644 --- a/scoringrules/__init__.py +++ b/scoringrules/__init__.py @@ -23,6 +23,7 @@ crps_loglogistic, crps_lognormal, crps_normal, + crps_poisson, crps_uniform, crps_quantile, owcrps_ensemble, @@ -73,6 +74,7 @@ "crps_loglogistic", "crps_lognormal", "crps_normal", + "crps_poisson", "crps_quantile", "crps_uniform", "owcrps_ensemble", diff --git a/scoringrules/_crps.py b/scoringrules/_crps.py index 094388e..f12dffb 100644 --- a/scoringrules/_crps.py +++ b/scoringrules/_crps.py @@ -1333,18 +1333,55 @@ def crps_normal( Mean of the forecast normal distribution. sigma: ArrayLike Standard deviation of the forecast normal distribution. - + Returns ------- crps: array_like The CRPS between Normal(mu, sigma) and obs. + + Examples + -------- + >>> from scoringrules import crps + >>> crps.normal(0.1, 0.4, 0.0) + """ + return crps.normal(observation, mu, sigma, backend=backend) + + +def crps_poisson( + observation: "ArrayLike", + mean: "ArrayLike", + /, + *, + backend: "Backend" = None, +) -> "ArrayLike": + r"""Compute the closed form of the CRPS for the Poisson distribution. + + It is based on the following formulation from + [Wei and Held (2014)](https://link.springer.com/article/10.1007/s11749-014-0380-8): + + $$ \mathrm{CRPS}(F_{\lambda}, y) = (y - \lambda) (2F_{\lambda}(y) - 1) + 2 \lambda f_{\lambda}(\lfloor y \rfloor ) - \lambda \exp (-2 \lambda) (I_{0} (2 \lambda) + I_{1} (2 \lambda))..$$ + + where $F_{\lambda}$ is Poisson distribution function with mean parameter $\lambda > 0$, + and $I_{0}$ and $I_{1}$ are modified Bessel functions of the first kind. + + Parameters + ---------- + observation: ArrayLike + The observed values. + mean: ArrayLike + Mean parameter of the forecast exponential distribution. + Returns + ------- + crps: array_like + The CRPS between Pois(mean) and obs. + Examples -------- >>> import scoringrules as sr - >>> sr.crps_normal(0.1, 0.4, 0.0) + >>> sr.crps_poisson(1, 2) """ - return crps.normal(observation, mu, sigma, backend=backend) + return crps.poisson(observation, mean, backend=backend) def crps_uniform( @@ -1420,5 +1457,6 @@ def crps_uniform( "crps_loglogistic", "crps_lognormal", "crps_normal", + "crps_poisson", "crps_uniform", ] diff --git a/scoringrules/backend/jax.py b/scoringrules/backend/jax.py index b5cc1f4..0d6a23f 100644 --- a/scoringrules/backend/jax.py +++ b/scoringrules/backend/jax.py @@ -187,10 +187,10 @@ def betainc(self, x: "Array", y: "Array", z: "Array") -> "Array": return jsp.special.betainc(x, y, z) def mbessel0(self, x: "Array") -> "Array": - return jsp.special.jv(0, x) + return jsp.special.i0(x) def mbessel1(self, x: "Array") -> "Array": - return jsp.special.jv(1, x) + return jsp.special.i1(x) def gamma(self, x: "Array") -> "Array": return jsp.special.gamma(x) diff --git a/scoringrules/backend/numpy.py b/scoringrules/backend/numpy.py index fac708c..694b934 100644 --- a/scoringrules/backend/numpy.py +++ b/scoringrules/backend/numpy.py @@ -12,7 +12,8 @@ gammainc, gammaincc, hyp2f1, - jv, + i0, + i1, ) if tp.TYPE_CHECKING: @@ -186,10 +187,10 @@ def betainc(self, x: "NDArray", y: "NDArray", z: "NDArray") -> "NDArray": return betainc(x, y, z) def mbessel0(self, x: "NDArray") -> "NDArray": - return jv(0, x) + return i0(x) def mbessel1(self, x: "NDArray") -> "NDArray": - return jv(1, x) + return i1(x) def gamma(self, x: "NDArray") -> "NDArray": return gamma(x) diff --git a/scoringrules/backend/tensorflow.py b/scoringrules/backend/tensorflow.py index 30c663f..d6fdf3a 100644 --- a/scoringrules/backend/tensorflow.py +++ b/scoringrules/backend/tensorflow.py @@ -227,10 +227,10 @@ def betainc(self, x: "Tensor", y: "Tensor", z: "Tensor") -> "Tensor": return tf.math.betainc(x, y, z) def mbessel0(self, x: "Tensor") -> "Tensor": - return tf.math.bessel_i0e(x) + return tf.math.bessel_i0(x) def mbessel1(self, x: "Tensor") -> "Tensor": - return tf.math.bessel_i1e(x) + return tf.math.bessel_i1(x) def gamma(self, x: "Tensor") -> "Tensor": return tf.math.exp(tf.math.lgamma(x)) diff --git a/scoringrules/core/crps/__init__.py b/scoringrules/core/crps/__init__.py index 0a174cc..8138a4c 100644 --- a/scoringrules/core/crps/__init__.py +++ b/scoringrules/core/crps/__init__.py @@ -16,6 +16,7 @@ loglogistic, lognormal, normal, + poisson, uniform, ) from ._gufuncs import estimator_gufuncs, quantile_pinball_gufunc @@ -40,6 +41,7 @@ "loglogistic", "lognormal", "normal", + "poisson", "uniform", "estimator_gufuncs", "quantile_pinball", diff --git a/scoringrules/core/crps/_closed.py b/scoringrules/core/crps/_closed.py index 360aa4f..f85fd66 100644 --- a/scoringrules/core/crps/_closed.py +++ b/scoringrules/core/crps/_closed.py @@ -13,6 +13,8 @@ _logis_cdf, _norm_cdf, _norm_pdf, + _pois_cdf, + _pois_pdf, ) if tp.TYPE_CHECKING: @@ -406,6 +408,16 @@ def laplace( return sigma * (B.abs(obs) + B.exp(-B.abs(obs)) - 3 / 4) +def logistic( + obs: "ArrayLike", mu: "ArrayLike", sigma: "ArrayLike", backend: "Backend" = None +) -> "Array": + """Compute the CRPS for the normal distribution.""" + B = backends.active if backend is None else backends[backend] + mu, sigma, obs = map(B.asarray, (mu, sigma, obs)) + ω = (obs - mu) / sigma + return sigma * (ω - 2 * B.log(_logis_cdf(ω, backend=backend)) - 1) + + def loglaplace( obs: "ArrayLike", locationlog: "ArrayLike", @@ -441,18 +453,20 @@ def loglaplace( return s -def normal( - obs: "ArrayLike", mu: "ArrayLike", sigma: "ArrayLike", backend: "Backend" = None +def loglogistic( + obs: "ArrayLike", + mulog: "ArrayLike", + sigmalog: "ArrayLike", + backend: "Backend" = None, ) -> "Array": - """Compute the CRPS for the normal distribution.""" + """Compute the CRPS for the log-logistic distribution.""" B = backends.active if backend is None else backends[backend] - mu, sigma, obs = map(B.asarray, (mu, sigma, obs)) - ω = (obs - mu) / sigma - return sigma * ( - ω * (2.0 * _norm_cdf(ω, backend=backend) - 1.0) - + 2.0 * _norm_pdf(ω, backend=backend) - - 1.0 / B.sqrt(B.pi) - ) + mulog, sigmalog, obs = map(B.asarray, (mulog, sigmalog, obs)) + F_ms = 1 / (1 + B.exp(-(B.log(obs) - mulog) / sigmalog)) + b = B.beta(1 + sigmalog, 1 - sigmalog) + I_B = B.betainc(1 + sigmalog, 1 - sigmalog, F_ms) + s = obs * (2 * F_ms - 1) - B.exp(mulog) * b * (2 * I_B + sigmalog - 1) + return s def lognormal( @@ -473,29 +487,33 @@ def lognormal( ) -def logistic( +def normal( obs: "ArrayLike", mu: "ArrayLike", sigma: "ArrayLike", backend: "Backend" = None ) -> "Array": """Compute the CRPS for the logistic distribution.""" B = backends.active if backend is None else backends[backend] mu, sigma, obs = map(B.asarray, (mu, sigma, obs)) ω = (obs - mu) / sigma - return sigma * (ω - 2 * B.log(_logis_cdf(ω, backend=backend)) - 1) + return sigma * ( + ω * (2.0 * _norm_cdf(ω, backend=backend) - 1.0) + + 2.0 * _norm_pdf(ω, backend=backend) + - 1.0 / B.sqrt(B.pi) + ) -def loglogistic( +def poisson( obs: "ArrayLike", - mulog: "ArrayLike", - sigmalog: "ArrayLike", + mean: "ArrayLike", backend: "Backend" = None, ) -> "Array": - """Compute the CRPS for the log-logistic distribution.""" + """Compute the CRPS for the poisson distribution.""" B = backends.active if backend is None else backends[backend] - mulog, sigmalog, obs = map(B.asarray, (mulog, sigmalog, obs)) - F_ms = 1 / (1 + B.exp(-(B.log(obs) - mulog) / sigmalog)) - b = B.beta(1 + sigmalog, 1 - sigmalog) - I_B = B.betainc(1 + sigmalog, 1 - sigmalog, F_ms) - s = obs * (2 * F_ms - 1) - B.exp(mulog) * b * (2 * I_B + sigmalog - 1) + mean, obs = map(B.asarray, (mean, obs)) + F_m = _pois_cdf(obs, mean, backend=backend) + f_m = _pois_pdf(B.floor(obs), mean, backend=backend) + I0 = B.mbessel0(2 * mean) + I1 = B.mbessel1(2 * mean) + s = (obs - mean) * (2 * F_m - 1) + 2 * mean * f_m - mean * B.exp(-2 * mean) * (I0 + I1) return s @@ -506,7 +524,7 @@ def uniform( B = backends.active if backend is None else backends[backend] min, max, lmass, umass, obs = map(B.asarray, (min, max, lmass, umass, obs)) ω = (obs - min) / (max - min) - F_ω = B.minimum(B.maximum(ω, B.asarray(0)), B.asarray(1)) + F_ω = B.minimum(B.maximum(ω, B.asarray(0.0)), B.asarray(1.0)) s = B.abs(ω - F_ω) + (F_ω**2) * (1 - lmass - umass) - F_ω * (1 - 2 * lmass) + ((1 - lmass - umass)**2) / 3 + (1 - lmass) * umass return (max - min) * s diff --git a/scoringrules/core/stats.py b/scoringrules/core/stats.py index 1ffd19b..bc92d0e 100644 --- a/scoringrules/core/stats.py +++ b/scoringrules/core/stats.py @@ -42,13 +42,17 @@ def _gamma_cdf( def _pois_cdf(x: "ArrayLike", mean: "ArrayLike", backend: "Backend" = None) -> "Array": """Cumulative distribution function for the Poisson distribution.""" B = backends.active if backend is None else backends[backend] - return B.max(B.ui_gamma(B.floor(x + 1), mean) / B.gamma(B.floor(x + 1)), 0) + x_plus = B.abs(x) + p = B.gammauinc(B.floor(x_plus + 1), mean) / B.gamma(B.floor(x_plus + 1)) + return B.where(x < 0.0, 0.0, p) def _pois_pdf(x: "ArrayLike", mean: "ArrayLike", backend: "Backend" = None) -> "Array": """Probability mass function for the Poisson distribution.""" B = backends.active if backend is None else backends[backend] - return B.isinteger(x) * (mean**x * B.exp(-x) / B.factorial(x)) + x_plus = B.abs(x) + d = B.where(B.floor(x_plus) < x_plus, 0.0, mean**(x_plus) * B.exp(-mean) / B.factorial(x_plus)) + return B.where(mean < 0.0, B.nan, B.where(x < 0.0, 0.0, d)) def _t_pdf(x: "ArrayLike", df: "ArrayLike", backend: "Backend" = None) -> "Array": diff --git a/tests/test_crps.py b/tests/test_crps.py index 691203c..16d8d50 100644 --- a/tests/test_crps.py +++ b/tests/test_crps.py @@ -415,6 +415,24 @@ def test_normal(backend): assert not np.any(res - 0.0 > 0.0001) +@pytest.mark.parametrize("backend", BACKENDS) +def test_poisson(backend): + obs, mean = 1.0, 3.0 + res = _crps.crps_poisson(obs, mean, backend=backend) + expected = 1.143447 + assert np.isclose(res, expected) + + obs, mean = 1.5, 2.3 + res = _crps.crps_poisson(obs, mean, backend=backend) + expected = 0.5001159 + assert np.isclose(res, expected) + + obs, mean = -1.0, 1.5 + res = _crps.crps_poisson(obs, mean, backend=backend) + expected = 1.840259 + assert np.isclose(res, expected) + + @pytest.mark.parametrize("backend", BACKENDS) def test_uniform(backend): obs, min, max, lmass, umass = 0.3, -1.0, 2.1, 0.3, 0.1