From cd0dfec8e11bd032745b0553376791e240f80760 Mon Sep 17 00:00:00 2001 From: gord chung <5091603+chungg@users.noreply.github.com> Date: Mon, 11 Nov 2024 14:08:50 -0500 Subject: [PATCH] add mase and emae error metric functions (#108) --- benches/traquer.rs | 10 ++++ src/statistic/regression.rs | 84 +++++++++++++++++++++++++++----- tests/stat_regression_test.rs | 90 +++++++++++++++++++++++++++++++++++ 3 files changed, 173 insertions(+), 11 deletions(-) diff --git a/benches/traquer.rs b/benches/traquer.rs index 60f42f2..25f9283 100644 --- a/benches/traquer.rs +++ b/benches/traquer.rs @@ -708,6 +708,16 @@ fn criterion_benchmark(c: &mut Criterion) { black_box(statistic::regression::mda(&stats.close, &stats.open).collect::>()) }) }); + c.bench_function("stats-regress-mase", |b| { + b.iter(|| { + black_box(statistic::regression::mase(&stats.close, &stats.open).collect::>()) + }) + }); + c.bench_function("stats-regress-emae", |b| { + b.iter(|| { + black_box(statistic::regression::emae(&stats.close, &stats.open).collect::>()) + }) + }); } criterion_group!(benches, criterion_benchmark); diff --git a/src/statistic/regression.rs b/src/statistic/regression.rs index e3576f4..edb6ecb 100644 --- a/src/statistic/regression.rs +++ b/src/statistic/regression.rs @@ -5,6 +5,7 @@ //! determing causal relationships.[[1]](https://en.wikipedia.org/wiki/Regression_analysis) use std::iter; +use itertools::izip; use num_traits::cast::ToPrimitive; /// Mean Squared Error @@ -31,9 +32,7 @@ pub fn mse<'a, T: ToPrimitive>(data: &'a [T], estimate: &'a [T]) -> impl Iterato .enumerate() .zip(estimate) .scan(0.0, |state, ((cnt, observe), est)| { - *state += (observe.to_f64().unwrap() - est.to_f64().unwrap()) - .powi(2) - .max(0.0); + *state += (observe.to_f64().unwrap() - est.to_f64().unwrap()).powi(2); Some(*state / (cnt + 1) as f64) }) } @@ -65,9 +64,7 @@ pub fn rmse<'a, T: ToPrimitive>( .enumerate() .zip(estimate) .scan(0.0, |state, ((cnt, observe), est)| { - *state += (observe.to_f64().unwrap() - est.to_f64().unwrap()) - .powi(2) - .max(0.0); + *state += (observe.to_f64().unwrap() - est.to_f64().unwrap()).powi(2); Some((*state / (cnt + 1) as f64).sqrt()) }) } @@ -97,9 +94,7 @@ pub fn mae<'a, T: ToPrimitive>(data: &'a [T], estimate: &'a [T]) -> impl Iterato .enumerate() .zip(estimate) .scan(0.0, |state, ((cnt, observe), est)| { - *state += (observe.to_f64().unwrap() - est.to_f64().unwrap()) - .abs() - .max(0.0); + *state += (observe.to_f64().unwrap() - est.to_f64().unwrap()).abs(); Some(*state / (cnt + 1) as f64) }) } @@ -133,8 +128,7 @@ pub fn mape<'a, T: ToPrimitive>( .scan(0.0, |state, ((cnt, observe), est)| { *state += ((observe.to_f64().unwrap() - est.to_f64().unwrap()) / observe.to_f64().unwrap()) - .abs() - .max(0.0); + .abs(); Some(100.0 * *state / (cnt + 1) as f64) }) } @@ -205,3 +199,71 @@ pub fn mda<'a, T: ToPrimitive>(data: &'a [T], estimate: &'a [T]) -> impl Iterato }, )) } + +/// Mean Absolute Scaled Error +/// +/// A forecasting accuracy metric that compares the mean absolute error (MAE) of your prediction +/// to the MAE of a naive forecasting method, such as predicting the previous period's value. It +/// helps determine if your forecasting model outperforms a simple baseline model. +/// +/// ```math +/// MASE = \mathrm{mean}\left( \frac{\left| e_j \right|}{\frac{1}{T-1}\sum_{t=2}^T \left| Y_t-Y_{t-1}\right|} \right) = \frac{\frac{1}{J}\sum_{j}\left| e_j \right|}{\frac{1}{T-1}\sum_{t=2}^T \left| Y_t-Y_{t-1}\right|} +/// ``` +/// +/// ## Sources +/// +/// [[1]](https://en.wikipedia.org/wiki/Mean_absolute_scaled_error) +/// [[2]](https://otexts.com/fpp2/accuracy.html#scaled-errors) +/// +/// # Examples +/// +/// ``` +/// use traquer::statistic::regression; +/// +/// regression::mase(&[1.0,2.0,3.0,4.0,5.0], &[1.0,2.0,3.0,4.0,5.0]).collect::>(); +/// ``` +pub fn mase<'a, T: ToPrimitive>( + data: &'a [T], + estimate: &'a [T], +) -> impl Iterator + 'a { + let mae_est = mae(data, estimate); + let mae_naive = data.windows(2).zip(1..).scan(0.0, |state, (w, cnt)| { + *state += (w[1].to_f64().unwrap() - w[0].to_f64().unwrap()).abs(); + Some(*state / cnt as f64) + }); + + iter::once(f64::NAN).chain( + mae_est + .skip(1) + .zip(mae_naive) + .map(|(est, naive)| est / naive), + ) +} + +/// Envelope-weighted Mean Absolute Error +/// +/// A scale-independent, symmetric, error metric +/// +/// ## Sources +/// +/// [[1]](https://typethepipe.com/post/energy-forecasting-error-metrics/) +/// [[2]](https://pdfs.semanticscholar.org/cf04/65bce25d78ccda6d8c5d12e141099aa606f4.pdf) +/// +/// # Examples +/// +/// ``` +/// use traquer::statistic::regression; +/// +/// regression::emae(&[1.0,2.0,3.0,4.0,5.0], &[1.0,2.0,3.0,4.0,5.0]).collect::>(); +/// ``` +pub fn emae<'a, T: ToPrimitive>( + data: &'a [T], + estimate: &'a [T], +) -> impl Iterator + 'a { + izip!(data, estimate, 1..).scan((0.0, 0.0), |state, (actual, est, n)| { + let actual = actual.to_f64().unwrap(); + let est = est.to_f64().unwrap(); + *state = (state.0 + (actual - est).abs(), state.1 + actual.max(est)); + Some(state.0 / state.1 * 100. / n as f64) + }) +} diff --git a/tests/stat_regression_test.rs b/tests/stat_regression_test.rs index 1de61fc..3b98b78 100644 --- a/tests/stat_regression_test.rs +++ b/tests/stat_regression_test.rs @@ -243,3 +243,93 @@ fn test_mda() { result[1..] ); } + +#[test] +fn test_mase() { + let stats = common::test_data(); + let result = mase(&stats.close, &stats.high).collect::>(); + assert_eq!(stats.close.len(), result.len()); + assert_eq!( + [ + 0.2586956443618357, + 0.5906157252555707, + 0.6499810418057367, + 0.6629225651466101, + 0.7021785561555351, + 0.6645714378049574, + 0.7081077123039311, + 0.7506796751237231, + 0.7459128097414857, + 0.7868962578322483, + 0.8199328220627006, + 0.824063446272828, + 0.823141252303084, + 0.8208614333821371, + 0.8261013614605646, + 0.8292729100447778, + 0.8277639162122615, + 0.8046293515788161, + 0.838055482335999, + 0.8333744069800522, + 0.8428637803317439, + 0.851792033838831, + 0.8475633536685455, + 0.8248455333592076, + 0.8501631437433327, + 0.8600553148730471, + 0.8678200431183396, + 0.872561058669536, + 0.886738913172739, + 0.8889511537019651, + 0.8886292399826446, + 0.9112483905653928, + 0.9101752878693691 + ], + result[1..] + ); +} + +#[test] +fn test_emae() { + let stats = common::test_data(); + let result = emae(&stats.close, &stats.high).collect::>(); + assert_eq!(stats.close.len(), result.len()); + assert_eq!( + [ + 3.1609701950004765, + 3.004329977893074, + 2.279411642074049, + 1.8846528761568404, + 1.5081324070745339, + 1.1613516217236113, + 1.0184840364517995, + 0.8769276902702197, + 0.7337664883901667, + 0.6583605094700364, + 0.5913681072056433, + 0.5386264166475463, + 0.48645105828681245, + 0.44218650424889805, + 0.4093300617694232, + 0.373369288841444, + 0.3456666025333258, + 0.31527313536376456, + 0.30113560279553997, + 0.27748629221038573, + 0.2593089769405124, + 0.2415289624260403, + 0.22371297843199978, + 0.2091402134785959, + 0.20123096114490854, + 0.19340777037589577, + 0.1847392436861936, + 0.17683164475568086, + 0.16993788222726583, + 0.16244101999900978, + 0.1541488048243848, + 0.1514282289577207, + 0.14565246790002928 + ], + result[1..] + ); +}