Skip to content

Commit

Permalink
Fix conditioning in bivariate MI network inference
Browse files Browse the repository at this point in the history
Fix conditioning in bivariate MI network inference. Per default, the
stats module used all selected variables for conditioning when creating
the surrogate table. Instead, explicitely pass the conditioning set to
the function.

Fix conditioning for AIS estimation. Remove minimum candidate from
conditioning set.

Fix formatting and PEP8.
  • Loading branch information
pwollstadt committed Apr 14, 2020
1 parent 4c4ea88 commit 8a5e34c
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 128 deletions.
64 changes: 35 additions & 29 deletions idtxl/active_information_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import numpy as np
from . import stats
from .single_process_analysis import SingleProcessAnalysis
from .estimator import find_estimator
from .results import ResultsSingleProcessAnalysis
from . import idtxl_exceptions as ex

Expand Down Expand Up @@ -344,8 +343,8 @@ def _include_candidates(self, candidate_set, data):
"""
success = False
if self.settings['verbose']:
print('testing candidate set: {0}'.format(
self._idx_to_lag(candidate_set)))
print('testing candidate set: {0}'.format(
self._idx_to_lag(candidate_set)))
while candidate_set:
# Get realisations for all candidates.
cand_real = data.get_realisations(self.current_value,
Expand All @@ -365,12 +364,12 @@ def _include_candidates(self, candidate_set, data):
# we'll terminate the search for more candidates,
# though those identified already remain valid
print('AlgorithmExhaustedError encountered in '
'estimations: ' + aee.message)
'estimations: ' + aee.message)
print('Halting current estimation set.')
# For now we don't need a stack trace:
# traceback.print_tb(aee.__traceback__)
break

# Test max CMI for significance with maximum statistics.
te_max_candidate = max(temp_te)
max_candidate = candidate_set[np.argmax(temp_te)]
Expand All @@ -379,19 +378,20 @@ def _include_candidates(self, candidate_set, data):
self._idx_to_lag([max_candidate])[0]), end='')
significant = False
try:
significant = stats.max_statistic(self, data, candidate_set,
te_max_candidate)[0]
significant = stats.max_statistic(
self, data, candidate_set, te_max_candidate,
conditional=self._selected_vars_realisations)[0]
except ex.AlgorithmExhaustedError as aee:
# The algorithm cannot continue here, so
# we'll terminate the check on the max stats and not let the
# source pass
print('AlgorithmExhaustedError encountered in '
'estimations: ' + aee.message)
'estimations: ' + aee.message)
print('Halting max stats and further selection for target.')
# For now we don't need a stack trace:
# traceback.print_tb(aee.__traceback__)
break

# If the max is significant keep it and test the next candidate. If
# it is not significant break. There will be no further significant
# sources b/c they all have lesser TE.
Expand Down Expand Up @@ -469,35 +469,41 @@ def _prune_candidates(self, data):
var2=self._current_value_realisations,
conditional=conditional_realisations)
except ex.AlgorithmExhaustedError as aee:
# The algorithm cannot continue here, so
# we'll terminate the pruning check,
# assuming that we need not prune any more
print('AlgorithmExhaustedError encountered in '
'estimations: ' + aee.message)
print('Halting current pruning and allowing others to'
' remain.')
# For now we don't need a stack trace:
# traceback.print_tb(aee.__traceback__)
break
# The algorithm cannot continue here, so we'll terminate the
# pruning check, assuming that we need not prune any more
print('AlgorithmExhaustedError encountered in '
'estimations: ' + aee.message)
print('Halting current pruning and allowing others to'
' remain.')
# For now we don't need a stack trace:
# traceback.print_tb(aee.__traceback__)
break

# Test min TE for significance with minimum statistics.
te_min_candidate = min(temp_te)
min_candidate = self.selected_vars_sources[np.argmin(temp_te)]
if self.settings['verbose']:
print('{0}'.format(self._idx_to_lag([min_candidate])[0]))
print('testing candidate: {0}'.format(
self._idx_to_lag([min_candidate])[0]))
remaining_candidates = set(self.selected_vars_sources).difference(
set([min_candidate]))
conditional_realisations = data.get_realisations(
self.current_value, remaining_candidates)[0]
try:
[significant, p, surr_table] = stats.min_statistic(
self, data,
self.selected_vars_sources,
te_min_candidate)
analysis_setup=self,
data=data,
candidate_set=self.selected_vars_sources,
te_min_candidate=te_min_candidate,
conditional=conditional_realisations)
except ex.AlgorithmExhaustedError as aee:
# The algorithm cannot continue here, so
# we'll terminate the min statistics
# assuming that we need not prune any more
print('AlgorithmExhaustedError encountered in '
'estimations: ' + aee.message)
'estimations: ' + aee.message)
print('Halting current pruning and allowing others to'
' remain.')
' remain.')
# For now we don't need a stack trace:
# traceback.print_tb(aee.__traceback__)
break
Expand Down Expand Up @@ -527,7 +533,7 @@ def _test_final_conditional(self, data):
# The algorithm cannot continue here, so
# we'll set the results to zero
print('AlgorithmExhaustedError encountered in '
'estimations: ' + aee.message)
'estimations: ' + aee.message)
print('Halting AIS final conditional test and setting to not significant.')
# For now we don't need a stack trace:
# traceback.print_tb(aee.__traceback__)
Expand All @@ -553,15 +559,15 @@ def _test_final_conditional(self, data):
# The algorithm cannot continue here, so
# we'll set the results to zero
print('AlgorithmExhaustedError encountered in '
'final local AIS estimations: ' + aee.message)
'final local AIS estimations: ' + aee.message)
print('Setting all local results to zero (but leaving'
' surrogate statistical test results)')
' surrogate statistical test results)')
# For now we don't need a stack trace:
# traceback.print_tb(aee.__traceback__)
# Return local AIS values of all zeros:
# (length gleaned from line below)
local_ais = np.zeros(
(max(replication_ind) + 1)*sum(replication_ind == 0));
(max(replication_ind) + 1)*sum(replication_ind == 0))

# Reshape local AIS to a [replications x samples] matrix.
self.ais = local_ais.reshape(
Expand Down
Loading

0 comments on commit 8a5e34c

Please sign in to comment.