From 4ad29bffdc477aa3d93caff6834db4d7c588b192 Mon Sep 17 00:00:00 2001 From: EtienneCmb Date: Wed, 2 Nov 2022 18:15:14 +0100 Subject: [PATCH] Averageover tapers --- frites/conn/conn_spec.py | 6 ++++-- frites/conn/conn_tf.py | 25 ++++++++++++------------- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/frites/conn/conn_spec.py b/frites/conn/conn_spec.py index 888e9f22e..e5b3bcb4f 100644 --- a/frites/conn/conn_spec.py +++ b/frites/conn/conn_spec.py @@ -242,7 +242,8 @@ def conn_spec( kernel = _create_kernel(sm_times, sm_freqs, kernel=sm_kernel) # average over tapers - tapers_average = mode == 'multitaper' + tapers_average = False + # tapers_average = mode == 'multitaper' # define arguments for parallel computing mesg = f'Estimating pairwise {f_name} for trials %s' @@ -339,13 +340,14 @@ def conn_spec( coords=(trials, roi, times)) freqs = np.linspace(2, 60, 40) n_cycles = freqs / 2. + mt_bandwidth = np.linspace(4, 10, len(freqs)) foi = np.array([[2, 4], [5, 7], [8, 13], [13, 30], [30, 60]]) coh = conn_spec( x, sfreq=sfreq, roi='roi', times='times', sm_times=2., sm_freqs=1, mode='multitaper', n_cycles=n_cycles, freqs=freqs, decim=1, foi=None, n_jobs=1, metric='coh', mean_trials=False, - block_size=2, **kw_links + block_size=2, mt_bandwidth=None, **kw_links ) # coh.mean(('trials', 'roi')).plot() diff --git a/frites/conn/conn_tf.py b/frites/conn/conn_tf.py index 7180a3c74..4d65dbb59 100644 --- a/frites/conn/conn_tf.py +++ b/frites/conn/conn_tf.py @@ -66,24 +66,23 @@ def _tf_decomp(data, sf, freqs, mode='morlet', n_cycles=7.0, mt_bandwidth=None, # the MT decomposition is done separatedly for each # Frequency center if isinstance(mt_bandwidth, (list, tuple, np.ndarray)): - raise NotImplementedError("Not compatible with multiple bandwidth") - # # Arrays freqs, n_cycles, mt_bandwidth should have the same size - # assert len(freqs) == len(n_cycles) == len(mt_bandwidth) - # out = [] - # for f_c, n_c, mt in zip(freqs, n_cycles, mt_bandwidth): - # _out = tfr_array_multitaper( - # data, sf, [f_c], n_cycles=float(n_c), time_bandwidth=mt, - # output='complex', decim=decim, n_jobs=n_jobs, **kw_mt - # ) - # out.append(_out) - - # # stack everything - # out = np.concatenate(out, axis=2) + # Arrays freqs, n_cycles, mt_bandwidth should have the same size + assert len(freqs) == len(n_cycles) == len(mt_bandwidth) + out = [] + for f_c, n_c, mt in zip(freqs, n_cycles, mt_bandwidth): + _out = tfr_array_multitaper( + data, sf, [f_c], n_cycles=float(n_c), time_bandwidth=mt, + output='complex', decim=decim, n_jobs=n_jobs, **kw_mt + ) + out.append(_out.mean(2)) + out = np.concatenate(out, axis=2) elif isinstance(mt_bandwidth, (type(None), int, float)): out = tfr_array_multitaper( data, sf, freqs, n_cycles=n_cycles, time_bandwidth=mt_bandwidth, output='complex', decim=decim, n_jobs=n_jobs, **kw_mt) + # mean across tapers + out = out.mean(axis=2) else: raise ValueError('Method should be either "morlet" or "multitaper"')