Skip to content

Commit

Permalink
Fix surrogates for sequential maximum statistics
Browse files Browse the repository at this point in the history
Fix surrogates for sequential maximum statistics. In the old version,
one candidate was missing from the conditioning set when generating
surrogates for sequential max stats. Now the surrogates for each
canditate are created by using a conditioning set that consists of all
selected source and target variables (source variables only for MI
calculation), excluding the candidate that is currently tested.

Add unit test for bivariate sequential max stats.
  • Loading branch information
pwollstadt committed May 16, 2020
1 parent 8a5e34c commit b617ca4
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 112 deletions.
213 changes: 113 additions & 100 deletions idtxl/stats.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -511,22 +511,18 @@ def max_statistic_sequential(analysis_setup, data):
numpy array, float
TE values for individual sources
"""
try:
n_permutations = analysis_setup.settings['n_perm_max_seq']
except KeyError:
try: # use the same n_perm as for min_stats if surr table is reused
analysis_setup.settings['n_perm_max_seq'] = (
analysis_setup.settings['n_perm_min_stat'])
n_permutations = analysis_setup.settings['n_perm_max_seq']
except KeyError: # is surr table is None, use default
analysis_setup.settings['n_perm_max_seq'] = 500
n_permutations = analysis_setup.settings['n_perm_max_seq']
# Set defaults and get test parameters.
analysis_setup.settings.setdefault('n_perm_max_seq', 500)
n_permutations = analysis_setup.settings['n_perm_max_seq']
analysis_setup.settings.setdefault('alpha_max_seq', 0.05)
alpha = analysis_setup.settings['alpha_max_seq']
_check_permute_in_time(analysis_setup, data, n_permutations)
permute_in_time = analysis_setup.settings['permute_in_time']

if analysis_setup.settings['verbose']:
print('sequential maximum statistic, n_perm: {0}'.format(
n_permutations))
print('sequential maximum statistic, n_perm: {0}, testing {1} selected'
' sources'.format(n_permutations,
len(analysis_setup.selected_vars_sources)))

assert analysis_setup.selected_vars_sources, 'No sources to test.'

Expand All @@ -541,29 +537,73 @@ def max_statistic_sequential(analysis_setup, data):

# Calculate TE for each candidate in the conditional source set, i.e.,
# calculate the conditional MI between each candidate and the current
# value, conditional on all selected variables in the conditioning set.
# Then sort the estimated TE values.
# value, conditional on all selected variables in the conditioning set,
# excluding the current source. Calculate surrogates for each candidate by
# shuffling the candidate realisations n_perm times. Afterwards, sort the
# estimated TE values.
i_1 = 0
i_2 = data.n_realisations(analysis_setup.current_value)
surr_table = np.zeros((len(analysis_setup.selected_vars_sources),
n_permutations))
# Collect data for each candidate and the corresponding conditioning set.
for candidate in analysis_setup.selected_vars_sources:
[temp_cond, temp_cand] = analysis_setup._separate_realisations(
idx_conditional,
candidate)
# Use realisations for parallel estimation of the test statistic later.
for idx_c, candidate in enumerate(analysis_setup.selected_vars_sources):
[conditional_realisations_current,
candidate_realisations_current] = analysis_setup._separate_realisations(
idx_conditional, candidate)

# The following may happen if either the requested conditing is 'none'
# or if the conditiong set that is tested consists only of a single
# candidate.
if temp_cond is None:
if conditional_realisations_current is None:
conditional_realisations = None
re_use = ['var2', 'conditional']
else:
conditional_realisations[i_1:i_2, ] = temp_cond
conditional_realisations[i_1:i_2, ] = conditional_realisations_current
re_use = ['var2']
candidate_realisations[i_1:i_2, ] = temp_cand
candidate_realisations[i_1:i_2, ] = candidate_realisations_current
i_1 = i_2
i_2 += data.n_realisations(analysis_setup.current_value)

# Generate surrogates for the current candidate.
if (analysis_setup._cmi_estimator.is_analytic_null_estimator() and
permute_in_time):
# Generate the surrogates analytically
surr_table[idx_c, :] = (
analysis_setup._cmi_estimator.estimate_surrogates_analytic(
n_perm=n_permutations,
var1=data.get_realisations(analysis_setup.current_value,
[candidate])[0],
var2=analysis_setup._current_value_realisations,
conditional=conditional_realisations_current))
else:
analysis_setup.settings['analytical_surrogates'] = False
surr_candidate_realisations = _get_surrogates(
data,
analysis_setup.current_value,
[candidate],
n_permutations,
analysis_setup.settings)
try:
surr_table[idx_c, :] = (
analysis_setup._cmi_estimator.estimate_parallel(
n_chunks=n_permutations,
re_use=['var2', 'conditional'],
var1=surr_candidate_realisations,
var2=analysis_setup._current_value_realisations,
conditional=conditional_realisations_current))
except ex.AlgorithmExhaustedError as aee:
# The aglorithm cannot continue here, so
# we'll terminate the max sequential stats test,
# and declare all not significant
print('AlgorithmExhaustedError encountered in estimations: {}.'.format(
aee.message))
print('Stopping sequential max stats at candidate with rank 0')
return \
(np.zeros(len(analysis_setup.selected_vars_sources)).astype(bool),
np.ones(len(analysis_setup.selected_vars_sources)),
np.zeros(len(analysis_setup.selected_vars_sources)))

# Calculate original statistic (multivariate/bivariate TE/MI)
try:
individual_stat = analysis_setup._cmi_estimator.estimate_parallel(
Expand All @@ -589,36 +629,6 @@ def max_statistic_sequential(analysis_setup, data):

selected_vars_order = utils.argsort_descending(individual_stat)
individual_stat_sorted = utils.sort_descending(individual_stat)

# Re-use surrogate table from previous pruning using min stats, if it
# already exists. This saves some time. Otherwise create surrogate table.
# Sort surrogate table.
if (analysis_setup._min_stats_surr_table is not None and
n_permutations <= analysis_setup._min_stats_surr_table.shape[1]):
surr_table = analysis_setup._min_stats_surr_table[:, :n_permutations]
assert len(analysis_setup.selected_vars_sources) == surr_table.shape[0]
else:
try:
surr_table = _create_surrogate_table(
analysis_setup=analysis_setup,
data=data,
idx_test_set=analysis_setup.selected_vars_sources,
n_perm=n_permutations,
conditional=conditional_realisations)
except ex.AlgorithmExhaustedError as aee:
# The aglorithm cannot continue here, so
# we'll terminate the max sequential stats test,
# and declare all not significant
print('AlgorithmExhaustedError encountered in estimation: '
'{}.'.format(aee.message))
print('Stopping sequential max stats at candidate with rank 0')
# For now we don't need a stack trace:
# traceback.print_tb(aee.__traceback__)
# Return (signficance, pvalue, TEs):
return (
np.zeros(len(analysis_setup.selected_vars_sources)).astype(bool),
np.ones(len(analysis_setup.selected_vars_sources)),
np.zeros(len(analysis_setup.selected_vars_sources)))
max_distribution = _sort_table_max(surr_table)

# Compare each original value with the distribution of the same rank,
Expand Down Expand Up @@ -693,21 +703,18 @@ def max_statistic_sequential_bivariate(analysis_setup, data):
numpy array, float
TE values for individual sources
"""
try:
n_permutations = analysis_setup.settings['n_perm_max_seq']
except KeyError:
try: # use the same n_perm as for min_stats if surr table is reused
n_permutations = analysis_setup._min_stats_surr_table.shape[1]
analysis_setup.settings['n_perm_max_seq'] = n_permutations
except AttributeError: # is surr table is None, use default
analysis_setup.settings['n_perm_max_seq'] = 500
n_permutations = analysis_setup.settings['n_perm_max_seq']
# Set defaults and get test parameters.
analysis_setup.settings.setdefault('n_perm_max_seq', 500)
n_permutations = analysis_setup.settings['n_perm_max_seq']
analysis_setup.settings.setdefault('alpha_max_seq', 0.05)
alpha = analysis_setup.settings['alpha_max_seq']
_check_permute_in_time(analysis_setup, data, n_permutations)
permute_in_time = analysis_setup.settings['permute_in_time']

