Skip to content

Commit

Permalink
Handle linear ssq_freqs in ssq_stft
Browse files Browse the repository at this point in the history
and warn user about passing both `fs` and `t`
  • Loading branch information
OverLordGoldDragon authored Apr 21, 2021
1 parent 785a115 commit 7fa6fd4
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
8 changes: 7 additions & 1 deletion ssqueezepy/_ssq_stft.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
from ._stft import stft, get_window, _check_NOLA
from ._ssq_cwt import _invert_components, _process_component_inversion_args
from .utils.cwt_utils import _process_fs_and_t
from .utils.cwt_utils import _process_fs_and_t, infer_scaletype
from .utils.common import WARN, EPS32, EPS64
from .utils import backend as S
from .utils.backend import torch
Expand All @@ -28,6 +28,7 @@ def ssq_stft(x, window=None, n_fft=None, win_len=None, hop_len=1, fs=None, t=Non
ssq_freqs, squeezing
See `help(ssqueezing.ssqueeze)`.
`ssq_freqs`, if array, must be linearly distributed.
gamma: float / None
See `help(_ssq_cwt.phase_cwt)`.
Expand Down Expand Up @@ -74,6 +75,11 @@ def ssq_stft(x, window=None, n_fft=None, win_len=None, hop_len=1, fs=None, t=Non
raise NotImplementedError("`get_w=True` unsupported with batched input.")
_, fs, _ = _process_fs_and_t(fs, t, x.shape[-1])
_check_ssqueezing_args(squeezing)
# assert ssq_freqs, if array, is linear
if (isinstance(ssq_freqs, np.ndarray) and
infer_scaletype(ssq_freqs) != 'linear'):
raise ValueError("`ssq_freqs` must be linearly distributed "
"for `ssq_stft`")

Sx, dSx = stft(x, window, n_fft=n_fft, win_len=win_len, hop_len=hop_len,
fs=fs, padtype=padtype, modulated=modulated, derivative=True,
Expand Down
2 changes: 2 additions & 0 deletions ssqueezepy/utils/cwt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,8 @@ def _process_fs_and_t(fs, t, N):
"""Ensures `t` is uniformly-spaced and of same length as `x` (==N)
and returns `fs` and `dt` based on it, or from defaults if `t` is None.
"""
if fs is not None and t is not None:
WARN("`t` will override `fs` (both were passed)")
if t is not None:
if len(t) != N:
# not explicitly used anywhere but ensures wrong `t` wasn't supplied
Expand Down

0 comments on commit 7fa6fd4

Please sign in to comment.