diff --git a/benches/traquer.rs b/benches/traquer.rs index c532c70..4302cd8 100644 --- a/benches/traquer.rs +++ b/benches/traquer.rs @@ -611,6 +611,9 @@ fn criterion_benchmark(c: &mut Criterion) { black_box(correlation::hoeffd(&stats.close, &stats.close, 16).collect::>()) }) }); + c.bench_function("correlation-dcor", |b| { + b.iter(|| black_box(correlation::dcor(&stats.close, &stats.close, 16).collect::>())) + }); c.bench_function("stats-dist-variance", |b| { b.iter(|| { diff --git a/src/correlation.rs b/src/correlation.rs index a7cf639..82275e1 100644 --- a/src/correlation.rs +++ b/src/correlation.rs @@ -413,3 +413,101 @@ pub fn hoeffd<'a, T: ToPrimitive + PartialOrd + Clone>( }), ) } + +/// Distance Correlation +/// +/// Measures both linear and nonlinear association between two random variables or random vectors. +/// Not to be confused with correlation distance which is related to Pearson Coefficient[3]. +/// +/// ## Usage +/// +/// Generates a value between 0 and 1 where 0 implies series are independent and 1 implies they are +/// surely equal. +/// +/// ## Sources +/// +/// [[1]](https://en.wikipedia.org/wiki/Distance_correlation) +/// [[2]](https://www.freecodecamp.org/news/how-machines-make-predictions-finding-correlations-in-complex-data-dfd9f0d87889/) +/// [[3]](https://en.wikipedia.org/wiki/Pearson_correlation_coefficient#Pearson's_distance) +/// +/// # Examples +/// +/// ``` +/// use traquer::correlation; +/// +/// correlation::dcor( +/// &[1.0,2.0,3.0,4.0,5.0,6.0,4.0,5.0], +/// &[1.0,2.0,3.0,4.0,5.0,6.0,4.0,5.0], +/// 6).collect::>(); +/// +/// ``` +pub fn dcor<'a, T: ToPrimitive>( + series1: &'a [T], + series2: &'a [T], + window: usize, +) -> impl Iterator + 'a { + fn centred_matrix(x: &[T]) -> Vec { + let n = x.len(); + // flattened NxN distance matrix, where [x_00..x0j, ... ,x_i0..x_ij] + let mut matrix = vec![0.0; n * n]; + for i in 0..n { + for j in 0..n { + matrix[(i * n) + j] = (x[i].to_f64().unwrap() - x[j].to_f64().unwrap()).abs(); + } + } + + // "double-centre" the matrix + let row_means: Vec = (0..matrix.len()) + .step_by(n) + .map(|i| matrix[i..i + n].iter().sum::() / n as f64) + .collect(); + let col_means: Vec = (0..n) + .map(|i| { + (i..matrix.len()) + .step_by(n) + .fold(0.0, |acc, j| acc + matrix[j]) + / n as f64 + }) + .collect(); + let matrix_mean: f64 = matrix.iter().sum::() / (n * n) as f64; + for i in 0..n { + for j in 0..n { + matrix[(i * n) + j] += -row_means[i] - col_means[j] + matrix_mean; + } + } + matrix + } + + iter::repeat(f64::NAN).take(window - 1).chain( + series1 + .windows(window) + .zip(series2.windows(window)) + .map(move |(x_win, y_win)| { + let centred_x = centred_matrix(x_win); + let centred_y = centred_matrix(y_win); + let dcov = (centred_x + .iter() + .zip(¢red_y) + .map(|(a, b)| a * b) + .sum::() + / window.pow(2) as f64) + .sqrt(); + let dvar_x = (centred_x + .iter() + .zip(¢red_x) + .map(|(a, b)| a * b) + .sum::() + / window.pow(2) as f64) + .sqrt(); + let dvar_y = (centred_y + .iter() + .zip(¢red_y) + .map(|(a, b)| a * b) + .sum::() + / window.pow(2) as f64) + .sqrt(); + + dcov / (dvar_x * dvar_y).sqrt() + }), + ) +} diff --git a/tests/correlation_test.rs b/tests/correlation_test.rs index 85a5da6..7e1efb4 100644 --- a/tests/correlation_test.rs +++ b/tests/correlation_test.rs @@ -244,3 +244,46 @@ fn test_hoeffd() { result[16 - 1..] ); } + +#[test] +fn test_dcor() { + let stats = common::test_data(); + let stats2 = common::test_data_path("./tests/sp500.input"); + let ln_ret1 = stats + .close + .iter() + .zip(&stats.close[1..]) + .map(|(x, y)| (y / x).ln()) + .collect::>(); + let ln_ret2 = stats2 + .close + .iter() + .zip(&stats2.close[1..]) + .map(|(x, y)| (y / x).ln()) + .collect::>(); + let result = correlation::dcor(&ln_ret1, &ln_ret2, 16).collect::>(); + assert_eq!(ln_ret1.len(), result.len()); + assert_eq!( + vec![ + 0.39390614319365574, + 0.39845847602318907, + 0.4012684752778961, + 0.4532725521408008, + 0.5335623994772698, + 0.5899262972738498, + 0.6886450053961184, + 0.7578847633388898, + 0.7748853182014356, + 0.7670646492585647, + 0.8058436110412499, + 0.818014456822133, + 0.8064793069072755, + 0.687572965245447, + 0.6362198562198043, + 0.5628963827860937, + 0.5699710776508861, + 0.44892542156220505 + ], + result[16 - 1..] + ); +}