Skip to content

Commit

Permalink
add mase and emae error metric functions (#108)
Browse files Browse the repository at this point in the history
  • Loading branch information
chungg authored Nov 11, 2024
1 parent 5c9c11b commit cd0dfec
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 11 deletions.
10 changes: 10 additions & 0 deletions benches/traquer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,16 @@ fn criterion_benchmark(c: &mut Criterion) {
black_box(statistic::regression::mda(&stats.close, &stats.open).collect::<Vec<_>>())
})
});
c.bench_function("stats-regress-mase", |b| {
b.iter(|| {
black_box(statistic::regression::mase(&stats.close, &stats.open).collect::<Vec<_>>())
})
});
c.bench_function("stats-regress-emae", |b| {
b.iter(|| {
black_box(statistic::regression::emae(&stats.close, &stats.open).collect::<Vec<_>>())
})
});
}

criterion_group!(benches, criterion_benchmark);
Expand Down
84 changes: 73 additions & 11 deletions src/statistic/regression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
//! determing causal relationships.[[1]](https://en.wikipedia.org/wiki/Regression_analysis)
use std::iter;

use itertools::izip;
use num_traits::cast::ToPrimitive;

/// Mean Squared Error
Expand All @@ -31,9 +32,7 @@ pub fn mse<'a, T: ToPrimitive>(data: &'a [T], estimate: &'a [T]) -> impl Iterato
.enumerate()
.zip(estimate)
.scan(0.0, |state, ((cnt, observe), est)| {
*state += (observe.to_f64().unwrap() - est.to_f64().unwrap())
.powi(2)
.max(0.0);
*state += (observe.to_f64().unwrap() - est.to_f64().unwrap()).powi(2);
Some(*state / (cnt + 1) as f64)
})
}
Expand Down Expand Up @@ -65,9 +64,7 @@ pub fn rmse<'a, T: ToPrimitive>(
.enumerate()
.zip(estimate)
.scan(0.0, |state, ((cnt, observe), est)| {
*state += (observe.to_f64().unwrap() - est.to_f64().unwrap())
.powi(2)
.max(0.0);
*state += (observe.to_f64().unwrap() - est.to_f64().unwrap()).powi(2);
Some((*state / (cnt + 1) as f64).sqrt())
})
}
Expand Down Expand Up @@ -97,9 +94,7 @@ pub fn mae<'a, T: ToPrimitive>(data: &'a [T], estimate: &'a [T]) -> impl Iterato
.enumerate()
.zip(estimate)
.scan(0.0, |state, ((cnt, observe), est)| {
*state += (observe.to_f64().unwrap() - est.to_f64().unwrap())
.abs()
.max(0.0);
*state += (observe.to_f64().unwrap() - est.to_f64().unwrap()).abs();
Some(*state / (cnt + 1) as f64)
})
}
Expand Down Expand Up @@ -133,8 +128,7 @@ pub fn mape<'a, T: ToPrimitive>(
.scan(0.0, |state, ((cnt, observe), est)| {
*state += ((observe.to_f64().unwrap() - est.to_f64().unwrap())
/ observe.to_f64().unwrap())
.abs()
.max(0.0);
.abs();
Some(100.0 * *state / (cnt + 1) as f64)
})
}
Expand Down Expand Up @@ -205,3 +199,71 @@ pub fn mda<'a, T: ToPrimitive>(data: &'a [T], estimate: &'a [T]) -> impl Iterato
},
))
}

/// Mean Absolute Scaled Error
///
/// A forecasting accuracy metric that compares the mean absolute error (MAE) of your prediction
/// to the MAE of a naive forecasting method, such as predicting the previous period's value. It
/// helps determine if your forecasting model outperforms a simple baseline model.
///
/// ```math
/// MASE = \mathrm{mean}\left( \frac{\left| e_j \right|}{\frac{1}{T-1}\sum_{t=2}^T \left| Y_t-Y_{t-1}\right|} \right) = \frac{\frac{1}{J}\sum_{j}\left| e_j \right|}{\frac{1}{T-1}\sum_{t=2}^T \left| Y_t-Y_{t-1}\right|}
/// ```
///
/// ## Sources
///
/// [[1]](https://en.wikipedia.org/wiki/Mean_absolute_scaled_error)
/// [[2]](https://otexts.com/fpp2/accuracy.html#scaled-errors)
///
/// # Examples
///
/// ```
/// use traquer::statistic::regression;
///
/// regression::mase(&[1.0,2.0,3.0,4.0,5.0], &[1.0,2.0,3.0,4.0,5.0]).collect::<Vec<f64>>();
/// ```
pub fn mase<'a, T: ToPrimitive>(
data: &'a [T],
estimate: &'a [T],
) -> impl Iterator<Item = f64> + 'a {
let mae_est = mae(data, estimate);
let mae_naive = data.windows(2).zip(1..).scan(0.0, |state, (w, cnt)| {
*state += (w[1].to_f64().unwrap() - w[0].to_f64().unwrap()).abs();
Some(*state / cnt as f64)
});