if analysis_setup.settings['verbose']:
print('sequential maximum statistic, n_perm: {0}'.format(
n_permutations))
print('sequential maximum statistic, n_perm: {0}, testing {1} selected'
' sources'.format(n_permutations,
len(analysis_setup.selected_vars_sources)))

assert analysis_setup.selected_vars_sources, 'No sources to test.'

Expand Down Expand Up @@ -750,13 +757,14 @@ def max_statistic_sequential_bivariate(analysis_setup, data):
# conditioning set. Then sort the estimated TE/MI values.
i_1 = 0
i_2 = data.n_realisations(analysis_setup.current_value)
surr_table = np.zeros((len(source_vars), n_permutations))
# Collect data for each candidate and the corresponding conditioning set.
for candidate in source_vars:
for idx_c, candidate in enumerate(source_vars):
temp_cond = data.get_realisations(
analysis_setup.current_value,
set(source_vars).difference(set([candidate])))[0]
analysis_setup.current_value,
set(source_vars).difference(set([candidate])))[0]
temp_cand = data.get_realisations(
analysis_setup.current_value, [candidate])[0]
analysis_setup.current_value, [candidate])[0]
# The following may happen if either the requested conditing is
# 'none' or if the conditiong set that is tested consists only of
# a single candidate.
Expand All @@ -774,6 +782,45 @@ def max_statistic_sequential_bivariate(analysis_setup, data):
i_1 = i_2
i_2 += data.n_realisations(analysis_setup.current_value)

