Skip to content

Commit

Permalink
Massive performance improvements
Browse files Browse the repository at this point in the history
Vectorized for-loops in `synsq_squeeze`, used list-append instead of array-concat for `buffer`

Also changed import logic
  • Loading branch information
OverLordGoldDragon committed Feb 1, 2020
1 parent bc9e741 commit 048a39f
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 44 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ There are more unlisted (see original repo), but not all will be implemented, in
- **Moved functions**; each no longer has its own file, but is grouped with other relevant functions
- **Code style**; grouped parts of code as sub-functions for improved readability; indentation for vertical alignment; other
- **Performance**; this repo may work faster or slower, as Numpy arrays are faster than C arrays, but some of original funcs use MEX-optimized code with no Numpy equivalent. Also using dense instead of sparse matrices (see below).
- **Performance**; this repo _will_ work **10x+ faster** for some of the methods which were vectorized out of for-loops

**Other**:
- Dense instead of sparse matrices for `stft_fwd` in [stft_transforms.py](https://github.com/OverLordGoldDragon/ssqueezepy/blob/master/synchrosqueezing/stft_transforms.py), as Numpy doesn't handle latter in ops involved
Expand Down
4 changes: 2 additions & 2 deletions ssqueezepy/stft_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
# (https://github.com/ebrevdo/synchrosqueezing/)

import numpy as np
from .utils import wfiltfn, padsignal, buffer
from quadpy import quad as quadgk
from utils import wfiltfn, padsignal, buffer

PI = np.pi
EPS = np.finfo(np.float64).eps # machine epsilon for float64
Expand Down Expand Up @@ -85,7 +85,7 @@ def _process_opts(opts, x):
Sfs = np.linspace(0, 1, opts['winlen'] + 1)
Sfs = Sfs[:np.floor(opts['winlen'] / 2).astype('int64') + 1] / dt

# compute STFt and keep only the positive frequencies
# compute STFT and keep only the positive frequencies
xbuf = buffer(x, opts['winlen'], opts['winlen'] - 1, 'nodelay')
xbuf = np.diag(window) @ xbuf
Sx = np.fft.fft(xbuf, None, axis=0)
Expand Down
6 changes: 3 additions & 3 deletions ssqueezepy/synsq_cwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
# (https://github.com/ebrevdo/synchrosqueezing/)

import numpy as np
from utils import est_riskshrink_thresh, p2up, synsq_adm
from wavelet_transforms import phase_cwt, phase_cwt_num
from wavelet_transforms import cwt_fwd, synsq_squeeze
from .utils import est_riskshrink_thresh, p2up, synsq_adm
from .wavelet_transforms import phase_cwt, phase_cwt_num
from .wavelet_transforms import cwt_fwd, synsq_squeeze


def synsq_cwt_fwd(x, t=None, fs=None, nv=32, opts=None):
Expand Down
6 changes: 3 additions & 3 deletions ssqueezepy/synsq_stft.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
# (https://github.com/ebrevdo/synchrosqueezing/)

import numpy as np
from utils import wfiltfn
from stft_transforms import stft_fwd, phase_stft
from wavelet_transforms import synsq_squeeze
from .utils import wfiltfn
from .stft_transforms import stft_fwd, phase_stft
from .wavelet_transforms import synsq_squeeze
from quadpy import quad as quadgk

PI = np.pi
Expand Down
42 changes: 19 additions & 23 deletions ssqueezepy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,39 +248,35 @@ def buffer(x, n, p=0, opt=None):
ryanjdillon: https://stackoverflow.com/questions/38453249/
is-there-a-matlabs-buffer-equivalent-in-numpy#answer-40105995
"""
import numpy as np

if opt not in ('nodelay', None):
raise ValueError('{} not implemented'.format(opt))

i = 0
first_iter = True
while i < len(x):
if first_iter:
if opt == 'nodelay':
# No zeros at array start
result = x[:n]
i = n
else:
# Start with `p` zeros
result = np.hstack([np.zeros(p), x[:n-p]])
i = n-p
# Make 2D array and pivot
result = np.expand_dims(result, axis=0).T
first_iter = False
continue
if opt == 'nodelay':
# No zeros at array start
result = x[:n]
i = n
else:
# Start with `p` zeros
result = np.hstack([np.zeros(p), x[:n-p]])
i = n-p

# Make 2D array
result = np.expand_dims(result, axis=0)
result = list(result)

while i < len(x):
# Create next column, add `p` results from last col if given
col = x[i:i+(n-p)]
if p != 0:
col = np.hstack([result[:,-1][-p:], col])
i += n-p
col = np.hstack([result[-1][-p:], col])

# Append zeros if last row and not length `n`
if len(col) < n:
col = np.hstack([col, np.zeros(n-len(col))])
if len(col):
col = np.hstack([col, np.zeros(n - len(col))])

# Combine result with next row
result = np.hstack([result, np.expand_dims(col, axis=0).T])
result.append(np.array(col))
i += (n - p)

return result
return np.vstack(result).T
72 changes: 59 additions & 13 deletions ssqueezepy/wavelet_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# (https://github.com/ebrevdo/synchrosqueezing/)

import numpy as np
from utils import padsignal, wfiltfn
from .utils import padsignal, wfiltfn

EPS = np.finfo(np.float64).eps # machine epsilon for float64
PI = np.pi
Expand Down Expand Up @@ -56,6 +56,63 @@ def synsq_squeeze(Wx, w, t, nv=None, opts={}):
SIAM Journal on Mathematical Analysis, 43(5):2078-2095, 2011.
"""
def _squeeze(w, Wx, Wx_scales, fs, dfs, scaleterm, na, N, lfm, lfM, opts):
def _vectorized(scaleterm): # possible to optimize further
if len(scaleterm.shape) == 1:
scaleterm = np.expand_dims(scaleterm, 1)
if v1:
_k = np.round(w * dfs)
_k[np.where(np.isnan(_k))] = len(fs)
_k_ones = np.ones(_k.size)
k = np.min((np.max((_k.flatten(), 1*_k_ones), axis=0),
len(fs)*_k_ones), axis=0).reshape(*_k.shape)
elif v2: # TESTED
_k = 1 + np.round(na / (lfM - lfm) * (np.log2(w) - lfm))
_k[np.where(np.isnan(_k))] = len(fs)
_k_ones = np.ones(_k.size)
k = np.min((np.max((_k.flatten(), 1*_k_ones), axis=0),
na*_k_ones), axis=0).reshape(*_k.shape)
elif v3:
w_rep = np.matlib.repmat(w[None].flatten(), len(fs), 1).reshape(
len(fs), *w.shape)
_k = w_rep - fs.reshape(len(fs), 1, 1)
_k[np.where(np.isnan(_k))] = len(fs)
k = np.min(np.abs(_k), axis=0)

k = np.floor(k).astype('int32') - 1 # MAT to py idx
Ws_prod = Wx * scaleterm

for b in range(N):
for ai in range(len(Wx_scales)):
Tx[k[ai, b], b] += Ws_prod[ai, b]
return Tx

def _for_loop(): # much slower; tested 11x slowdown
if v1:
for b in range(N):
for ai in range(len(Wx_scales)):
_k = np.round(w[ai, b] * dfs)
k = min(max(_k, 1), len(fs)) if not np.isnan(_k) else len(fs)

k = int(k - 1) # MAT to py idx
Tx[k, b] += Wx[ai, b] * scaleterm[ai]
elif v2:
for b in range(N):
for ai in range(len(Wx_scales)):
_k = 1 + np.round(na / (lfM - lfm) * (
np.log2(w[ai, b]) - lfm))
k = min(max(_k, 1), na) if not np.isnan(_k) else len(fs)

k = int(k - 1) # MAT to py idx
Tx[k, b] += Wx[ai, b] * scaleterm[ai]
elif v3:
for b in range(N):
for ai in range(len(Wx_scales)):
k = np.min(np.abs(w[ai, b] - fs))

k = int(k - 1) # MAT to py idx
Tx[k, b] += Wx[ai, b] * scaleterm[ai]
return Tx

# must cast to complex else value assignment discards imaginary component
Tx = np.zeros((len(fs), Wx.shape[1])).astype('complex128')

Expand All @@ -66,18 +123,7 @@ def _squeeze(w, Wx, Wx_scales, fs, dfs, scaleterm, na, N, lfm, lfM, opts):
v2 = (opts['findbins'] == 'direct') and (opts['freqscale'] == 'log')
v3 = (opts['findbins'] == 'min')

for b in range(N):
for ai in range(len(Wx_scales)):
if v1:
_k = np.round(w[ai, b] * dfs)
k = min(max(_k, 1), len(fs)) if not np.isnan(_k) else len(fs)
elif v2:
_k = 1 + np.round(na / (lfM - lfm) * (np.log2(w[ai, b]) - lfm))
k = min(max(_k, 1), na) if not np.isnan(_k) else len(fs)
elif v3:
k = np.min(np.abs(w[ai, b] - fs))
k = int(k - 1) # MAT to py idx
Tx[k, b] += Wx[ai, b] * scaleterm[ai]
Tx = _vectorized(scaleterm)

if opts['transform'] == 'CWT':
Tx *= (1 / nv)
Expand Down

0 comments on commit 048a39f

Please sign in to comment.