iter::once(f64::NAN).chain(
mae_est
.skip(1)
.zip(mae_naive)
.map(|(est, naive)| est / naive),
)
}

/// Envelope-weighted Mean Absolute Error
///
/// A scale-independent, symmetric, error metric
///
/// ## Sources
///
/// [[1]](https://typethepipe.com/post/energy-forecasting-error-metrics/)
/// [[2]](https://pdfs.semanticscholar.org/cf04/65bce25d78ccda6d8c5d12e141099aa606f4.pdf)
///
/// # Examples
///
/// ```
/// use traquer::statistic::regression;
///
/// regression::emae(&[1.0,2.0,3.0,4.0,5.0], &[1.0,2.0,3.0,4.0,5.0]).collect::<Vec<f64>>();
/// ```
pub fn emae<'a, T: ToPrimitive>(
data: &'a [T],
estimate: &'a [T],
) -> impl Iterator<Item = f64> + 'a {
izip!(data, estimate, 1..).scan((0.0, 0.0), |state, (actual, est, n)| {
let actual = actual.to_f64().unwrap();
let est = est.to_f64().unwrap();
*state = (state.0 + (actual - est).abs(), state.1 + actual.max(est));
Some(state.0 / state.1 * 100. / n as f64)
})
}
90 changes: 90 additions & 0 deletions tests/stat_regression_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,3 +243,93 @@ fn test_mda() {
result[1..]
);
}

#[test]
fn test_mase() {
let stats = common::test_data();
let result = mase(&stats.close, &stats.high).collect::<Vec<_>>();
assert_eq!(stats.close.len(), result.len());
assert_eq!(
[
0.2586956443618357,
0.5906157252555707,
0.6499810418057367,
0.6629225651466101,
0.7021785561555351,
0.6645714378049574,
0.7081077123039311,
0.7506796751237231,
0.7459128097414857,
0.7868962578322483,
0.8199328220627006,
0.824063446272828,
0.823141252303084,
0.8208614333821371,
0.8261013614605646,
0.8292729100447778,
0.8277639162122615,
0.8046293515788161,
0.838055482335999,
0.8333744069800522,
0.8428637803317439,
0.851792033838831,
0.8475633536685455,
0.8248455333592076,
0.8501631437433327,
0.8600553148730471,
0.8678200431183396,
0.872561058669536,
0.886738913172739,
0.8889511537019651,
0.8886292399826446,
0.9112483905653928,
0.9101752878693691
],
result[1..]
);
}

#[test]
fn test_emae() {
let stats = common::test_data();
let result = emae(&stats.close, &stats.high).collect::<Vec<_>>();
assert_eq!(stats.close.len(), result.len());
assert_eq!(
[
3.1609701950004765,
3.004329977893074,
2.279411642074049,
1.8846528761568404,
1.5081324070745339,
1.1613516217236113,
1.0184840364517995,
0.8769276902702197,
0.7337664883901667,
0.6583605094700364,
0.5913681072056433,
0.5386264166475463,
0.48645105828681245,
0.44218650424889805,
0.4093300617694232,
0.373369288841444,
0.3456666025333258,
0.31527313536376456,
0.30113560279553997,
0.27748629221038573,
0.2593089769405124,
0.2415289624260403,
0.22371297843199978,
0.2091402134785959,
0.20123096114490854,
0.19340777037589577,
0.1847392436861936,
0.17683164475568086,
0.16993788222726583,
0.16244101999900978,
0.1541488048243848,
0.1514282289577207,
0.14565246790002928
],
result[1..]
);
}

0 comments on commit cd0dfec

Please sign in to comment.