# Generate surrogates for the current candidate.
if (analysis_setup._cmi_estimator.is_analytic_null_estimator() and
permute_in_time):
# Generate the surrogates analytically
surr_table[idx_c, :] = (
analysis_setup._cmi_estimator.estimate_surrogates_analytic(
n_perm=n_permutations,
var1=data.get_realisations(analysis_setup.current_value,
[candidate])[0],
var2=analysis_setup._current_value_realisations,
conditional=temp_cond))
else:
analysis_setup.settings['analytical_surrogates'] = False
surr_candidate_realisations = _get_surrogates(
data,
analysis_setup.current_value,
[candidate],
n_permutations,
analysis_setup.settings)
try:
surr_table[idx_c, :] = (
analysis_setup._cmi_estimator.estimate_parallel(
n_chunks=n_permutations,
re_use=['var2', 'conditional'],
var1=surr_candidate_realisations,
var2=analysis_setup._current_value_realisations,
conditional=temp_cond))
except ex.AlgorithmExhaustedError as aee:
# The aglorithm cannot continue here, so
# we'll terminate the max sequential stats test,
# and declare all not significant
print('AlgorithmExhaustedError encountered in estimations: {}.'.format(
aee.message))
print('Stopping sequential max stats at candidate with rank 0')
return \
(np.zeros(len(analysis_setup.selected_vars_sources)).astype(bool),
np.ones(len(analysis_setup.selected_vars_sources)),
np.zeros(len(analysis_setup.selected_vars_sources)))

# Calculate original statistic (multivariate/bivariate TE/MI)
try:
individual_stat = analysis_setup._cmi_estimator.estimate_parallel(
Expand All @@ -799,40 +846,6 @@ def max_statistic_sequential_bivariate(analysis_setup, data):

selected_vars_order = utils.argsort_descending(individual_stat)
individual_stat_sorted = utils.sort_descending(individual_stat)

# Don't re-use surrogate table from previous pruning using min stats
# like for the multivariate algorithm. There is no longer a global
# min_stats including all sources variables, but a separate table per
# source.
conditional_realisations_sources = data.get_realisations(
analysis_setup.current_value, source_vars)[0]
if conditional_realisations_target is None:
conditional_realisations = conditional_realisations_sources
else:
conditional_realisations = np.hstack((
conditional_realisations_sources,
conditional_realisations_target))
try:
surr_table = _create_surrogate_table(
analysis_setup=analysis_setup,
data=data,
idx_test_set=analysis_setup.selected_vars_sources,
n_perm=n_permutations,
conditional=conditional_realisations)
except ex.AlgorithmExhaustedError as aee:
# The algorithm cannot continue here, so
# we'll terminate the max sequential stats test,
# and declare all not significant
print('AlgorithmExhaustedError encountered in '
'estimations: {}.'.format(aee.message))
print('Stopping sequential max stats at candidate with rank 0')
# For now we don't need a stack trace:
# traceback.print_tb(aee.__traceback__)
# Return (signficance, pvalue, TEs):
return (
np.zeros(len(analysis_setup.selected_vars_sources)).astype(bool),
np.ones(len(analysis_setup.selected_vars_sources)),
np.zeros(len(analysis_setup.selected_vars_sources)))
max_distribution = _sort_table_max(surr_table)

# Compare each original value with the distribution of the same rank,
Expand Down
27 changes: 15 additions & 12 deletions test/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def test_max_statistic_sequential():
1)
[sign, p, te] = stats.max_statistic_sequential(analysis_setup=setup,
data=data)
[sign, p, te] = stats.max_statistic_sequential_bivariate(
analysis_setup=setup,
data=data)


def test_network_fdr():
Expand Down Expand Up @@ -326,16 +329,16 @@ def test_analytical_surrogates():


if __name__ == '__main__':
test_ais_fdr()
test_analytical_surrogates()
test_data_type()
test_network_fdr()
test_find_pvalue()
test_find_table_max()
test_find_table_min()
test_sort_table_max()
test_sort_table_min()
test_omnibus_test()
test_max_statistic()
test_min_statistic()
# test_ais_fdr()
# test_analytical_surrogates()
# test_data_type()
# test_network_fdr()
# test_find_pvalue()
# test_find_table_max()
# test_find_table_min()
# test_sort_table_max()
# test_sort_table_min()
# test_omnibus_test()
# test_max_statistic()
# test_min_statistic()
test_max_statistic_sequential()

0 comments on commit b617ca4

Please sign in to comment.