Skip to content

Commit

Permalink
Add tests for signal delay simulation
Browse files Browse the repository at this point in the history
  • Loading branch information
wattai committed Nov 13, 2024
1 parent 7638e38 commit dd86d7c
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 44 deletions.
97 changes: 53 additions & 44 deletions src/sse/simulators/environments.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,36 +9,6 @@
import soundfile as sf


def make_signals(num_sources: int, num_mics: int):
pass


def calc_time_delays_from_position(
r: float,
theta: float,
phi: float,
c: float = 340.0,
) -> np.ndarray:
"""_summary_
Args:
r: Distance.
theta: Azimuth angle.
phi: Elevation angle.
c: Speed of sound. Defaults to 340.0.
Returns:
Position.
"""
positions = (
r * np.cos(theta) * np.sin(phi),
r * np.sin(theta) * np.sin(phi),
r * np.cos(theta),
)
tau = np.linalg.vector_norm(positions)
return tau


class Signal(BaseModel):
values: list[float]
sampling_frequency: float
Expand All @@ -55,7 +25,7 @@ class BaseDevice(abc.ABC):


class BaseSource(BaseDevice):
@abc.abstructmethod
@abc.abstractmethod
def ring(self) -> Signal:
pass

Expand Down Expand Up @@ -122,10 +92,7 @@ def __init__(
self.sampling_frequency = sampling_frequency

def record(self, signals: list[Signal]) -> Signal:
return Signal(
values=signals,
sampling_frequency=self.sampling_frequency,
)
return overlap(signals, self.sampling_frequency)


def overlap(signals: list[Signal], sampling_frequency: float) -> Signal:
Expand All @@ -140,12 +107,16 @@ def overlap(signals: list[Signal], sampling_frequency: float) -> Signal:
)


def resample(signal: Signal, sampling_frequency: float) -> Signal:
return sp.signal.resample(
signal.values,
num=np.ceil(
len(signal.values) * (sampling_frequency / signal.sampling_frequency),
),
def resample(signal: Signal, new_sampling_frequency: float) -> Signal:
# 変換するサンプル数を計算
num_samples = int(
round(len(signal.values) * (new_sampling_frequency / signal.sampling_frequency))
)

# 'num' を整数にキャスト
return Signal(
values=list(sp.signal.resample(signal.values, num_samples)),
sampling_frequency=new_sampling_frequency,
)


Expand Down Expand Up @@ -207,7 +178,7 @@ def calc_received_signals(
[
delay(
signal=source.ring(),
distance=mic.position.distance(source.position),
distance=calc_distance(mic.position, source.position),
sound_speed=sound_speed,
)
for source in sources
Expand All @@ -217,9 +188,47 @@ def calc_received_signals(
]


def calc_distance(p1: Position3D, p2: Position3D) -> float:
return euclidean_distance(
p1=(p1.r, p1.theta, p1.phi),
p2=(p2.r, p2.theta, p2.phi),
)


def polar_to_cartesian(r, theta, phi):
"""極座標をデカルト座標に変換"""
x = r * np.sin(theta) * np.cos(phi)
y = r * np.sin(theta) * np.sin(phi)
z = r * np.cos(theta)
return x, y, z


def euclidean_distance(p1, p2):
"""2つの点間のユークリッド距離を計算"""
(r1, theta1, phi1) = p1
(r2, theta2, phi2) = p2

x1, y1, z1 = polar_to_cartesian(r1, theta1, phi1)
x2, y2, z2 = polar_to_cartesian(r2, theta2, phi2)

distance = np.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2 + (z2 - z1) ** 2)
return distance


def delay(signal: Signal, distance: float, sound_speed: float) -> Signal:
num_delay_points = signal.sampling_frequency * (distance / sound_speed)
# num_delay_points を整数にキャスト
num_delay_points = int(round(signal.sampling_frequency * (distance / sound_speed)))

# スライシングに整数を使用
return Signal(
values=np.pad(signal.values[num_delay_points:], ((0, num_delay_points))),
values=np.pad(signal.values[num_delay_points:], (0, num_delay_points)),
sampling_frequency=signal.sampling_frequency,
)


# def delay(signal: Signal, distance: float, sound_speed: float) -> Signal:
# num_delay_points = signal.sampling_frequency * (distance / sound_speed)
# return Signal(
# values=np.pad(signal.values[num_delay_points:], ((0, num_delay_points))),
# sampling_frequency=signal.sampling_frequency,
# )
74 changes: 74 additions & 0 deletions tests/simulators/test_environments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import numpy as np
import pytest

from sse.simulators.environments import (
Observer,
Air,
Microphone,
Source,
Position3D,
SineSignalGenerator,
calc_distance,
)


@pytest.mark.parametrize(
"source, microphone",
[
(
Source(
position=Position3D(r=50, theta=np.pi, phi=np.pi / 2),
signal=SineSignalGenerator(frequency=6000).generate(
sampling_frequency=16000,
time_length=1,
),
),
Microphone(
position=Position3D(r=100, theta=-0.5, phi=1.1),
sampling_frequency=16000,
),
),
],
)
def test_num_delayed_points_of_signal(source: Source, microphone: Microphone):
medium = Air()
obs = Observer(
sources=[source],
microphones=[microphone],
medium=medium,
)
outs = obs.ring_sources()
shift = find_max_correlation_shift(
signal1=obs.sources[0].signal.values,
signal2=outs[0].values,
)
print(f"相互相関が最大となるシフト量: {shift}")
assert shift == int(
round(
calc_distance(source.position, microphone.position)
/ medium.sound_speed
* microphone.sampling_frequency
)
)


def find_max_correlation_shift(signal1, signal2):
"""
2つの信号の相互相関が最大となる位置を返す関数
Args:
signal1 (array-like): 最初の信号配列
signal2 (array-like): 2番目の信号配列
Returns:
int: 相互相関が最大になるときの信号2のシフト量
"""
# 長さを揃えるためのゼロパディング
len(signal1) + len(signal2) - 1
corr = np.correlate(signal1, signal2, mode="full")

# 最大相互相関値のインデックスを見つける
max_corr_index = np.argmax(corr)

# ズレ位置を計算(中央を基準としてラグを計算)
shift = max_corr_index - (len(signal2) - 1)
return shift

0 comments on commit dd86d7c

Please sign in to comment.