From d2008634c1b87735a37638345feb40ac279afe71 Mon Sep 17 00:00:00 2001 From: "Andrew X. Shah" Date: Fri, 1 Sep 2023 19:02:31 -0600 Subject: [PATCH] feat(lib): create stats module Create general functions. --- src/lib.rs | 2 + src/stats.rs | 175 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 177 insertions(+) create mode 100644 src/stats.rs diff --git a/src/lib.rs b/src/lib.rs index 04849b3..6df1b98 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,6 +6,7 @@ pub mod loss; pub mod network; pub mod optimizer; pub mod regularization; +pub mod stats; pub mod tensor; pub use activation::*; @@ -16,4 +17,5 @@ pub use loss::*; pub use network::*; pub use optimizer::*; pub use regularization::*; +pub use stats::*; pub use tensor::*; diff --git a/src/stats.rs b/src/stats.rs new file mode 100644 index 0000000..6088bd4 --- /dev/null +++ b/src/stats.rs @@ -0,0 +1,175 @@ +//! This module contains functions for calculating statistics. + +use std::{ + collections::HashMap, + hash::Hash, + ops::{Add, Sub}, +}; + +/// Returns the mean of a list of values. +/// +/// # Examples +/// +/// ``` +/// use engram::mean; +/// let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; +/// assert_eq!(mean(&values), 3.5); +/// ``` +pub fn mean + Copy>(data: &[T]) -> f64 { + let mut sum = 0.0; + for &x in data { + sum += x.into(); + } + sum / data.len() as f64 +} + +/// Returns the median of a list of values +/// +/// # Examples +/// +/// ``` +/// use engram::median; +/// let values = vec![1.0, 9.0, 2.5, 3.0, 2.0, 8.0]; +/// assert_eq!(median(&values), Some(2.75)); +/// ``` +pub fn median + Copy>(data: &[T]) -> Option { + if data.is_empty() { + return None; + } + + let mut sorted = data.iter().map(|&x| x.into()).collect::>(); + sorted.sort_by(|a, b| a.partial_cmp(b).unwrap()); + let mid = sorted.len() / 2; + + if sorted.len() % 2 == 0 { + Some((sorted[mid - 1] + sorted[mid]) / 2.0) + } else { + Some(sorted[mid]) + } +} + +/// Returns the mode of a list of values. +/// +/// # Examples +/// +/// ``` +/// use engram::mode; +/// let values = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 9]; +/// assert_eq!(mode(&values), Some(9)); +/// ``` +pub fn mode(data: &[T]) -> Option { + if data.is_empty() { + return None; + } + + let mut counts: HashMap = HashMap::new(); + data.iter().copied().max_by_key(|&x| { + let count = counts.entry(x).or_insert(0); + *count += 1; + *count + }) +} + +/// Returns the sample variance of a list of values. +/// +/// # Examples +/// +/// ``` +/// use engram::var; +/// let values = vec![1.0, 2.0, 3.0, 4.0, 5.0]; +/// assert_eq!(var(&values), 2.5); +/// ``` +pub fn var + Copy + Add + Sub>(data: &[T]) -> f64 { + let mean = mean(data); + data.iter() + .map(|&x| ((x.into() - mean).powi(2)) / (data.len() - 1) as f64) + .sum() +} + +/// Returns the population variance of a list of values. +/// +/// # Examples +/// +/// ``` +/// use engram::pop_var; +/// let values = vec![1.0, 2.0, 3.0, 4.0, 5.0]; +/// assert_eq!(pop_var(&values), 2.0); +/// ``` +pub fn pop_var + Copy + Add + Sub>(data: &[T]) -> f64 { + let mean = mean(data); + data.iter() + .map(|&x| ((x.into() - mean).powi(2)) / data.len() as f64) + .sum() +} + +/// Returns the sample standard deviation of a list of values. +/// +/// # Examples +/// +/// ``` +/// use engram::std; +/// let values = vec![1.0, 2.0, 3.0, 4.0, 5.0]; +/// assert_eq!(std(&values), 1.5811388300841898); +/// ``` +pub fn std + Copy + Add + Sub>(data: &[T]) -> f64 { + var(data).sqrt() +} + +/// Returns the population standard deviation of a list of values. +/// +/// # Examples +/// +/// ``` +/// use engram::pop_std; +/// let values = vec![1.0, 2.0, 3.0, 4.0, 5.0]; +/// assert_eq!(pop_std(&values), 1.4142135623730951); +/// ``` +pub fn pop_std + Copy + Add + Sub>(data: &[T]) -> f64 { + pop_var(data).sqrt() +} + +/// Returns the sample standard error of a list of values. +/// +/// # Examples +/// +/// ``` +/// use engram::mean_std_err; +/// let values = vec![1.0, 2.0, 3.0, 4.0, 5.0]; +/// assert_eq!(mean_std_err(&values), 0.7071067811865476); +/// ``` +pub fn mean_std_err + Copy + Add + Sub>(data: &[T]) -> f64 { + std(data) / (data.len() as f64).sqrt() +} + +/// Returns the mean confidence interval of a list of values. +/// Confidence level is a value between 0 and 1. +/// +/// # Examples +/// +/// ``` +/// use engram::mean_ci; +/// let values = vec![1.0, 2.0, 3.0, 4.0, 5.0]; +/// assert_eq!(mean_ci(&values, 0.95).unwrap(), (2.32824855787278, 3.67175144212722)); +/// ``` +#[derive(Debug)] +pub enum MeanCIError { + EmptyData, + InvalidConfidence, +} +pub fn mean_ci + Copy + Add + Sub>( + data: &[T], + confidence: f64, +) -> Result<(f64, f64), MeanCIError> { + if confidence <= 0.0 || confidence >= 1.0 { + return Err(MeanCIError::EmptyData); + } + if confidence <= 0.0 || confidence >= 1.0 { + return Err(MeanCIError::InvalidConfidence); + } + + let mean = mean(data); + let mean_std_err = mean_std_err(data); + let z_std_err = confidence * mean_std_err; + + Ok((mean - z_std_err, mean + z_std_err)) +}