Skip to content

Commit

Permalink
Add new music implementation and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wattai committed Nov 3, 2024
1 parent 89a0b63 commit d2660be
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install ruff pytest
pip install -e ".[dev]"
- name: Run tests with pytest
run: pytest
14 changes: 14 additions & 0 deletions src/sse/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-

import abc

import numpy as np


class SoundSourceLocatorBase(abc.ABC):
@abc.abstractmethod
def fit_transform(
self,
X: np.ndarray,
) -> np.ndarray:
pass
96 changes: 96 additions & 0 deletions src/sse/music_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# -*- coding: utf-8 -*-

import numpy as np

from sse.base import SoundSourceLocatorBase


class MusicSoundSourceLocator(SoundSourceLocatorBase):
"""Music Sound Source Locator."""

def __init__(
self,
fs: float,
d: float,
c: float = 340.0,
N_theta: int = 180,
):
"""Implementation of MUSIC method.
Args:
fs: Sampling frequency [Hz].
d: distance between mics.
c: sound speed. Defaults to 340.0.
N_theta (int, optional): _description_. Defaults to 180.
"""

self.fs = fs # sampling frequency [Hz].
self.T = 1.0 / fs # sampling Term [sec].
self.d = d # mic. distance [m].
self.k = None # data index series [point].
self.L = None # data length [point].
self.c = c # sound speed [m/sec].
self.N_theta = N_theta
self.N_ch = None

self.minvec = None
self.thetas = None

def fit_transform(
self,
X: np.ndarray,
y: np.ndarray | None = None,
) -> "MusicSoundSourceLocator":
# x: input sound signal
# x.shape = (sample, N_ch)

self.L = len(X) // 1
self.k = np.arange(self.L // 2)
self.N_ch = X.shape[1]

winfunc = np.hanning(len(X)).reshape(-1, 1)

X_freq = np.fft.fft(winfunc * X, axis=0)[: self.L // 2, :]
self.R = (
X_freq.conj().reshape(-1, self.N_ch, 1) @ X_freq.reshape(-1, 1, self.N_ch)
) / (X_freq.shape[0] / self.fs)

self.eigval, self.eigvec = np.linalg.eig(self.R)
self.minvec = self.eigvec[:, :, -1].reshape(-1, self.N_ch, 1)

self.thetas = np.linspace(-np.pi / 2, np.pi / 2, self.N_theta)
self.S = self._calc_s_music(self.thetas)
return self._theta_hat()

def _theta_hat(self):
return np.rad2deg(self.thetas[(np.abs(self.S) ** 2).mean(0).argmax()])

def _calc_s_music(self, thetas):
return np.array(
[
np.sum(self._calc_alpha(theta).conj() * self._calc_alpha(theta), axis=0)
/ np.abs(
self.minvec.conj().transpose(0, 2, 1)
@ self._calc_alpha(theta).transpose(1, 0).reshape(-1, self.N_ch, 1)
).squeeze()
** 2
for theta in thetas
]
).T * np.sqrt(
self.eigval[:, 0][:, None]
) # / (self.eigval[:, -1][:, None]+1e-8) )

def _calc_alpha(self, theta):
return np.array(
[
np.ones(len(self.k)),
np.exp(
2j
* np.pi
* self.d
* np.sin(theta)
/ (self.c * self.T)
* ((self.k) / (len(self.k) * 2))
),
]
)
16 changes: 16 additions & 0 deletions tests/test_music.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# -*- coding: utf-8 -*-

import numpy as np

from sse.music_v2 import MusicSoundSourceLocator


class TestMusicSoundSourceLocator:
def test_fit_transform(self):
self.locator = MusicSoundSourceLocator(
fs=16000,
d=1.0,
)
X = np.random.randn(1000, 2)
out = self.locator.fit_transform(X=X)
print(out)

0 comments on commit d2660be

Please sign in to comment.