Skip to content

Commit

Permalink
Averageover tapers
Browse files Browse the repository at this point in the history
  • Loading branch information
EtienneCmb committed Nov 2, 2022
1 parent 6d3f127 commit 4ad29bf
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 15 deletions.
6 changes: 4 additions & 2 deletions frites/conn/conn_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,8 @@ def conn_spec(
kernel = _create_kernel(sm_times, sm_freqs, kernel=sm_kernel)

# average over tapers
tapers_average = mode == 'multitaper'
tapers_average = False
# tapers_average = mode == 'multitaper'

# define arguments for parallel computing
mesg = f'Estimating pairwise {f_name} for trials %s'
Expand Down Expand Up @@ -339,13 +340,14 @@ def conn_spec(
coords=(trials, roi, times))
freqs = np.linspace(2, 60, 40)
n_cycles = freqs / 2.
mt_bandwidth = np.linspace(4, 10, len(freqs))

foi = np.array([[2, 4], [5, 7], [8, 13], [13, 30], [30, 60]])
coh = conn_spec(
x, sfreq=sfreq, roi='roi', times='times', sm_times=2.,
sm_freqs=1, mode='multitaper', n_cycles=n_cycles, freqs=freqs,
decim=1, foi=None, n_jobs=1, metric='coh', mean_trials=False,
block_size=2, **kw_links
block_size=2, mt_bandwidth=None, **kw_links
)

# coh.mean(('trials', 'roi')).plot()
Expand Down
25 changes: 12 additions & 13 deletions frites/conn/conn_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,24 +66,23 @@ def _tf_decomp(data, sf, freqs, mode='morlet', n_cycles=7.0, mt_bandwidth=None,
# the MT decomposition is done separatedly for each
# Frequency center
if isinstance(mt_bandwidth, (list, tuple, np.ndarray)):
raise NotImplementedError("Not compatible with multiple bandwidth")
# # Arrays freqs, n_cycles, mt_bandwidth should have the same size
# assert len(freqs) == len(n_cycles) == len(mt_bandwidth)
# out = []
# for f_c, n_c, mt in zip(freqs, n_cycles, mt_bandwidth):
# _out = tfr_array_multitaper(
# data, sf, [f_c], n_cycles=float(n_c), time_bandwidth=mt,
# output='complex', decim=decim, n_jobs=n_jobs, **kw_mt
# )
# out.append(_out)

# # stack everything
# out = np.concatenate(out, axis=2)
# Arrays freqs, n_cycles, mt_bandwidth should have the same size
assert len(freqs) == len(n_cycles) == len(mt_bandwidth)
out = []
for f_c, n_c, mt in zip(freqs, n_cycles, mt_bandwidth):
_out = tfr_array_multitaper(
data, sf, [f_c], n_cycles=float(n_c), time_bandwidth=mt,
output='complex', decim=decim, n_jobs=n_jobs, **kw_mt
)
out.append(_out.mean(2))
out = np.concatenate(out, axis=2)
elif isinstance(mt_bandwidth, (type(None), int, float)):
out = tfr_array_multitaper(
data, sf, freqs, n_cycles=n_cycles,
time_bandwidth=mt_bandwidth, output='complex', decim=decim,
n_jobs=n_jobs, **kw_mt)
# mean across tapers
out = out.mean(axis=2)
else:
raise ValueError('Method should be either "morlet" or "multitaper"')

Expand Down

0 comments on commit 4ad29bf

Please sign in to comment.