diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..b6e47617 --- /dev/null +++ b/.gitignore @@ -0,0 +1,129 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ diff --git a/README.md b/README.md new file mode 100644 index 00000000..0279e66e --- /dev/null +++ b/README.md @@ -0,0 +1,99 @@ +# Synchrosqueezing in Python +Synchrosqueezing Toolbox ported to Python, authored by Eugene Brevdo, Gaurav Thakur. Original: https://github.com/ebrevdo/synchrosqueezing + +**Reviewers needed**; the repository is in a development stage - details below. + +## Features + - Forward & inverse CWT- and STFT-based Synchrosqueezing + - Forward & inverse discretized Continuous Wavelet Transform (CWT) + - Forward & inverse discretized Short-Time Fourier Transform (STFT) + - Phase CWT & STFT + +## Reviewers needed +An eventual goal is a merged Pull Request to PyWavelets ([relevant Issue](https://github.com/PyWavelets/pywt/issues/258)). Points of review include: + 1. **Correctness**; plots generated from CWT transforms resemble those in the publications, but not entirely + 2. **Completeness**; parts of code in the original Github are missing or incorrect, and I'm unable to replace them all + 3. **Unit tests**; I'm not familiar with Synchrosqueezing itself, so I don't know how to validate its various functionalities + 4. **Licensing**; unsure how to proceed here; [original's](https://github.com/ebrevdo/synchrosqueezing/blob/master/LICENSE) says to "Redistributions in binary form must reproduce the above copyright notice" - but I'm not "redistributing" it, I'm distributing my rewriting of it + 5. **Code style**; I'm aware PyWavelets conforms with PEP8 (but I don't), so I'll edit PR code accordingly + +## Review To-do: + +**Correctness**: + - [ ] 1. Example 1 + - [ ] 2. Example 2 + +**Completeness**: + - [ ] 1. `freqband` in `synsq_cwt_inv` and `synsq_stft_inf` is defaulted to an integer, but indexed into as a 2D array; the two have nearly identical docstrings, and reference the same equation, but the equation appears completely irrelevant to both. + - [ ] 2. `quadgk` has been ported as quadpy's [`quad`](https://github.com/nschloe/quadpy/blob/master/quadpy/line_segment/_tools.py#L16) (linked its wrapped function), which does not support infinite integration bounds, and has [trouble](https://github.com/nschloe/quadpy/issues/236) with computing `synsq_stft_inv`'s integral. Needs a workaround. + - [ ] 3. As seen in examples, the y-axis shows "scales", not frequencies - and the relation between the two is neither clear nor linear; it also isn't linear w.r.t. `len(t)`, `nv`, or `fs`. Publications show frequencies instead. + +**Unit tests**: Whatever suffices for PyWavelets will suffice for me + +## Implementation To-do: + One checkmark = code written; two = reviewed + +| Status | Toolbox name | Repository name | Repository file | +| --- | --- | --- | --- | +| [ ] [**x**] | [`synsq_cwt_fw`](https://github.com/ebrevdo/synchrosqueezing/blob/master/synchrosqueezing/synsq_cwt_fw.m) | `synsq_cwt_fwd` | [synsq_cwt.py](https://github.com/OverLordGoldDragon/synchrosqueezing_python/blob/master/synchrosqueezing/synsq_cwt.py) | +| [ ] [**x**] | [`synsq_cwt_iw`](https://github.com/ebrevdo/synchrosqueezing/blob/master/synchrosqueezing/synsq_cwt_iw.m) | `synsq_cwt_inv` | [synsq_cwt.py](https://github.com/OverLordGoldDragon/synchrosqueezing_python/blob/master/synchrosqueezing/synsq_cwt.py) | +| [ ] [**x**] | [`synsq_stft_fw`](https://github.com/ebrevdo/synchrosqueezing/blob/master/synchrosqueezing/synsq_stft_fw.m) | `synsq_stft_fwd` | [synsq_stft.py](https://github.com/OverLordGoldDragon/synchrosqueezing_python/blob/master/synchrosqueezing/synsq_stft.py) | +| [ ] [**x**] | [`synsq_stft_iw`](https://github.com/ebrevdo/synchrosqueezing/blob/master/synchrosqueezing/synsq_stft_iw.m) | `synsq_stft_inv` | [synsq_stft.py](https://github.com/OverLordGoldDragon/synchrosqueezing_python/blob/master/synchrosqueezing/synsq_stft.py) | +| [ ] [**x**] | [`synsq_squeeze`](https://github.com/ebrevdo/synchrosqueezing/blob/master/synchrosqueezing/synsq_squeeze.m) | `synsq_squeeze` | [wavelet_transforms.py](https://github.com/OverLordGoldDragon/synchrosqueezing_python/blob/master/synchrosqueezing/wavelet_transforms.py) | +| [ ] [**x**] | [`synsq_cwt_squeeze`](https://github.com/ebrevdo/synchrosqueezing/blob/master/synchrosqueezing/synsq_cwt_squeeze.m) | `synsq_cwt_squeeze` | [wavelet_transforms.py](https://github.com/OverLordGoldDragon/synchrosqueezing_python/blob/master/synchrosqueezing/wavelet_transforms.py) | +| [ ] [**x**] | [`phase_cwt`](https://github.com/ebrevdo/synchrosqueezing/blob/master/synchrosqueezing/phase_cwt.m) | `phase_cwt` | [wavelet_transforms.py](https://github.com/OverLordGoldDragon/synchrosqueezing_python/blob/master/synchrosqueezing/wavelet_transforms.py) | +| [ ] [**x**] | [`phase_cwt_num`](https://github.com/ebrevdo/synchrosqueezing/blob/master/synchrosqueezing/phase_cwt_num.m) | `phase_cwt_num` | [wavelet_transforms.py](https://github.com/OverLordGoldDragon/synchrosqueezing_python/blob/master/synchrosqueezing/wavelet_transforms.py) | +| [ ] [**x**] | [`cwt_fw`](https://github.com/ebrevdo/synchrosqueezing/blob/master/synchrosqueezing/cwt_fw.m) | `cwt_fwd` | [wavelet_transforms.py](https://github.com/OverLordGoldDragon/synchrosqueezing_python/blob/master/synchrosqueezing/wavelet_transforms.py) | +| [ ] [ ] | [`cwt_iw`](https://github.com/ebrevdo/synchrosqueezing/blob/master/synchrosqueezing/cwt_iw.m) | +| [ ] [**x**] | [`stft_fw`](https://github.com/ebrevdo/synchrosqueezing/blob/master/synchrosqueezing/stft_fw.m) | `stft_fwd` | [stft_transforms.py](https://github.com/OverLordGoldDragon/synchrosqueezing_python/blob/master/synchrosqueezing/stft_transforms.py) | +| [ ] [**x**] | [`stft_iw`](https://github.com/ebrevdo/synchrosqueezing/blob/master/synchrosqueezing/stft_iw.m) | `stft_inv` | [stft_transforms.py](https://github.com/OverLordGoldDragon/synchrosqueezing_python/blob/master/synchrosqueezing/stft_transforms.py) | +| [ ] [**x**] | [`phase_stft`](https://github.com/ebrevdo/synchrosqueezing/blob/master/synchrosqueezing/phase_stft.m) | `phase_stft` | [stft_transforms.py](https://github.com/OverLordGoldDragon/synchrosqueezing_python/blob/master/synchrosqueezing/stft_transforms.py) | +| [ ] [**x**] | [`padsignal`](https://github.com/ebrevdo/synchrosqueezing/blob/master/synchrosqueezing/padsignal.m) | `padsignal` | [utils.py](https://github.com/OverLordGoldDragon/synchrosqueezing_python/blob/master/synchrosqueezing/utils.py) | +| [ ] [**x**] | [`wfiltfn`](https://github.com/ebrevdo/synchrosqueezing/blob/master/synchrosqueezing/wfiltfn.m) | `wfiltfn` | [utils.py](https://github.com/OverLordGoldDragon/synchrosqueezing_python/blob/master/synchrosqueezing/utils.py) | +| [ ] [ ] | [`wfilth`](https://github.com/ebrevdo/synchrosqueezing/blob/master/synchrosqueezing/wfilth.m) | +| [ ] [**x**] | [`synsq_adm`](https://github.com/ebrevdo/synchrosqueezing/blob/master/synchrosqueezing/synsq_adm.m) | `synsq_adm` | [utils.py](https://github.com/OverLordGoldDragon/synchrosqueezing_python/blob/master/synchrosqueezing/utils.py) | +| [ ] [**x**] | [`buffer`](https://www.mathworks.com/help/signal/ref/buffer.html) | `buffer` | [utils.py](https://github.com/OverLordGoldDragon/synchrosqueezing_python/blob/master/synchrosqueezing/utils.py) | +| [**x**] [**x**] | [`est_riskshrink_thresh`](https://github.com/ebrevdo/synchrosqueezing/blob/master/synchrosqueezing/est_riskshrink_thresh.m) | `est_riskshrink_thresh` | [utils.py](https://github.com/OverLordGoldDragon/synchrosqueezing_python/blob/master/synchrosqueezing/utils.py) | +| [**x**] [**x**] | [`p2up`](https://github.com/ebrevdo/synchrosqueezing/blob/master/synchrosqueezing/p2up.m) | `p2up` | [utils.py](https://github.com/OverLordGoldDragon/synchrosqueezing_python/blob/master/synchrosqueezing/utils.py) | + +There are more unlisted (see original repo), but not all will be implemented, in particular GUI implementations. + +## Differences w.r.t. original + + - **Renamed variables/functions**; more Pythonic & readable + - **Removed unused arguments / variables** + - **Improved nan/inf handling** + - **Added examples**; original repo lacks any examples in README + - **Indexing / var references**; MATLAB is 1-indexed, and handles object reference / assignment, and 'range' ops, differently + - **Edited docstrings**; filled missing info, & few corrections + - **Moved functions**; each no longer has its own file, but is grouped with other relevant functions + - **Code style**; grouped parts of code as sub-functions for improved readability; indentation for vertical alignment; other + - **Performance**; this repo may work faster or slower, as Numpy arrays are faster than C arrays, but some of original funcs use MEX-optimized code with no Numpy equivalent. Also using dense instead of sparse matrices (see below). + + **Other**: + - Dense instead of sparse matrices for `stft_fwd` in [stft_transforms.py](https://github.com/OverLordGoldDragon/ssqueezepy/blob/master/synchrosqueezing/stft_transforms.py), as Numpy doesn't handle latter in ops involved + + + +## Examples + +See [examples.py](https://github.com/OverLordGoldDragon/ssqueezepy/blob/master/examples.py). Links to: [paper [1]](https://sci-hub.se/https://doi.org/10.1016/j.sigpro.2012.11.029), [paper[2]](https://arxiv.org/pdf/0912.2437.pdf). `_inv` methods (reconstruction, inversion) have been omitted as they involve `freqband`. + +**EX 1:** Paper [1], pg. 1086 + +Only real components shown; imaginary are nearly identical, sometimes sign-opposite. + + + + + + +synsq-CWT (`synsq_cwt_fwd`) appears to produce strongest agreement with paper (FIG 4), while none of STFT yield any resemblance of anything in the papers. It's also unclear whether `synsq_squeeze` was used for "synsq" in FIG 4 instead. + +**EX 2:** Paper [2], pg. 18 + + + + + +Similar situation as EX 1; again CWT has close resemblance, and STFT is in a separate reality. The two apparent discrepancies w/ CWT are: (1) slope of the forking incline, steeper in FIG. 3; (2) position of horizontal line, lower in FIG. 3. As for the black lines in FIG 3, they seem to be the (manual) "markings" mentioned under the figure in the paper. diff --git a/examples.py b/examples.py new file mode 100644 index 00000000..bc2c4a3f --- /dev/null +++ b/examples.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- +# Papers: +# [1] https://sci-hub.se/https://doi.org/10.1016/j.sigpro.2012.11.029 +# [2] https://arxiv.org/pdf/0912.2437.pdf +import numpy as np +import matplotlib.pyplot as plt + +from ssqueezepy import synsq_cwt_fwd, synsq_stft_fwd +from ssqueezepy import cwt_fwd, stft_fwd + +#%% +def viz_y(y, y1, y2): + _, axes = plt.subplots(1, 2, sharey=True, figsize=(11, 3)) + axes[0].plot(y1) + axes[1].plot(y2) + plt.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=0, hspace=0) + plt.show() + + plt.plot(y) + plt.gcf().set_size_inches(14, 4) + plt.show() + +def viz_s(s, sN, s1, s2, s3): + _, axes = plt.subplots(1, 3, sharey=True, figsize=(10, 3)) + axes[0].plot(t, s1) + axes[1].plot(t, s2) + axes[2].plot(t, s3) + plt.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=0, hspace=0) + plt.show() + + _, axes = plt.subplots(2, 1, sharex=True, figsize=(10, 5)) + plt.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=0, hspace=0) + axes[0].plot(t, s) + axes[1].plot(t, sN) + plt.show() + +def _get_norm(data, norm_rel, norm_abs): + if norm_abs is not None and norm_rel != 1: + raise ValueError("specify only one of `norm_rel`, `norm_abs`") + + if norm_abs is None: + vmax = np.max(np.abs(data)) * norm_rel + vmin = -vmax + else: + vmin, vmax = norm_abs + return vmin, vmax + +def _make_plots(*data, cmap='bwr', norm=None, titles=None): + vmin, vmax = norm or None, None + + _, axes = plt.subplots(1, len(data), sharey=True, figsize=(11, 4)) + plt.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=.1, hspace=0) + + for i, x in enumerate(data): + axes[i].imshow(x, cmap=cmap, aspect='auto', vmin=vmin, vmax=vmax) + axes[i].invert_yaxis() + axes[i].set_title(titles[i], fontsize=14, weight='bold') + plt.show() + +def viz_gen(x, cmap='bwr', norm_rel=1, norm_abs=None): + vmin, vmax = _get_norm(np.real(x), norm_rel, norm_abs) + _make_plots(np.real(x), np.imag(x), cmap=cmap, + norm=(vmin, vmax), titles=("Real", "Imag")) + +def viz_gen2(x1, x2, cmap='bwr', norm_rel=1, norm_abs=None): + vmin, vmax = _get_norm(np.real(x1), norm_rel, norm_abs) + _make_plots(np.real(x1), np.real(x2), cmap=cmap, + norm=(vmin, vmax), titles=("s", "sN")) + +OPTS = {'type': 'bump', 'mu':1} +#%% +"""Paper [2], pg. 18""" +t = np.linspace(0, 12, 1000) +y1 = np.cos(8*t) +y2 = np.cos(t**2 + t + np.cos(t)) +y = y1 + y2 +viz_y(y, y1, y2) +#%% +Tx_yc, *_ = synsq_cwt_fwd(y, fs=len(t)/t[-1], nv=32, opts=OPTS) +Tx_ys, *_ = synsq_stft_fwd(t, y, opts=OPTS) +#%% +viz_gen(Tx_yc, cmap='bwr', norm_abs=(-5e-5, 5e-5)) +viz_gen(Tx_ys, cmap='bwr', norm_abs=(-5e-5, 5e-5)) +#%% +"""Paper [1], pg. 1086""" +t = np.linspace(0, 10, 2048) +s1 = (1 + .2*np.cos(t)) * np.cos(2*np.pi*(2*t + .3*np.cos(t))) +s2 = (1 + .3*np.cos(t)) * np.cos(2*np.pi*(2.4*t + .3*np.sin(t) + .5*t**1.2) + ) * np.exp(-t/15) +s3 = np.cos(2*np.pi*(5.3*t + 0.2*t**1.3)) +N = np.random.randn(len(t)) * np.sqrt(2.4) +s = s1 + s2 + s3 +sN = s + N +viz_s(s, sN, s1, s2, s3) +#%% +# feed "denoised" (actually noiseless) signal as noted on pg. 1086 of [1] +Tx_sc, *_ = synsq_cwt_fwd(s, fs=len(t)/t[-1], nv=32, opts=OPTS) +Tx_sNc, *_ = synsq_cwt_fwd(sN, fs=len(t)/t[-1], nv=32, opts=OPTS) +Tx_ss, *_ = synsq_stft_fwd(t, s, opts=OPTS) +Tx_sNs, *_ = synsq_stft_fwd(t, sN, opts=OPTS) +#%% +viz_gen2(Tx_sc, Tx_sNc, cmap='bwr', norm_abs=(-5e-5, 5e-5)) +viz_gen2(Tx_ss, Tx_sNs, cmap='bwr', norm_abs=(-5e-5, 5e-5)) +#%% +Wx_s, *_ = cwt_fwd(s, 'bump', opts=OPTS) +Wx_sN, *_ = cwt_fwd(sN, 'bump', opts=OPTS) +Sx_s, *_ = stft_fwd(s, dt=t[1]-t[0], opts=OPTS) +Sx_sN, *_ = stft_fwd(sN, dt=t[1]-t[0], opts=OPTS) +#%% +viz_gen2(Wx_s, Wx_sN, cmap='bwr', norm_abs=(-1.3, 1.3)) +viz_gen2(Sx_s, Sx_sN, cmap='bwr', norm_abs=(-1.3, 1.3)) diff --git a/ssqueezepy/__init__.py b/ssqueezepy/__init__.py new file mode 100644 index 00000000..da57289c --- /dev/null +++ b/ssqueezepy/__init__.py @@ -0,0 +1,8 @@ +from synsq_cwt import * +from synsq_stft import * +from wavelet_transforms import * +from stft_transforms import * +from utils import * + + +__version__ = '0.80' diff --git a/ssqueezepy/stft_transforms.py b/ssqueezepy/stft_transforms.py new file mode 100644 index 00000000..400b6847 --- /dev/null +++ b/ssqueezepy/stft_transforms.py @@ -0,0 +1,285 @@ +# Ported from the Synchrosqueezing Toolbox, authored by +# Eugine Brevdo, Gaurav Thakur +# (http://www.math.princeton.edu/~ebrevdo/) +# (https://github.com/ebrevdo/synchrosqueezing/) + +import numpy as np +from quadpy import quad as quadgk +from utils import wfiltfn, padsignal, buffer + +PI = np.pi +EPS = np.finfo(np.float64).eps # machine epsilon for float64 + + +def stft_fwd(x, dt, opts={}): + """Compute the short-time Fourier transform and modified short-time + Fourier transform from [1]. The former is very closely based on Steven + Schimmel's stft.m and istft.m from his SPHSC 503: Speech Signal Processing + course at Univ. Washington. + + # Arguments: + x: np.ndarray. Input signal vector, length `n` (need not be dyadic). + dt: int, sampling period (defaults to 1). + opts: dict. Options: + 'type': str. Wavelet type. See `wfiltfn` + 'winlen': int. length of window in samples; Nyquist frequency + is winlen/2 + 'padtype': str ('symmetric', 'repliace', 'circular'). Type + of padding (default = 'symmetric') + 'rpadded': bool. Whether to return padded `Sx` and `dSx` + (default = True) + 's', 'mu', ... : window options (see `wfiltfn`) + + # Returns: + Sx: (na x n) size matrix (rows = scales, cols = times) containing + samples of the CWT of `x`. + Sfs: vector containign the associated frequencies. + dSx: (na x n) size matrix containing samples of the time-derivatives + of the STFT of `x`. + + # References: + 1. G. Thakur and H.-T. Wu, + "Synchrosqueezing-based Recovery of Instantaneous Frequency + from Nonuniform Samples", + SIAM Journal on Mathematical Analysis, 43(5):2078-2095, 2011. + """ + def _process_opts(opts, x): + # opts['window'] is window length; opts['type'] overrides the + # default hamming window + opts['stft_type'] = opts.get('stft_type', 'normal') + opts['winlen'] = opts.get('winlen', int(np.round(len(x) / 8))) + # 'padtype' is one of: 'symmetric', 'replicate', 'circular' + opts['padtype'] = opts.get('padtype', 'symmetric') + opts['rpadded'] = opts.get('rpadded', False) + + windowfunc, diff_windowfunc = None, None + if 'type' in opts: + windowfunc = wfiltfn(opts['type'], opts, derivative=False) + diff_windowfunc = wfiltfn(opts['type'], opts, derivative=True) + + return opts, windowfunc, diff_windowfunc + + opts, windowfunc, diff_windowfunc = _process_opts(opts, x) + + # Pre-pad signal; this only works well for 'normal' STFT + n = len(x) + if opts['stft_type'] == 'normal': + x, N_old, n1, n2 = padsignal(x, opts['padtype'], opts['winlen']) + n1 = n1 // 2 + else: + n1 = 0 + + N = len(x) + + if opts['stft_type'] == 'normal': + # set up window + if 'type' in opts: + window = windowfunc(np.linspace(-1, 1, opts['winlen'])) + diff_window = diff_windowfunc(np.linspace(-1, 1, opts['winlen'])) + else: + window = np.hamming(opts['winlen']) + diff_window = np.hstack([np.diff(np.hamming(opts['winlen'])), 0]) + diff_window[np.where(np.isnan(diff_window))] = 0 + + # frequency range + Sfs = np.linspace(0, 1, opts['winlen'] + 1) + Sfs = Sfs[:np.floor(opts['winlen'] / 2).astype('int64') + 1] / dt + + # compute STFt and keep only the positive frequencies + xbuf = buffer(x, opts['winlen'], opts['winlen'] - 1, 'nodelay') + xbuf = np.diag(window) @ xbuf + Sx = np.fft.fft(xbuf, None, axis=0) + Sx = Sx[:opts['winlen'] // 2 + 1] / np.sqrt(N) + + # same steps for STFT derivative + dxbuf = buffer(x, opts['winlen'], opts['winlen'] - 1, 'nodelay') + dxbuf = np.diag(diff_window) @ dxbuf + dSx = np.fft.fft(dxbuf, None, axis=0) + dSx = dSx[:opts['winlen'] // 2 + 1] / np.sqrt(N) + dSx /= dt + + elif opts['stfttype'] == 'modified': + # modified STFt is more accurately done in the frequency domain, + # like a filter bank over different frequency bands + # uses a lot of memory, so best used on small blocks + # (<5000 samples) at a time + Sfs = np.linspace(0, 1, N) / dt + Sx = np.zeros((N, N)) + dSx = np.zeros((N, N)) + + halfN = np.round(N / 2) + halfwin = np.floor((opts['winlen'] - 1) / 2) + window = windowfunc(np.linspace(-1, 1, opts['winlen'])).T # TODO chk dim + diff_window = diff_windowfunc(np.linspace(-1, 1, opts['winlen'])).T * ( + 2 / opts['winlen'] / dt) + + for k in range(N): + freqs = np.arange(-min(halfN - 1, halfwin, k - 1), + min(halfN - 1, halfwin, N- k) + 1) + indices = np.mod(freqs, N) + Sx[indices, k] = x[k + freqs] * window(halfwin + freqs + 1) + dSx[indices, k] = x[k + freqs] * diff_window(halfwin + freqs + 1) + + Sx = np.fft.fft(Sx) / np.sqrt(N) + dSx = np.fft.fft(dSx) / np.sqrt(N) + + # only keep the positive frequencies + Sx = Sx[:halfN] + dSx = dSx[:halfN] + Sfs = Sfs[:halfN] + + # Shorten Sx to proper size (remove padding) + if not opts['rpadded']: + Sx = Sx[:, range(n1, n1 + n)] + dSx = dSx[:, range(n1, n1 + n)] + + return Sx, Sfs, dSx + + +def stft_inv(Sx, opts={}): + """Inverse short-time Fourier transform. + + Very closely based on Steven Schimel's stft.m and istft.m from his + SPHSC 503: Speech Signal Processing course at Univ. Washington. + Adapted for use with Synchrosqueeing Toolbox. + + # Arguments: + Sx: np.ndarray. Wavelet transform of a signal (see `stft_fwd`). + opts: dict. Options: + 'type': str. Wavelet type. See `stft_fwd`, and `wfiltfn`. + Others; see `stft_fwd` and source code. + + # Returns: + x: the signal, as reconstructed from `Sx`. + """ + def _unbuffer(x, w, o): + # Undo the effect of 'buffering' by overlap-add; + # returns the signal A that is the unbuffered version of B + y = [] + skip = w - o + N = np.ceil(w / skip) + L = (x.shape[1] - 1) * skip + x.shape[0] + + # zero-pad columns to make length nearest integer multiple of `skip` + if x.shape[0] < skip * N: + x[skip * N - 1, -1] = 0 # TODO columns? + + # selectively reshape columns of input into 1d signals + for i in range(N): + t = x[:, range(i, len(x) - 1, N)].reshape(1, -1) + l = len(t) + y[i, l + (i - 1)*skip - 1] = 0 + y[i, np.arange(l) + (i - 1)*skip] = t + + # overlap-add + y = np.sum(y, axis=0) + y = y[:L] + + return y + + def _process_opts(opts, Sx): + # opts['window'] is window length; opts['type'] overrides + # default hamming window + opts['winlen'] = opts.get('winlen', int(np.round(Sx.shape[1] / 16))) + opts['overlap'] = opts.get('overlap', opts['winlen'] - 1) + opts['rpadded'] = opts.get('rpadded', False) + + if 'type' in opts: + A = wfiltfn(opts['type'], opts) + window = A(np.linspace(-1, 1, opts['winlen'])) + else: + window = np.hamming(opts['winlen']) + + return opts, window + + opts, window = _process_opts(opts, Sx) + + # window = window / norm(window, 2) --> Unit norm + n_win = len(window) + + # find length of padding, similar to outputs of `padsignal` + n = Sx.shape[1] + if not opts['rpadded']: + xLen = n + else: + xLen == n - n_win + + # n_up = xLen + 2 * n_win + n1 = n_win - 1 + # n2 = n_win + new_n1 = np.floor((n1 - 1) / 2) + + # add STFT apdding if it doesn't exist + if not opts['rpadded']: + Sxp = np.zeros(Sx.shape) + Sxp[:, range(new_n1, new_n1 + n + 1)] = Sx + Sx = Sxp + else: + n = xLen + + # regenerate the full spectrum 0...2pi (minus zero Hz value) + Sx = np.hstack([Sx, np.conj(Sx[np.arange( + np.floor((n_win + 1) / 2), 3, -1)])]) + + # take the inverse fft over the columns + xbuf = np.real(np.fft.ifft(Sx, None, axis=0)) + + # apply the window to the columns + xbuf *= np.matlib.repmat(window.flatten(), 1, xbuf.shape[1]) + + # overlap-add the columns + x = _unbuffer(xbuf, n_win, opts['overlap']) + + # keep the unpadded part only + x = x[n1:n1 + n + 1] + + # compute L2-norm of window to normalize STFT with + windowfunc = wfiltfn(opts['type'], opts, derivative=False) + C = lambda x: quadgk(windowfunc(x) ** 2, -np.inf, np.inf) + + # `quadgk` is a bit inaccurate with the 'bump' function, + # this scales it correctly + if opts['type'] == 'bump': + C *= 0.8675 + + x *= 2 / (PI * C) + + return x + + +def phase_stft(Sx, dSx, Sfs, t, opts={}): + """Calculate the phase transform of modified STFT at each (freq, time) pair: + w[a, b] = Im( eta - d/dt(Sx[t, eta]) / Sx[t, eta] / (2*pi*j)) + Uses direct differentiation by calculating dSx/dt in frequency domain + (the secondary output of `stft_fwd`, see `stft_fwd`). + + # Arguments: + Sx: np.ndarray. Wavelet transform of `x` (see `stft_fwd`). + dSx: np.ndarray. Samples of time-derivative of STFT of `x` + (see `stft_fwd`). + opts: dict. Options: + 'gamma': float. Wavelet threshold (default: sqrt(machine epsilon)) + + # Returns: + w: phase transform, w.shape == Sx.shape + + # References: + 1. G. Thakur and H.-T. Wu, + "Synchrosqueezing-based Recovery of Instantaneous Frequency from + Nonuniform Samples", + SIAM Journal on Mathematical Analysis, 43(5):2078-2095, 2011. + + 2. G. Thakur, E. Brevdo, N.-S. Fučkar, and H.-T. Wu, + "The Synchrosqueezing algorithm for time-varying spectral analysis: + robustness properties and new paleoclimate applications," + Signal Processing, 93:1079-1094, 2013. + """ + opts['gamma'] = opts.get('gamma', np.sqrt(EPS)) + + # calculate phase transform; modified STFT amounts to extra frequency term + w = np.matlib.repmat(Sfs, len(t), 1).T - np.imag(dSx / Sx / (2 * PI)) + + # threshold out small points + w[np.abs(Sx) < opts['gamma']] = np.inf + + return w diff --git a/ssqueezepy/synsq_cwt.py b/ssqueezepy/synsq_cwt.py new file mode 100644 index 00000000..e54337e2 --- /dev/null +++ b/ssqueezepy/synsq_cwt.py @@ -0,0 +1,187 @@ +# Ported from the Synchrosqueezing Toolbox, authored by +# Eugine Brevdo, Gaurav Thakur +# (http://www.math.princeton.edu/~ebrevdo/) +# (https://github.com/ebrevdo/synchrosqueezing/) + +import numpy as np +from utils import est_riskshrink_thresh, p2up, synsq_adm +from wavelet_transforms import phase_cwt, phase_cwt_num +from wavelet_transforms import cwt_fwd, synsq_squeeze + + +def synsq_cwt_fwd(x, t=None, fs=None, nv=32, opts=None): + """Calculates the synchrosqueezing transform of vector `x`, with samples + taken at times given in vector `t`. Uses `nv` voices. Implements the + algorithm described in Sec. III of [1]. + + # Arguments: + x: np.ndarray. Vector of signal samples (e.g. x = np.cos(20 * np.pi * t)) + t: np.ndarray / None. Vector of times samples are taken + (e.g. np.linspace(0, 1, n)). If None, defaults to np.arange(len(x)). + Overrides `fs` if not None. + fs: float. Sampling frequency of `x`; overridden by `t`, if provided. + nv: int. Number of voices. Recommended 32 or 64 by [1]. + opts: dict. Options specifying how synchrosqueezing is computed. + 'type': str. type of wavelet. See `wfiltfn` docstring. + 'gamma': float / None. Wavelet hard thresholding value. If None, + is estimated automatically. + 'difftype': str. 'direct', 'phase', or 'numerical' differentiation. + 'numerical' uses MEX differentiation, which is faster and + uses less memory, but may be less accurate. + + # Returns: + Tx: Synchrosqueeze-transformed `x`, columns associated w/ `t` + fs: Frequencies associated with rows of `Tx`. + Wx: Wavelet transform of `x` (see `cwt_fwd`) + Wx_scales: scales associated with rows of `Wx`. + w: Phase transform for each element of `Wx`. + + + # References: + 1. G. Thakur, E. Brevdo, N.-S. Fučkar, and H.-T. Wu, + "The Synchrosqueezing algorithm for time-varying spectral analysis: + robustness properties and new paleoclimate applications", + Signal Processing, 93:1079-1094, 2013. + 2. I. Daubechies, J. Lu, H.T. Wu, "Synchrosqueezed Wavelet Transforms: + an empricial mode decomposition-like tool", + Applied and Computational Harmonic Analysis, 30(2):243-261, 2011. + """ + def _get_opts(opts): + opts_default = {'type': 'morlet', + 'difftype': 'direct', + 'gamma': None} + if opts is None: + opts = opts_default + else: + opts = {} + for opt_name in opts_default: + opts[opt_name] = opts.get(opt_name, opts_default[opt_name]) + return opts + + def _wavelet_transform(x, nv, dt, opts): + N = len(x) + N_up, n1, n2 = p2up(N) + + if opts['difftype'] == 'direct': + # calculate derivative directly in the wavelet domain + # before taking wavelet transform + opts['rpadded'] = 0 + + Wx, Wx_scales, dWx, _ = cwt_fwd(x, opts['type'], nv, dt, opts) + w = phase_cwt(Wx, dWx, opts) + + elif opts['difftype'] == 'phase': + # take derivative of unwrapped CWT phase + # directly in phase transform + opts['rpadded'] = 0 + + Wx, Wx_scales, _ = cwt_fwd(x, opts['type'], nv, dt, opts) + w = phase_cwt(Wx, None, opts) + else: + # calculate derivative numerically after calculating wavelet + # transform. This requires less memory and is more accurate + # for lesser `a`. + opts['rpadded'] = 1 + + Wx, Wx_scales, _ = cwt_fwd(x, opts['type'], nv, dt, opts) + Wx = Wx[:, (n1 - 5 + 1):(n1 + N + 3)] + w = phase_cwt_num(Wx, dt, opts) + + return Wx, w, Wx_scales, opts + + def _validate_spacing_uniformity(t): + if np.any([(np.diff(t, 2) / (t[-1] - t[0]) > 1e-5)]): + raise Exception("Time vector `t` must be uniformly sampled.") + + if t is None: + fs = fs or 1. + t = np.linspace(0., len(x) / fs, len(x)) + else: + _validate_spacing_uniformity(t) + opts = _get_opts(opts) + + dt = t[1] - t[0] # sampling period, assuming regular spacing + Wx, w, Wx_scales, opts = _wavelet_transform(x, nv, dt, opts) + + if opts['gamma'] is None: + opts['gamma'] = est_riskshrink_thresh(Wx, nv) + + # calculate the synchrosqueezed frequency decomposition + opts['transform'] = 'CWT' + Tx, fs = synsq_squeeze(Wx, w, t, nv, opts) + + if opts['difftype'] == 'numerical': + Wx = Wx[:, (3 + 1):(len(Wx) - 1 - 5)] + w = w[: (3 + 1):(len(w) - 1 - 5)] + Tx = Tx[:, (3 + 1):(len(Tx) - 1 - 5)] + + return Tx, fs, Wx, Wx_scales, w + + +def synsq_cwt_inv(Tx, fs, opts={}, Cs=None, freqband=None): #TODO Arguments + """Inverse synchrosqueezing transform of `Tx` with associated frequencies + in `fs` and curve bands in time-frequency plane specified by `Cs` and + `freqband`. This implements Eq. 5 of [1]. + + # Arguments: + Tx: np.ndarray. Synchrosqueeze-transformed `x` (see `synsq_cwt_fwd`). + fs: np.ndarray. Frequencies associated with rows of Tx. + (see `synsq_cwt_fwd`). + opts: dict. Options (see `synsq_cwt_fwd`): + 'type': type of wavelet used in `synsq_cwt_fwd` + + other wavelet options ('mu', 's') should also match + those used in `synsq_cwt_fwd` + 'Cs': (optional) curve centerpoints + 'freqs': (optional) curve bands + + # Returns: + x: components of reconstructed signal, and residual error + + # Example: + Tx, fs = synsq_cwt_fwd(t, x, 32) # synchrosqueeizing + Txf = synsq_filter_pass(Tx,fs, -np.inf, 1) # pass band filter + xf = synsq_cwt_inv(Txf, fs) # filtered signal reconstruction + + # References: + 1. G. Thakur, E. Brevdo, N.-S. Fučkar, and H.-T. Wu, + "The Synchrosqueezing algorithm for time-varying spectral analysis: + robustness properties and new paleoclimate applications", + Signal Processing, 93:1079-1094, 2013. + """ + opts = opts or {'type': 'morlet'} + Cs = Cs or np.ones((Tx.shape[1], 1)) + freqband = Tx.shape[0] + + # Find the admissibility coefficient Cpsi + Css = synsq_adm(opts['type'], opts) + + # Invert Tx around curve masks in the time-frequency plane to recover + # individual components; last one is the remaining signal + # Integration over all frequencies recovers original signal + # Factor of 2 is because real parts contain half the energy + x = np.zeros((Cs.shape[0], Cs.shape[1] + 1)) + TxRemainder = Tx + + for n in range(Cs.shape[1]): + TxMask = np.zeros(Tx.shape) + UpperCs = min(max(Cs[:, n] + freqband[:, n], 1), len(fs)) + LowerCs = min(max(Cs[:, n] - freqband[:, n], 1), len(fs)) + + # Cs==0 corresponds to no curve at that time, so this removes + # such points from the inversion + UpperCs[np.where(Cs[:, n] < 1)] = 1 + LowerCs[np.where(Cs[:, n] < 1)] = 2 + for m in range(Tx.shape[1]): + idxs = slice(LowerCs[m] - 1, UpperCs[m]) + TxMask[idxs, m] = Tx[idxs, m] + TxRemainder[idxs, m] = 0 + + # Due to linear discretization of integral in log(fs), + # this becomes a simple normalized sum + x[:, n] = (1 / Css) * np.sum(np.real(TxMask), axis=0).T + + x[:, n + 1] = (1 / Css) * np.sum(np.real(TxRemainder), axis=0).T + x = x.T + + return x diff --git a/ssqueezepy/synsq_stft.py b/ssqueezepy/synsq_stft.py new file mode 100644 index 00000000..e2e2f1ff --- /dev/null +++ b/ssqueezepy/synsq_stft.py @@ -0,0 +1,123 @@ +# Ported from the Synchrosqueezing Toolbox, authored by +# Eugine Brevdo, Gaurav Thakur +# (http://www.math.princeton.edu/~ebrevdo/) +# (https://github.com/ebrevdo/synchrosqueezing/) + +import numpy as np +from utils import wfiltfn +from stft_transforms import stft_fwd, phase_stft +from wavelet_transforms import synsq_squeeze +from quadpy import quad as quadgk + +PI = np.pi + + +def synsq_stft_fwd(t, x, opts={}): + """Calculates the STFt synchrosqueezing transform of vector `x`, with + samples taken at times given in vector `t`. This implements the algorithm + described in Sec. III of [1]. + + # Arguments: + t: np.ndarray. Vector of times samples are taken + (e.g. np.linspace(0, 1, n)) + x: np.ndarray. Vector of signal samples (e.g. x = np.cos(20 * PI * t)) + opts: dict. Options: + 'type': str. Type of wavelet (see `wfiltfn`) + 's', 'mu': float. Wavelet parameters (see `wfiltfn`) + 'gamma': float. Wavelet hard thresholding value + (see `cwt_freq_direct`) + + # Returns: + Tx: synchrosqueezed output of `x` (columns associated with time `t`) + fs: frequencies associated with rows of `Tx` + Sx: STFT of `x` (see `stft_fw`) + Sfs: frequencies associated with rows of `Sx` + w: phase transform of `Sx` + """ + def _validate_spacing_uniformity(t): + if np.any([(np.diff(t, 2) / (t[-1] - t[0]) > 1e-5)]): + raise Exception("Time vector `t` must be uniformly sampled.") + + _validate_spacing_uniformity(t) + + opts['type'] = opts.get('type', 'bump') + opts['rpadded'] = opts.get('rpadded', False) + + dt = t[1] - t[0] + + # Calculate the modified STFT, using window of opts['winlen'] + # in frequency domain + opts['stfttype'] = 'modified' + Sx, Sfs, dSx = stft_fwd(x, dt, opts) + + w = phase_stft(Sx, dSx, Sfs, t, opts) + + # Calculate the synchrosqueezed frequency decomposition + # The parameter alpha from reference [2] is given by Sfs[1] - Sfs[0] + opts['transform'] = 'STFT' + Tx, fs = synsq_squeeze(Sx, w, t, None, opts) + + return Tx, fs, Sx, Sfs, w, dSx + + +def synsq_stft_inv(Tx, fs, opts, Cs=None, freqband=None): + """Inverse STFT synchrosqueezing transform of `Tx` with associated + frequencies in `fs` and curve bands in time-frequency plane + specified by `Cs` and `freqband`. This implements Eq. 5 of [1]. + + # Arguments: + Tx: np.ndarray. Synchrosqueeze-transformed `x` (see `synsq_cwt_fwd`). + fs: np.ndarray. Frequencies associated with rows of Tx. + (see `synsq_cwt_fwd`). + opts. dict. Options: + 'type': type of wavelet used in `synsq_cwt_fwd` (required). + + other wavelet options ('mu', 's') should also match those + used in `synsq_cwt_fwd` + 'Cs': (optional) curve centerpoints + 'freqs': (optional) curve bands + + # Returns: + x: components of reconstructed signal, and residual error + + Example: + Tx, fs = synsq_cwt_fwd(t, x, 32) # synchrosqueezing + Txf = synsq_filter_pass(Tx, fs, -np.inf, 1) # pass band filter + xf = synsq_cwt_inv(Txf, fs) # filtered signal reconstruction + """ + Cs = Cs or np.ones((Tx.shape[1], 1)) + freqband = freqband or Tx.shape[0] + + windowfunc = wfiltfn(opts['type'], opts, derivative=False) + inf_lim = 1000 # quadpy can't handle np.inf limits + C = quadgk(lambda x: windowfunc(x)**2, -inf_lim, inf_lim) + if opts['type'] == 'bump': + C *= 0.8675 + + # Invert Tx around curve masks in the time-frequency plane to recover + # individual components; last one is the remaining signal + # Integration over all frequencies recovers original signal + # Factor of 2 is because real parts contain half the energy + x = np.zeros((Cs.shape[0], Cs.shape[1] + 1)) + TxRemainder = Tx + for n in range(Cs.shape[1]): + TxMask = np.zeros(Tx.shape) + UpperCs = min(max(Cs[:, n] + freqband[:, n], 1), len(fs)) + LowerCs = min(max(Cs[:, n] - freqband[:, n], 1), len(fs)) + + # Cs==0 corresponds to no curve at that time, so this removes + # such points from the inversion + # NOTE: transposed + flattened to match MATLAB's 'linear indices' + UpperCs[np.where(Cs[:, n].T.flatten() < 1)] = 1 + LowerCs[np.where(Cs[:, n].T.flatten() < 1)] = 2 + + for m in range(Tx.shape[1]): + idxs = slice(LowerCs[m] - 1, UpperCs[m]) + TxMask[idxs, m] = Tx[idxs, m] + TxRemainder[idxs, m] = 0 + x[:, n] = 1 / (PI * C) * np.sum(np.real(TxMask), axis=0).T + + x[:, n + 1] = 1 / (PI * C) * np.sum(np.real(TxRemainder), axis=0).T + x = x.T + + return x diff --git a/ssqueezepy/utils.py b/ssqueezepy/utils.py new file mode 100644 index 00000000..9b480f8c --- /dev/null +++ b/ssqueezepy/utils.py @@ -0,0 +1,286 @@ +# Ported from the Synchrosqueezing Toolbox, authored by +# Eugine Brevdo, Gaurav Thakur +# (http://www.math.princeton.edu/~ebrevdo/) +# (https://github.com/ebrevdo/synchrosqueezing/) + +import numpy as np +import numpy.matlib +from quadpy import quad as quadgk + +PI = np.pi + + +def mad(data, axis=None): + return np.mean(np.abs(data - np.mean(data, axis)), axis) + + +def est_riskshrink_thresh(Wx, nv): + """Estimate the RiskShrink hard thresholding level. + + # Arguments: + Wx: np.ndarray. Wavelet transform of a signal. + opt: dict. Options structure used for forward wavelet transform. + + # Returns: + gamma: float. The RiskShrink hard thresholding estimate. + """ + na, n = Wx.shape + + Wx_fine = np.abs(Wx[:nv]) + gamma = 1.4826 * np.sqrt(2 * np.log(n)) * np.mad(Wx_fine) + + return gamma + + +def p2up(n): + """Calculates next power of 2, and left/right padding to center + the original `n` locations. + + # Arguments: + n: int. Non-dyadic integer. + + # Returns: + up: next power of 2 + n1: length on left + n2: length on right + """ + eps = np.finfo(np.float64).eps # machine epsilon for float64 + up = 2 ** (1 + np.round(np.log2(n + eps))) + + n1 = np.floor((up - n) / 2) + n2 = n1 + + if (2 * n1 + n) % 2 == 1: + n2 = n1 + 1 + return up, n1, n2 + + +def padsignal(x, padtype='symmetric', padlength=None): + """Pads signal and returns indices of original signal. + + # Arguments: + x: np.ndarray. Original signal. + padtype: str ('symmetric', 'replicate'). + padlength: int. Number of samples to pad on each side. Default is + nearest power of 2. + + # Returns: + x: padded signal. + n_up: next power of 2. + n1: length on left. + n2: length on right. + """ + padtypes = ('symmetric', 'replicate') + if padtype not in padtypes: + raise ValueError(("Unsupported `padtype` {}; must be one of: {}" + ).format(padtype, ", ".join(padtypes))) + n = len(x) + + if padlength is None: + # pad up to the nearest power of 2 + n_up, n1, n2 = p2up(n) + else: + n_up = n + 2 * padlength + n1 = padlength + 1 + n2 = padlength + n_up, n1, n2 = int(n_up), int(n1), int(n2) + + if padtype == 'symmetric': + xl = np.matlib.repmat(np.hstack([x, np.flipud(x)]), + m=int(np.ceil(n1 / (2 * n))), n=1).squeeze() + xr = np.matlib.repmat(np.hstack([np.flipud(x), x]), + m=int(np.ceil(n2 / (2 * n))), n=1).squeeze() + elif padtype == 'replicate': + xl = x[0] * np.ones(n1) + xr = x[-1] * np.ones(n2) + + xpad = np.hstack([xl[-n1:], x, xr[:n2]]) + + return xpad, n_up, n1, n2 + + +def wfiltfn(wavelet_type, opts, derivative=False): + """Wavelet transform function of the wavelet filter in question, + Fourier domain. + + # Arguments: + wavelet_type: str. See below. + opts: dict. Options, e.g. {'s': 1, 'mu': 5} + + # Returns: + lambda xi: psihfn(xi) + + _______________________________________________________________________ + Filter types Use for synsq? Parameters (default) + + mhat no s (1) + cmhat yes s (1), mu (1) + morlet yes mu (2*pi) + shannon no -- (NOT recommended for analysis) + hshannon yes -- (NOT recommended for analysis) + hhat no + hhhat yes mu (5) + bump yes s (1), mu (5) + _______________________________________________________________________ + + # Example: + psihfn = wfiltfn('bump', {'s': .5, 'mu': 1}) + plt.plot(psihfn(np.arange(-5, 5.01, step=.01))) + """ + supported_types = ('bump', 'mhat', 'cmhat', 'morlet', 'shannon', + 'hshannon', 'hhat', 'hhhat') + if wavelet_type not in supported_types: + raise ValueError(("Unsupported `wavelet_type` '{}'; must be one of: {}" + ).format(wavelet_type, ", ".join(supported_types))) + if wavelet_type == 'bump': + mu = opts.get('mu', 5) + s = opts.get('s', 1) + om = opts.get('om', 0) + + psihfnorig = lambda w: (np.abs(w) < .999) * np.exp( + -1. / (1 - (w * (np.abs(w) < .999)) ** 2)) / .443993816053287 + + psihfn = lambda w: np.exp(2 * PI * 1j * om * w) * psihfnorig( + (w - mu) / s) / s + if derivative: + _psihfn = psihfn; del psihfn + psihfn = lambda w: _psihfn(w) * ( + 2 * PI * 1j * om - 2 * ((w - mu) / s**2) / ( + 1 - ((w - mu) / s)**2)**2) + + elif wavelet_type == 'mhat': # mexican hat + s = opts.get('s', 1) + psihfn = lambda w: -np.sqrt(8) * s**(5/2) * PI**(1/4) / np.sqrt( + 3) * w**2 * np.exp(-s**2 * w**2 / 2) + + elif wavelet_type == 'cmhat': + # complex mexican hat; hilbert analytic function of sombrero + # can be used with synsq + s = opts.get('s', 1) + mu = opts.get('mu', 1) + psihfnshift = lambda w: 2 * np.sqrt(2/3) * PI**(-1/4) * ( + s**(5/2) * w**2 * np.exp(-s**2 * w**2 / 2) * (w >= 0)) + psihfn = lambda w: psihfnshift(w - mu) + + elif wavelet_type == 'morlet': + # can be used with synsq for large enough `s` (e.g. >5) + mu = opts.get('mu', 2 * PI) + cs = (1 + np.exp(-mu**2) - 2 * np.exp(-3/4 * mu**2)) ** (-.5) + ks = np.exp(-.5 * mu**2) + psihfn = lambda w: cs * PI**(-1/4) * (np.exp(-.5 * (mu - w)**2) + - ks * np.exp(-.5 * w**2)) + + elif wavelet_type == 'shannon': + psihfn = lambda w: np.exp(-1j * w / 2) * (np.abs(w) >= PI + and np.abs(w) <= 2 * PI) + elif wavelet_type == 'hshannon': + # hilbert analytic function of shannon transform + # time decay is too slow to be of any use in synsq transform + mu = opts.get('mu', 0) + psihfnshift = lambda w: np.exp(-1j * w / 2) * ( + w >= PI and w <= 2 * PI) * (1 + np.sign(w)) + psihfn = lambda w: psihfnshift(w - mu) + + elif wavelet_type == 'hhat': # hermitian hat + psihfnshift = lambda w: 2 / np.sqrt(5) * PI**(-1 / 4) * ( + w * (1 + w) * np.exp(-.5 * w**2)) + psihfn = lambda w: psihfnshift(w - mu) + + elif wavelet_type == 'hhhat': + # hilbert analytic function of hermitian hat; can be used with synsq + mu = opts.get('mu', 5) + psihfnshift = lambda w: 2 / np.sqrt(5) * PI**(-1/4) * ( + w * (1 + w) * np.exp(-1/2 * w**2)) * (1 + np.sign(w)) + psihfn = lambda w: psihfnshift(w - mu) + + return psihfn + + +def synsq_adm(wavelet_type, opts={}): + """Calculate the synchrosqueezing admissibility constant, the term + R_\psi in Eq. 3 of [1]. Note, here we multiply R_\psi by the inverse of + log(2)/nv (found in Alg. 1 of [1]). + + Uses numerical intergration. + + # Arguments: + wavelet_type: str. See `wfiltfn`. + opts: dict. Options. See `wfiltfn`. + + # Returns: + Css: proportional to 2 * integral(conj(f(w)) / w, w=0..inf) + + # References: + 1. G. Thakur, E. Brevdo, N.-S. Fučkar, and H.-T. Wu, + "The Synchrosqueezing algorithm for time-varying spectral analysis: + robustness properties and new paleoclimate applications", + Signal Processing, 93:1079-1094, 2013. + + """ + psihfn = wfiltfn(wavelet_type, opts) + Css = lambda x: quadgk(np.conj(psihfn(x)) / x, 0, np.inf) + + # Normalization constant, due to logarithmic scaling + # in wavelet transform + _Css = Css; del Css + Css = lambda x: _Css(x) / np.sqrt(2 * PI) * 2 * np.log(2) + + return Css + + +def buffer(x, n, p=0, opt=None): + """Mimic MATLAB routine to generate buffer array + + MATLAB docs here: https://se.mathworks.com/help/signal/ref/buffer.html + + # Arguments: + x: np.ndarray. Signal array. + n: int. Number of data segments. + p: int. Number of values to overlap + opt: str. Initial condition options. Default sets the first `p` + values to zero, while 'nodelay' begins filling the buffer immediately. + + # Returns: + result : (n,n) ndarray + Buffer array created from x. + + # References: + ryanjdillon: https://stackoverflow.com/questions/38453249/ + is-there-a-matlabs-buffer-equivalent-in-numpy#answer-40105995 + """ + import numpy as np + + if opt not in ('nodelay', None): + raise ValueError('{} not implemented'.format(opt)) + + i = 0 + first_iter = True + while i < len(x): + if first_iter: + if opt == 'nodelay': + # No zeros at array start + result = x[:n] + i = n + else: + # Start with `p` zeros + result = np.hstack([np.zeros(p), x[:n-p]]) + i = n-p + # Make 2D array and pivot + result = np.expand_dims(result, axis=0).T + first_iter = False + continue + + # Create next column, add `p` results from last col if given + col = x[i:i+(n-p)] + if p != 0: + col = np.hstack([result[:,-1][-p:], col]) + i += n-p + + # Append zeros if last row and not length `n` + if len(col) < n: + col = np.hstack([col, np.zeros(n-len(col))]) + + # Combine result with next row + result = np.hstack([result, np.expand_dims(col, axis=0).T]) + + return result diff --git a/ssqueezepy/wavelet_transforms.py b/ssqueezepy/wavelet_transforms.py new file mode 100644 index 00000000..9d50e62d --- /dev/null +++ b/ssqueezepy/wavelet_transforms.py @@ -0,0 +1,436 @@ +# Ported from the Synchrosqueezing Toolbox, authored by +# Eugine Brevdo, Gaurav Thakur +# (http://www.math.princeton.edu/~ebrevdo/) +# (https://github.com/ebrevdo/synchrosqueezing/) + +import numpy as np +from utils import padsignal, wfiltfn + +EPS = np.finfo(np.float64).eps # machine epsilon for float64 +PI = np.pi + + +def synsq_squeeze(Wx, w, t, nv=None, opts={}): + """Calculates the synchrosqueezed CWT or STFT of `x`. Used internally by + `synsq_cwt_fw` and `synsq_stft_fw`. + + # Arguments: + Wx or Sx: np.ndarray. CWT or STFT of `x`. + w: np.ndarray. Phase transform at same locations in T-F plane. + t: np.ndarray. Time vector. + nv: int. Number of voices (CWT only). + opts: dict. Options: + 'transform': ('CWT', 'STFT'). Underlying time-frequency transform. + 'freqscale': ('log', 'linear'). Frequency bins/divisions. + 'findbins': ('min', 'direct'). Method to find bins. + 'direct' is faster. + 'squeezing': ('full', 'measure'). Latter corresponds to approach + in [3], which is not invertible but has better + robustness properties in some cases. + + # Returns: + Tx: synchrosqueezed output. + fs: associated frequencies. + + Note the multiplicative correction term x in `synsq_cwt_squeeze_mex`, + required due to the fact that the squeezing integral of Eq. (2.7), in, + [1], is taken w.r.t. dlog(a). This correction term needs to be included + as a factor of Eq. (2.3), which we implement here. + + A more detailed explanation is available in Sec. III of [2]. + Note the constant multiplier log(2)/nv has been moved to the + inverse of the normalization constant, as calculated in `synsq_adm`. + + # References: + 1. I. Daubechies, J. Lu, H.T. Wu, "Synchrosqueezed Wavelet Transforms: + an empricial mode decomposition-like tool", + Applied and Computational Harmonic Analysis, 30(2):243-261, 2011. + + 2. G. Thakur, E. Brevdo, N.-S. Fučkar, and H.-T. Wu, + "The Synchrosqueezing algorithm for time-varying spectral analysis: + robustness properties and new paleoclimate applications", + Signal Processing, 93:1079-1094, 2013. + + 3. G. Thakur and H.-T. Wu, "Synchrosqueezing-based Recovery of + Instantaneous Frequency from Nonuniform Samples", + SIAM Journal on Mathematical Analysis, 43(5):2078-2095, 2011. + """ + def _squeeze(w, Wx, Wx_scales, fs, dfs, scaleterm, na, N, lfm, lfM, opts): + # must cast to complex else value assignment discards imaginary component + Tx = np.zeros((len(fs), Wx.shape[1])).astype('complex128') + + # do squeezing by finding which frequency bin each phase transform + # point w(ai, b) lands in + # look only at points where w(ai, b) is positive and finite + v1 = (opts['findbins'] == 'direct') and (opts['freqscale'] == 'linear') + v2 = (opts['findbins'] == 'direct') and (opts['freqscale'] == 'log') + v3 = (opts['findbins'] == 'min') + + for b in range(N): + for ai in range(len(Wx_scales)): + if v1: + _k = np.round(w[ai, b] * dfs) + k = min(max(_k, 1), len(fs)) if not np.isnan(_k) else len(fs) + elif v2: + _k = 1 + np.round(na / (lfM - lfm) * (np.log2(w[ai, b]) - lfm)) + k = min(max(_k, 1), na) if not np.isnan(_k) else len(fs) + elif v3: + k = np.min(np.abs(w[ai, b] - fs)) + k = int(k - 1) # MAT to py idx + Tx[k, b] += Wx[ai, b] * scaleterm[ai] + + if opts['transform'] == 'CWT': + Tx *= (1 / nv) + Tx *= (fs[1] - fs[0]) + + return Tx, fs + + def _compute_associated_frequencies(na, N, fm, fM, dt, opts): + # frequency divisions `w_l` to search over in Synchrosqueezing + if opts['freqscale'] == 'log': + lfm = np.log2(fm) + lfM = np.log2(fM) + _fs = fm * np.power(fM / fm, np.arange(na - 1) / (np.floor(na) - 1)) + fs = np.hstack([_fs, fM]) + dfs = None + elif opts['freqscale'] == 'linear': + if opts['transform'] == 'CWT': + fs = np.linspace(fm, fM, na) + elif opts['transform'] == 'STFT': + fs = np.linspace(0, 1, N) / dt + fs = fs[:N // 2] + dfs = 1 / (fs[1] - fs[0]) + lfm, lfM = None, None + return fs, dfs, lfm, lfM + + def _process_opts(opts): + if 'freqscale' not in opts: + if opts['transform'] == 'CWT': + opts['freqscale'] = 'log' + elif opts['transform'] == 'STFT': + opts['freqscale'] = 'linear' + opts['findbins'] = opts.get('findbins', 'direct') + opts['squeezing'] = opts.get('squeezing', 'full') + + return opts + + opts = _process_opts(opts) + + dt = t[1] - t[0] + dT = t[-1] - t[0] + + # maximum measurable (Nyquist) frequency of data + fM = 1 / (2 * dt) + # minimum measurable (fundamental) frequency of data + fm = 1 / dT + + # `na` is number of scales for CWT, number of freqs for STFT + na, N = Wx.shape + fs, dfs, lfm, lfM = _compute_associated_frequencies(na, N, fm, fM, dt, opts) + + if opts['transform'] == 'CWT': + Wx_scales = np.power(2 ** (1 / nv), + np.expand_dims(np.arange(1, na + 1), 1)) + scaleterm = np.power(Wx_scales, -.5) + elif opts['transform'] == 'STFT': + Wx_scales = np.linspace(fm, fM, na) + scaleterm = np.ones(Wx_scales.shape) + + # measure version from reference [3] + if opts['squeezing'] == 'measure': + Wx = np.ones(Wx.shape) / Wx.shape[0] + + # incorporate threshold by zeroing out Inf values, so they get ignored below + Wx[np.isinf(w)] = 0 + + Tx, fs = _squeeze(w, Wx, Wx_scales, fs, dfs, scaleterm, + na, N, lfm, lfM, opts) + # MEX version, deprecated (above code has been reworked to attain + # similar speed with JIT compiler) + # Tx = 1/nv * synsq_cwt_squeeze_mex(Wx, w, as, fs, ones[fs.shape], lfm, lfM) + return Tx, fs + + +def synsq_cwt_squeeze(Wx, w, t, nv): + """Calculates the synchrosqueezed transform of `f` on a logarithmic scale. + Used internally by `synsq_cwt_fwd`. + + # Arguments: + Wx: np.ndarray. Wavelet transform of `x`. + w: np.ndarray. Estimate of frequency locations in `Wx` + (see `synsq_cwt_fwd`). + t: np.ndarray. Time vector. + nv: int. Number of voices. + + # Returns: + Tx: synchrosqueezed output. + fs: associated frequencies. + + Note the multiplicative correction term `f` in `_cwt_squeeze`, required + due to the fact that the squeezing integral of Eq. (2.7), in, [1], is taken + w.r.t. dlog(a). This correction term needs to be included as a factor of + Eq. (2.3), which we implement here. + + A more detailed explanation is available in Sec. III of [2]. + Specifically, this is an implementation of Sec. IIIC, Alg. 1. + Note the constant multiplier log(2)/nv has been moved to the + inverse of the normalization constant, as calculated in `synsq_adm`. + + # References: + 1. I. Daubechies, J. Lu, H.T. Wu, "Synchrosqueezed Wavelet Transforms: a + tool for empirical mode decomposition", 2010. + + 2. E. Brevdo, N.S. Fučkar, G. Thakur, and H-T. Wu, "The + Synchrosqueezing algorithm: a robust analysis tool for signals + with time-varying spectrum," 2011. + """ + def _cwt_squeeze(Wx, w, Wx_scales, fs, dfs, N, lfm, lfM): + Tx = np.zeros(Wx.shape) + + for b in range(N): + for ai in range(len(Wx_scales)): + if not np.isinf(w[ai, b]) and w[ai, b] > 0: + # find w_l nearest to w[ai, b] + k = int(np.min(np.max( + 1 + np.floor(na / (lfM - lfm) * (np.log2(w[ai, b]) - lfm)), + 0), na - 1)) + Tx[k, b] += Wx[ai, b] * Wx_scales[ai] ** (-0.5) + + return Tx + + dt = t[1] - t[0] + dT = t[-1] - t[0] + + # Maximum measurable frequency of data + #fM = 1/(4*dt) # wavelet limit - tested + fM = 1/(2*dt) # standard + # Minimum measurable frequency, due to wavelet + fm = 1/dT; + #fm = 1/(2*dT); # standard + + na, N = Wx.shape + + Wx_scales = np.power(2 ** (1 / nv), np.expand_dims(np.arange(1, na + 1))) + # dWx_scales = np.array([1, np.diff(Wx_scales)]) + + lfm = np.log2(fm) + lfM = np.log2(fM) + fs = np.power(2, np.linspace(lfm, lfM, na)) + #dfs = np.array([fs[0], np.diff(fs)]) + + # Harmonics of diff. frequencies but same magniude have same |Tx| + dfs = np.ones(fs.shape) + + if np.linalg.norm(Wx, 'fro') < EPS: + Tx = np.zeros(Wx.shape) + else: + Tx = (1 / nv) * _cwt_squeeze( + Wx, w, Wx_scales, fs, dfs, N, lfm, lfM) + + return Tx + + +def phase_cwt(Wx, dWx, opts={}): + """Calculate the phase transform at each (scale, time) pair: + w[a, b] = Im((1/2pi) * d/db (Wx[a,b]) / Wx[a,b]) + Uses direct differentiation by calculating dWx/db in frequency domain + (the secondary output of `cwt_fwd`, see `cwt_fwd`) + + This is the analytic implementation of Eq. (7) of [1]. + + # Arguments: + Wx: np.ndarray. wavelet transform of `x` (see `cwt_fwd`). + dWx: np.ndarray. Samples of time derivative of wavelet transform of `x` + (see `cwt_fwd`). + opts. dict. Options: + 'gamma': wavelet threshold (default: sqrt(machine epsilon)) + + # Returns: + w: phase transform, w.shape == Wx.shape. + + # References: + 1. G. Thakur, E. Brevdo, N.-S. Fučkar, and H.-T. Wu, + "The Synchrosqueezing algorithm for time-varying spectral analysis: + robustness properties and new paleoclimate applications," + Signal Processing, 93:1079-1094, 2013. + + 2. I. Daubechies, J. Lu, H.T. Wu, "Synchrosqueezed Wavelet Transforms: + an empricial mode decomposition-like tool", + Applied and Computational Harmonic Analysis 30(2):243-261, 2011. + """ + if opts.get('gamma', None) is None: + opts['gamma'] = np.sqrt(EPS) + + # Calculate phase transform for each `ai`, normalize by (2 * pi) + if opts.get('dtype', None) == 'phase': + u = np.unwrap(np.angle(Wx)).T + w = np.array([np.diff(u), u[-1] - u[0]]).T / (2 * PI) + else: + w = np.abs(np.imag(dWx / Wx / (2 * PI))) + + w[np.abs(Wx) < opts['gamma']] = np.inf + return w + + +def phase_cwt_num(Wx, dt, opts={}): + """Calculate the phase transform at each (scale, time) pair: + w[a, b] = Im((1/2pi) * d/db (Wx[a,b]) / Wx[a,b]) + Uses numerical differentiation (1st, 2nd, or 4th order). + + This is a numerical differentiation implementation of Eq. (7) of [1]. + + # Arguments: + Wx: np.ndarray. Wavelet transform of `x` (see `cwt_fwd`). + dt: int. Sampling period (e.g. t[1] - t[0]). + opts. dict. Options: + 'dorder': int (1, 2, 4). Differences order. (default = 4) + 'gamma': float. Wavelet threshold. (default = sqrt(machine epsilon)) + + # Returns: + w: demodulated FM-estimates, w.shape == Wx.shape. + + # References: + 1. G. Thakur, E. Brevdo, N.-S. Fučkar, and H.-T. Wu, + "The Synchrosqueezing algorithm for time-varying spectral analysis: + robustness properties and new paleoclimate applications," + Signal Processing, 93:1079-1094, 2013. + """ + def _differentiate(Wx, dt, opts): + if opts['dorder'] == 1: + w = np.array([Wx[:, 1:] - Wx[:, :-1], + Wx[:, 0] - Wx[:, -1]]) + w /= dt + elif opts['dorder'] == 2: + # append for differentiating + Wxr = np.array([Wx[:, -2:], Wx, Wx[:, :2]]) + # calculate 2nd-order forward difference + w = -Wxr[:, 4:] + 4 * Wxr[:, 3:-1] - 3 * Wxr[:, 2:-2] + w /= (2 * dt) + elif opts['dorder'] == 4: + # calculate 4th-order central difference + w = -Wxr[:, 4:] + w += Wxr[:, 3:-1] * 8 + w -= Wxr[:, 1:-3] * 8 + w += Wxr[:, 0:-4] + w /= (12 * dt) + + return w + + def _process_opts(opts): + # order of differentiation (1, 2, or 4) + opts['dorder'] = opts.get('dorder', 4) + + # epsilon from Daubechies, H-T Wu, et al. + # gamma from Brevdo, H-T Wu, et al. + opts['gamma'] = opts.get('gamma', np.sqrt(EPS)) + + if opts['dorder'] not in (1, 2, 4): + raise ValueError("Differentiation order %d not supported" + % opts['dorder']) + return opts + + opts = _process_opts(opts) + + w = _differentiate(Wx, dt, opts) + w[np.abs(Wx) < opts['gamma']] = np.nan + + # calculate inst. freq for each `ai`, normalize by (2*pi) for true freq + w = np.real(-1j * w / Wx) / (2 * PI) + + return w + + +def cwt_fwd(x, wavelet_type, nv=32, dt=1, opts={}): + """Forward continuous wavelet transform, discretized, as described in + Sec. 4.3.3 of [1] and Sec. IIIA for [2]. This algorithm uses the FFT and + samples the wavelet atoms in the Fourier domain. Options such as padding + of the original signal are allowed. Returns the vector of scales and, if + requested, the analytic time-derivative of the wavelet transform (as + described in Sec. IIIB of [2]). + + # Arguments: + x: np.ndarray. Input signal vector, length `n` (need not be dyadic). + wavelet_type: str. See `wfiltfn`. + nv: int. Number of voices. Suggested >= 32. + dt: int. sampling period. + opts: dict. Options: + 'padtype': ('symmetric', 'replicate', 'circular'). Type of padding. + (default = 'symmetric') + 'rpadded': bool. Whether to return padded Wx and dWx. + (default = False) + 'type', 's', 'mu', ...: str. Wavelet options (see `wfiltfn`). + + # Returns: + Wx: (na x n) size matrix (rows = scales, cols = times), containing + samples of the CWT of `x`. + Wx_scales: `na` length vector containing the associated scales. + dWx: (na x n) size matrix containing samples of the time-derivatives + of the CWT of `x`. + xMean: mean of padded `x`. + + # References: + 1. Mallat, S., Wavelet Tour of Signal Processing 3rd ed. + + 2. G. Thakur, E. Brevdo, N.-S. Fučkar, and H.-T. Wu, + "The Synchrosqueezing algorithm for time-varying spectral analysis: + robustness properties and new paleoclimate applications," + Signal Processing, 93:1079-1094, 2013. + """ + opts['padtype'] = opts.get('padtype', 'symmetric') + opts['rpadded'] = opts.get('rpadded', 0) + + n = len(x) + + # pad x first + x, N, n1, n2 = padsignal(x, opts['padtype']) + + xMean = np.mean(x) + x -= xMean + + # choosing more than this means the wavelet window becomes too short + noct = np.log2(N) - 1 + assert(noct > 0 and noct % 1 == 0) + assert(nv > 0 and nv % 1 == 0) + assert(dt > 0) + assert(not np.any(np.isnan(x))) + + na = int(noct * nv) + Wx_scales = np.power(2 ** (1 / nv), np.arange(1, na + 1)) + + # must cast to complex else value assignment discards imaginary component + Wx = np.zeros((na, N)).astype('complex128') + dWx = Wx.copy() + opts['dt'] = dt + + # x = x.T # already shaped as a row vector + xh = np.fft.fft(x) + + # for each octave + # reworked this part to not use `wfilth`, which slows things down a lot + # due to branching and temp objects; see that function for more comments + k = np.arange(N) + xi = np.zeros((1, N)) + xi[:, :N // 2] = 2 * PI / N * np.arange(N // 2) + xi[:, N // 2 + 1:] = 2 * PI / N * np.arange(-N // 2 + 1, 0) + psihfn = wfiltfn(wavelet_type, opts) + + for ai in range(na): + a = Wx_scales[ai] + psih = psihfn(a * xi) * np.sqrt(a) / np.sqrt(2 *PI) * (-1)**k + dpsih = (1j * xi / opts['dt']) * psih + + xcpsi = np.fft.ifftshift(np.fft.ifft(psih * xh)) + Wx[ai] = xcpsi + + dxcpsi = np.fft.ifftshift(np.fft.ifft(dpsih * xh)) + dWx[ai] = dxcpsi + + # shorten W to proper size + if not opts['rpadded']: + Wx = Wx[ :, n1:n1 + n] + dWx = dWx[:, n1:n1 + n] + + # output for graphing purposes; scale by `dt` + Wx_scales *= dt + + return Wx, Wx_scales, dWx, xMean