From 7fa6fd41c63c666d29c6492c05c371ed415e49e8 Mon Sep 17 00:00:00 2001 From: OverLordGoldDragon <16495490+OverLordGoldDragon@users.noreply.github.com> Date: Wed, 21 Apr 2021 23:03:33 +0400 Subject: [PATCH] Handle linear `ssq_freqs` in `ssq_stft` and warn user about passing both `fs` and `t` --- ssqueezepy/_ssq_stft.py | 8 +++++++- ssqueezepy/utils/cwt_utils.py | 2 ++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/ssqueezepy/_ssq_stft.py b/ssqueezepy/_ssq_stft.py index eed806d..796123e 100644 --- a/ssqueezepy/_ssq_stft.py +++ b/ssqueezepy/_ssq_stft.py @@ -2,7 +2,7 @@ import numpy as np from ._stft import stft, get_window, _check_NOLA from ._ssq_cwt import _invert_components, _process_component_inversion_args -from .utils.cwt_utils import _process_fs_and_t +from .utils.cwt_utils import _process_fs_and_t, infer_scaletype from .utils.common import WARN, EPS32, EPS64 from .utils import backend as S from .utils.backend import torch @@ -28,6 +28,7 @@ def ssq_stft(x, window=None, n_fft=None, win_len=None, hop_len=1, fs=None, t=Non ssq_freqs, squeezing See `help(ssqueezing.ssqueeze)`. + `ssq_freqs`, if array, must be linearly distributed. gamma: float / None See `help(_ssq_cwt.phase_cwt)`. @@ -74,6 +75,11 @@ def ssq_stft(x, window=None, n_fft=None, win_len=None, hop_len=1, fs=None, t=Non raise NotImplementedError("`get_w=True` unsupported with batched input.") _, fs, _ = _process_fs_and_t(fs, t, x.shape[-1]) _check_ssqueezing_args(squeezing) + # assert ssq_freqs, if array, is linear + if (isinstance(ssq_freqs, np.ndarray) and + infer_scaletype(ssq_freqs) != 'linear'): + raise ValueError("`ssq_freqs` must be linearly distributed " + "for `ssq_stft`") Sx, dSx = stft(x, window, n_fft=n_fft, win_len=win_len, hop_len=hop_len, fs=fs, padtype=padtype, modulated=modulated, derivative=True, diff --git a/ssqueezepy/utils/cwt_utils.py b/ssqueezepy/utils/cwt_utils.py index 3b48810..9895a21 100644 --- a/ssqueezepy/utils/cwt_utils.py +++ b/ssqueezepy/utils/cwt_utils.py @@ -697,6 +697,8 @@ def _process_fs_and_t(fs, t, N): """Ensures `t` is uniformly-spaced and of same length as `x` (==N) and returns `fs` and `dt` based on it, or from defaults if `t` is None. """ + if fs is not None and t is not None: + WARN("`t` will override `fs` (both were passed)") if t is not None: if len(t) != N: # not explicitly used anywhere but ensures wrong `t` wasn't supplied