Skip to content

Commit

Permalink
rock: Add tool to compute the cosine similarity of two WAVE files
Browse files Browse the repository at this point in the history
BUG=b:316075022
TEST=target/release/rock cosine out2.wav out3.wav

Change-Id: I7b04bb63494115eddce37992a857f6a3344bb89b
Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/third_party/adhd/+/5123808
Reviewed-by: Chih-Yang Hsia <paulhsia@chromium.org>
Tested-by: chromeos-cop-builder@chromeos-cop.iam.gserviceaccount.com <chromeos-cop-builder@chromeos-cop.iam.gserviceaccount.com>
Commit-Queue: Li-Yu Yu <aaronyu@google.com>
  • Loading branch information
afq984 authored and Chromeos LUCI committed Dec 20, 2023
1 parent bfc2ba1 commit 38e49df
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 0 deletions.
97 changes: 97 additions & 0 deletions rock/src/cosine.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Copyright 2023 The ChromiumOS Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

use std::path::PathBuf;

use anyhow::bail;
use clap::Args;

use crate::wav::read_wav;

#[derive(Args)]
pub(crate) struct CosineCommand {
/// Path to the first WAVE file
a: PathBuf,
/// Path to the second WAVE file
b: PathBuf,
/// Delay of b in frames.
#[clap(long, default_value = "0")]
delay: i32,
}

impl CosineCommand {
pub(crate) fn run(&self) -> anyhow::Result<()> {
let (a_spec, a_buf) = read_wav(&self.a)?;
let (b_spec, b_buf) = read_wav(&self.b)?;
if a_spec.sample_rate != a_spec.sample_rate {
bail!(
"sample rate mismatch: {} != {}",
a_spec.sample_rate,
b_spec.sample_rate
);
}

let val = cosine_with_delay(
a_buf.data[0].as_slice(),
b_buf.data[0].as_slice(),
self.delay,
);
println!("{val}");

Ok(())
}
}

/// Compute the cosine similarity of `aa` ` and `bb`.
/// The longer one is truncated.
fn cosine(aa: &[f32], bb: &[f32]) -> f64 {
let mut dot = 0f64;
let mut a_norm = 0f64;
let mut b_norm = 0f64;

for (a, b) in aa.iter().zip(bb) {
let a = *a as f64;
let b = *b as f64;
dot += a * b;
a_norm += a * a;
b_norm += b * b;
}

dot / a_norm.sqrt() / b_norm.sqrt()
}

fn cosine_with_delay(aa: &[f32], bb: &[f32], delay: i32) -> f64 {
if delay > 0 {
cosine(aa, &bb[delay as usize..])
} else {
cosine(&aa[-delay as usize..], bb)
}
}

#[cfg(test)]
mod tests {
use super::cosine;
use super::cosine_with_delay;

#[test]
fn test_cosine() {
assert_eq!(cosine(&[1.], &[1.]), 1.);
assert_eq!(cosine(&[1., 2.], &[1., 2.]), 1.);
assert_eq!(cosine(&[1., 2.], &[-1., -2.]), -1.);
assert_eq!(cosine(&[1., 2.], &[2., -1.]), 0.);
assert!(cosine(&[1.], &[0.]).is_nan());
assert!(cosine(&[1.], &[]).is_nan());
assert!(cosine(&[], &[]).is_nan());
}

#[test]
fn test_cosine_with_delay() {
let aa = [1., 2., 3., 4., 5.];
let bb = [2., 3., 4., 5., 6.];

assert_eq!(cosine_with_delay(&aa, &bb, -1), cosine(&aa[1..], &bb));
assert!((cosine_with_delay(&aa, &bb, -1) - 1.).abs() < 1e-6);
assert_eq!(cosine_with_delay(&aa, &bb, 1), cosine(&aa, &bb[1..]));
}
}
4 changes: 4 additions & 0 deletions rock/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

mod cosine;
mod delay;
pub(crate) mod wav;

Expand All @@ -18,12 +19,15 @@ struct Cli {
enum Commands {
/// Compute the delay of two WAVE files
Delay(delay::DelayCommand),
/// Compute the cosine similarity of two WAVE files
Cosine(cosine::CosineCommand),
}

impl Cli {
fn run(&self) -> anyhow::Result<()> {
match &self.command {
Commands::Delay(c) => c.run(),
Commands::Cosine(c) => c.run(),
}
}
}
Expand Down

0 comments on commit 38e49df

Please sign in to comment.