From 4c4ea88daaefc2f87dd44819d88a4f5326b5f2db Mon Sep 17 00:00:00 2001 From: Patricia Wollstadt Date: Tue, 14 Apr 2020 15:51:57 +0200 Subject: [PATCH] Add PEP8 and documentation for Goettingen PID Add PEP8 and documentation for Goettingen PID. Make PrettyTable import optional and issue a warning similar to the ECOS estimator for the Tartu PID. Otherwise, any analysis crashes because it tries to import somthing that is optional and thus probably not installed. Rename current PID estimators to BivariatePID and new estimator to MultivariatePID to be consistent with MI and TE class names. Fix import error in multivariate PID estimator. --- ...decomposition.py => demo_bivariate_pid.py} | 7 +- ...omposition.py => demo_multivariate_pid.py} | 19 +- ...tion_decomposition.py => bivariate_pid.py} | 20 +- idtxl/estimators_multivariate_pid.py | 66 +++-- ...n_decomposition.py => multivariate_pid.py} | 60 ++-- idtxl/pid_goettingen.py | 276 ++++++++++-------- idtxl/results.py | 20 +- ...osition.py => systemtest_bivariate_pid.py} | 8 +- ...decomposition.py => test_bivariate_pid.py} | 12 +- ...omposition.py => test_multivariate_pid.py} | 18 +- 10 files changed, 262 insertions(+), 244 deletions(-) rename demos/{demo_partial_information_decomposition.py => demo_bivariate_pid.py} (89%) rename demos/{demo_multivariate_partial_information_decomposition.py => demo_multivariate_pid.py} (93%) rename idtxl/{partial_information_decomposition.py => bivariate_pid.py} (95%) rename idtxl/{multivariate_partial_information_decomposition.py => multivariate_pid.py} (92%) rename test/{systemtest_partial_information_decomposition.py => systemtest_bivariate_pid.py} (93%) rename test/{test_partial_information_decomposition.py => test_bivariate_pid.py} (95%) rename test/{test_multivariate_partial_information_decomposition.py => test_multivariate_pid.py} (93%) diff --git a/demos/demo_partial_information_decomposition.py b/demos/demo_bivariate_pid.py similarity index 89% rename from demos/demo_partial_information_decomposition.py rename to demos/demo_bivariate_pid.py index e61eeb45..f8ede06a 100644 --- a/demos/demo_partial_information_decomposition.py +++ b/demos/demo_bivariate_pid.py @@ -1,7 +1,6 @@ # Import classes import numpy as np -from idtxl.partial_information_decomposition import ( - PartialInformationDecomposition) +from idtxl.bivariate_pid import BivariatePID from idtxl.data import Data # a) Generate test data @@ -13,7 +12,7 @@ data = Data(np.vstack((x, y, z)), 'ps', normalise=False) # b) Initialise analysis object and define settings for both PID estimators -pid = PartialInformationDecomposition() +pid = BivariatePID() settings_tartu = {'pid_estimator': 'TartuPID', 'lags_pid': [0, 0]} settings_sydney = { 'alph_s1': alph, @@ -30,7 +29,7 @@ settings=settings_tartu, data=data, target=2, sources=[0, 1]) # d) Run Sydney estimator -pid = PartialInformationDecomposition() +pid = BivariatePID() results_sydney = pid.analyse_single_target( settings=settings_sydney, data=data, target=2, sources=[0, 1]) diff --git a/demos/demo_multivariate_partial_information_decomposition.py b/demos/demo_multivariate_pid.py similarity index 93% rename from demos/demo_multivariate_partial_information_decomposition.py rename to demos/demo_multivariate_pid.py index 46e45e88..ee119f44 100644 --- a/demos/demo_multivariate_partial_information_decomposition.py +++ b/demos/demo_multivariate_pid.py @@ -1,7 +1,6 @@ # Import classes import numpy as np -from idtxl.multivariate_partial_information_decomposition import ( - MultivariatePartialInformationDecomposition) +from idtxl.multivariate_pid import MultivariatePID from idtxl.data import Data # a) Generate test data @@ -13,7 +12,7 @@ data = Data(np.vstack((x, y, z)), 'ps', normalise=False) # b) Initialise analysis object and define settings for SxPID estimators -pid = MultivariatePartialInformationDecomposition() +pid = MultivariatePID() settings_SxPID = {'pid_estimator': 'SxPID', 'lags_pid': [0, 0]} # c) Run Goettingen estimator @@ -39,13 +38,13 @@ # Some special Examples -# Pointwise Unique +# Pointwise Unique x = np.asarray([0, 1, 0, 2]) y = np.asarray([1, 0, 2, 0]) z = np.asarray([1, 1, 2, 2]) data = Data(np.vstack((x, y, z)), 'ps', normalise=False) -pid = MultivariatePartialInformationDecomposition() +pid = MultivariatePID() settings_SxPID = {'pid_estimator': 'SxPID', 'lags_pid': [0, 0]} results_SxPID = pid.analyse_single_target( @@ -66,13 +65,13 @@ results_SxPID.get_single_target(2)['avg'][((1,2,),)][2], 0.)) -# Redundancy Error +# Redundancy Error x = np.asarray([0, 0, 0, 1, 1, 1, 0, 1]) y = np.asarray([0, 0, 0, 1, 1, 1, 1, 0]) z = np.asarray([0, 0, 0, 1, 1, 1, 0, 1]) data = Data(np.vstack((x, y, z)), 'ps', normalise=False) -pid = MultivariatePartialInformationDecomposition() +pid = MultivariatePID() settings_SxPID = {'pid_estimator': 'SxPID', 'lags_pid': [0, 0]} results_SxPID = pid.analyse_single_target( @@ -93,14 +92,14 @@ results_SxPID.get_single_target(2)['avg'][((1,2,),)][2], 0.368)) -# Three bits hash +# Three bits hash s1 = np.asarray([0, 0, 0, 0, 1, 1, 1, 1]) s2 = np.asarray([0, 0, 1, 1, 0, 0, 1, 1]) s3 = np.asarray([0, 1, 0, 1, 0, 1, 0, 1]) z = np.asarray([0, 1, 1, 0, 1, 0, 0, 1]) data = Data(np.vstack((s1, s2, s3, z)), 'ps', normalise=False) -pid = MultivariatePartialInformationDecomposition() +pid = MultivariatePID() settings_SxPID = {'pid_estimator': 'SxPID', 'lags_pid': [0, 0, 0]} results_SxPID = pid.analyse_single_target( @@ -113,7 +112,7 @@ print('Uni s2 {0:.4f}\t\t{1:.4f}'.format( results_SxPID.get_single_target(3)['avg'][((2,),)][2], 0.3219)) print('Uni s3 {0:.4f}\t\t{1:.4f}'.format( - results_SxPID.get_single_target(3)['avg'][((3,),)][2], 0.3219)) + results_SxPID.get_single_target(3)['avg'][((3,),)][2], 0.3219)) print('Synergy s1_s2_s3 {0:.4f}\t\t{1:.4f}'.format( results_SxPID.get_single_target(3)['avg'][((1,2,3),)][2], 0.2451)) print('Synergy s1_s2 {0:.4f}\t\t{1:.4f}'.format( diff --git a/idtxl/partial_information_decomposition.py b/idtxl/bivariate_pid.py similarity index 95% rename from idtxl/partial_information_decomposition.py rename to idtxl/bivariate_pid.py index b021ce8b..d3b158c9 100644 --- a/idtxl/partial_information_decomposition.py +++ b/idtxl/bivariate_pid.py @@ -8,10 +8,10 @@ import numpy as np from .single_process_analysis import SingleProcessAnalysis from .estimator import find_estimator -from .results import ResultsPartialInformationDecomposition +from .results import ResultsPID -class PartialInformationDecomposition(SingleProcessAnalysis): +class BivariatePID(SingleProcessAnalysis): """Perform partial information decomposition for individual processes. Perform partial information decomposition (PID) for two source processes @@ -75,7 +75,7 @@ def analyse_network(self, settings, data, targets, sources): >>> 'pid_estimator': 'SydneyPID'} >>> targets = [0, 1, 2] >>> sources = [[1, 2], [0, 2], [0, 1]] - >>> pid_analysis = PartialInformationDecomposition() + >>> pid_analysis = BivariatePID() >>> results = pid_analysis.analyse_network(settings, data, targets, >>> sources) @@ -97,9 +97,9 @@ def analyse_network(self, settings, data, targets, sources): [[0, 2], [1, 0]], must have the same length as targets Returns: - ResultsPartialInformationDecomposition instance + ResultsPID instance results of network inference, see documentation of - ResultsPartialInformationDecomposition() + ResultsPID() """ # Set defaults for PID estimation. settings.setdefault('verbose', True) @@ -112,7 +112,7 @@ def analyse_network(self, settings, data, targets, sources): list_of_lags = settings['lags_pid'] # Perform PID estimation for each target individually - results = ResultsPartialInformationDecomposition( + results = ResultsPID( n_nodes=data.n_processes, n_realisations=data.n_realisations(), normalised=data.normalise) @@ -158,7 +158,7 @@ def analyse_single_target(self, settings, data, target, sources): >>> 'max_iters': 1000, >>> 'pid_calc_name': 'SydneyPID', >>> 'lags_pid': [2, 3]} - >>> pid_analysis = PartialInformationDecomposition() + >>> pid_analysis = BivariatePID() >>> results = pid_analysis.analyse_single_target(settings=settings, >>> data=data, >>> target=0, @@ -181,9 +181,9 @@ def analyse_single_target(self, settings, data, target, sources): sources : list of ints indices of the two source processes for the target - Returns: ResultsPartialInformationDecomposition instance results of + Returns: ResultsPID instance results of network inference, see documentation of - ResultsPartialInformationDecomposition() + ResultsPID() """ # Check input and initialise values for analysis. self._initialise(settings, data, target, sources) @@ -192,7 +192,7 @@ def analyse_single_target(self, settings, data, target, sources): self._calculate_pid(data) # Add analyis info. - results = ResultsPartialInformationDecomposition( + results = ResultsPID( n_nodes=data.n_processes, n_realisations=data.n_realisations(self.current_value), normalised=data.normalise) diff --git a/idtxl/estimators_multivariate_pid.py b/idtxl/estimators_multivariate_pid.py index 896116d5..ad404639 100755 --- a/idtxl/estimators_multivariate_pid.py +++ b/idtxl/estimators_multivariate_pid.py @@ -1,36 +1,39 @@ """Multivariate Partical information decomposition for discrete random variables. -This module provides an estimator for multivariate partial information decomposition -as proposed in +This module provides an estimator for multivariate partial information +decomposition as proposed in -- Makkeh, A. & Gutknecht, A. & Wibral, M. (2020). A Differentiable measure +- Makkeh, A. & Gutknecht, A. & Wibral, M. (2020). A Differentiable measure for shared information. 1- 27 Retrieved from - http://arxiv.org/abs/2002.03356 + http://arxiv.org/abs/2002.03356 """ import numpy as np from . import lattices as lt from . import pid_goettingen from .estimator import Estimator +from .estimators_pid import _join_variables # TODO add support for multivariate estimation for Tartu and Sydney estimator + class SxPID(Estimator): - """Estimate partial information decomposition for multiple inputs (up to 4 inputs) - and one output + """Estimate partial information decomposition for multiple inputs. - Implementation of the multivariate partial information decomposition (PID) - estimator for discrete data. The estimator finds shared information, unique - information and synergistic information between the multiple inputs - s1, s2, ..., sn with respect to the output t for each realization (t, s1, ..., sn) - and then average them according to their distribution weights p(t, s1, ..., sn). - Both the pointwise (on the realization level) PID and the averaged PID are - returned (see the 'return' of 'estimate()'). + Implementation of the multivariate partial information decomposition (PID) + estimator for discrete data with (up to 4 inputs) and one output. The + estimator finds shared information, unique information and synergistic + information between the multiple inputs s1, s2, ..., sn with respect to the + output t for each realization (t, s1, ..., sn) and then average them + according to their distribution weights p(t, s1, ..., sn). Both the + pointwise (on the realization level) PID and the averaged PID are returned + (see the 'return' of 'estimate()'). - The algorithm uses recurrsion to compute the partial information decomposition + The algorithm uses recursion to compute the partial information + decomposition. References: - - Makkeh, A. & Wibral, M. (2020). A differentiable pointwise partial + - Makkeh, A. & Wibral, M. (2020). A differentiable pointwise partial Information Decomposition estimator. https://github.com/Abzinger/SxPID. Args: @@ -39,14 +42,12 @@ class SxPID(Estimator): - verbose : bool [optional] - print output to console (default=False) - """ def __init__(self, settings): # get estimation parameters self.settings = settings.copy() self.settings.setdefault('verbose', False) - def is_parallel(): return False @@ -69,7 +70,7 @@ def estimate(self, s, t): 'avg' -> {alpha -> [float, float, float]} } - where the list of floats is ordered + where the list of floats is ordered [informative, misinformative, informative - misinformative] ptw stands for pointwise decomposition avg stands for average decomposition @@ -84,22 +85,24 @@ def estimate(self, s, t): # children is a list of tuples lattices = lt.lattices num_source_vars = len(s) - retval_ptw, retval_avg = pid_goettingen.pid(num_source_vars, pdf_orig=pdf, - chld=lattices[num_source_vars][0], - achain=lattices[num_source_vars][1], - printing = self.settings['verbose']) - - #AskM: Trivariate: does it make sense to name the alphas + retval_ptw, retval_avg = pid_goettingen.pid( + num_source_vars, + pdf_orig=pdf, + chld=lattices[num_source_vars][0], + achain=lattices[num_source_vars][1], + printing=self.settings['verbose']) + + # TODO AskM: Trivariate: does it make sense to name the alphas # for example shared_syn_s1_s2__syn_s1_s3 ? results = { 'ptw': retval_ptw, 'avg': retval_avg, - } + } return results def _get_pdf_dict(s, t): - """"Stores the probability mass function estimated via counting to a dictionary""" + """"Write probability mass function estimated via counting to a dict.""" # Create dictionary with probability mass function counts = dict() n_samples = s[0].shape[0] @@ -142,12 +145,12 @@ def _check_input(s, t, settings): alph_new = len(np.unique(s[i][:, 0])) for col in range(1, s[i].shape[1]): alph_col = len(np.unique(s[i][:, col])) - si_joint, alph_new = _join_variables(si_joint, s[i][:, col], - alph_new, alph_col) + si_joint, alph_new = _join_variables( + si_joint, s[i][:, col], alph_new, alph_col) settings['alph_s'+str(i+1)] = alph_new else: - raise ValueError('Input source {0} s{0} has to be a 1D or 2D numpy ' - 'array.'.format(i+1)) + raise ValueError('Input source {0} s{0} has to be a 1D or 2D ' + 'numpy array.'.format(i+1)) if t.ndim != 1: if t.shape[1] == 1: @@ -163,7 +166,7 @@ def _check_input(s, t, settings): if not issubclass(s[i].dtype.type, np.integer): raise TypeError('Input s{0} (source {0}) must be an integer numpy ' 'array.'.format(i+1)) - #^ for + # ^ for if not issubclass(t.dtype.type, np.integer): raise TypeError('Input t (target) must be an integer numpy array.') @@ -173,4 +176,3 @@ def _check_input(s, t, settings): raise ValueError('Number of samples s and t must be equal') return s, t, settings - diff --git a/idtxl/multivariate_partial_information_decomposition.py b/idtxl/multivariate_pid.py similarity index 92% rename from idtxl/multivariate_partial_information_decomposition.py rename to idtxl/multivariate_pid.py index 59432597..b77e1174 100644 --- a/idtxl/multivariate_partial_information_decomposition.py +++ b/idtxl/multivariate_pid.py @@ -1,6 +1,6 @@ """Estimate partial information decomposition (PID). -Estimate PID for multiple sources (up to 4 sources) and one target process +Estimate PID for multiple sources (up to 4 sources) and one target process using SxPID estimator. Note: @@ -9,17 +9,17 @@ import numpy as np from .single_process_analysis import SingleProcessAnalysis from .estimator import find_estimator -from .results import ResultsMultivariatePartialInformationDecomposition +from .results import ResultsMultivariatePID -class MultivariatePartialInformationDecomposition(SingleProcessAnalysis): +class MultivariatePID(SingleProcessAnalysis): """Perform partial information decomposition for individual processes. - Perform partial information decomposition (PID) for multiple source - processes (up to 4 sources) and a target process in the network. - Estimate unique, shared, and synergistic information in the multiple - sources about the target. Call analyse_network() on the whole network - or a set of nodes or call analyse_single_target() to estimate PID for + Perform partial information decomposition (PID) for multiple source + processes (up to 4 sources) and a target process in the network. + Estimate unique, shared, and synergistic information in the multiple + sources about the target. Call analyse_network() on the whole network + or a set of nodes or call analyse_single_target() to estimate PID for a single process. See docstrings of the two functions for more information. References: @@ -27,7 +27,7 @@ class MultivariatePartialInformationDecomposition(SingleProcessAnalysis): - Williams, P. L., & Beer, R. D. (2010). Nonnegative Decomposition of Multivariate Information, 1–14. Retrieved from http://arxiv.org/abs/1004.2515 - - Makkeh, A. & Gutknecht, A. & Wibral, M. (2020). A Differentiable measure + - Makkeh, A. & Gutknecht, A. & Wibral, M. (2020). A Differentiable measure for shared information. 1- 27 Retrieved from http://arxiv.org/abs/2002.03356 @@ -48,10 +48,10 @@ def __init__(self): def analyse_network(self, settings, data, targets, sources): """Estimate partial information decomposition for network nodes. - Estimate, for multiple nodes (target processes), the partial - information decomposition (PID) for multiple source processes - (up to 4 sources) and each of these target processes - in the network. + Estimate, for multiple nodes (target processes), the partial + information decomposition (PID) for multiple source processes + (up to 4 sources) and each of these target processes + in the network. Note: For a detailed description of the algorithm and settings see @@ -67,7 +67,7 @@ def analyse_network(self, settings, data, targets, sources): >>> s3 = np.random.randint(0, alph, n) >>> target1 = np.logical_xor(s1, s2).astype(int) >>> target = np.logical_xor(target1, s3).astype(int) - >>> data = Data(np.vstack((s1, s2, s3, target)), 'ps', + >>> data = Data(np.vstack((s1, s2, s3, target)), 'ps', >>> normalise=False) >>> settings = { >>> 'lags_pid': [[1, 1, 1], [3, 2, 7]], @@ -75,7 +75,7 @@ def analyse_network(self, settings, data, targets, sources): >>> 'pid_estimator': 'SxPID'} >>> targets = [0, 1] >>> sources = [[1, 2, 3], [0, 2, 3]] - >>> pid_analysis = MultivariatePartialInformationDecomposition() + >>> pid_analysis = MultivariatePID() >>> results = pid_analysis.analyse_network(settings, data, targets, >>> sources) @@ -86,7 +86,7 @@ def analyse_network(self, settings, data, targets, sources): contain - lags_pid : list of lists of ints [optional] - lags in samples - between sources and target + between sources and target (default=[[1, 1, ..., 1], [1, 1, ..., 1], ...]) data : Data instance @@ -95,13 +95,13 @@ def analyse_network(self, settings, data, targets, sources): index of target processes sources : list of lists indices of the multiple source processes for each target, e.g., - [[0, 1, 2], [1, 0, 3]], all must lists be of the same lenght and + [[0, 1, 2], [1, 0, 3]], all must lists be of the same lenght and list of lists must have the same length as targets Returns: - ResultsMultivariatePartialInformationDecomposition instance + ResultsMultivariatePID instance results of network inference, see documentation of - ResultsMultivariatePartialInformationDecomposition() + ResultsMultivariatePID() """ # Set defaults for PID estimation. settings.setdefault('verbose', True) @@ -119,11 +119,11 @@ def analyse_network(self, settings, data, targets, sources): #^ if #^ for #^ for - + list_of_lags = settings['lags_pid'] # Perform PID estimation for each target individually - results = ResultsMultivariatePartialInformationDecomposition( + results = ResultsMultivariatePID( n_nodes=data.n_processes, n_realisations=data.n_realisations(), normalised=data.normalise) @@ -144,7 +144,7 @@ def analyse_network(self, settings, data, targets, sources): def analyse_single_target(self, settings, data, target, sources): """Estimate partial information decomposition for a network node. - Estimate partial information decomposition (PID) for multiple source + Estimate partial information decomposition (PID) for multiple source processes (up to 4 sources) and a target process in the network. Note: @@ -160,13 +160,13 @@ def analyse_single_target(self, settings, data, target, sources): >>> s3 = np.random.randint(0, alph, n) >>> target1 = np.logical_xor(s1, s2).astype(int) >>> target = np.logical_xor(target1, s3).astype(int) - >>> data = Data(np.vstack((s1, s2, s3, target)), 'ps', + >>> data = Data(np.vstack((s1, s2, s3, target)), 'ps', >>> normalise=False) >>> settings = { >>> 'verbose' : false, >>> 'pid_estimator': 'SxPID', >>> 'lags_pid': [2, 3, 1]} - >>> pid_analysis = MultivariatePartialInformationDecomposition() + >>> pid_analysis = MultivariatePID() >>> results = pid_analysis.analyse_single_target(settings=settings, >>> data=data, >>> target=0, @@ -189,9 +189,9 @@ def analyse_single_target(self, settings, data, target, sources): sources : list of ints indices of the multiple source processes for the target - Returns: ResultsMultivariatePartialInformationDecomposition instance results of + Returns: ResultsMultivariatePID instance results of network inference, see documentation of - ResultsPartialInformationDecomposition() + ResultsPID() """ # Check input and initialise values for analysis. self._initialise(settings, data, target, sources) @@ -200,7 +200,7 @@ def analyse_single_target(self, settings, data, target, sources): self._calculate_pid(data) # Add analyis info. - results = ResultsMultivariatePartialInformationDecomposition( + results = ResultsMultivariatePID( n_nodes=data.n_processes, n_realisations=data.n_realisations(self.current_value), normalised=data.normalise) @@ -253,7 +253,7 @@ def _initialise(self, settings, data, target, sources): self.sources = self._lag_to_idx([ (sources[i], self.settings['lags_pid'][i]) for i in range(len(sources))]) - + def _calculate_pid(self, data): # TODO Discuss how and if the following statistical testing should be @@ -275,14 +275,14 @@ def _calculate_pid(self, data): list_sources_var_realisations = [data.get_realisations( self.current_value, [self.sources[i]])[0] - for i in range(len(self.sources))] + for i in range(len(self.sources))] orig_pid = self._pid_estimator.estimate( s=list_sources_var_realisations, t=target_realisations) - + self.results = orig_pid for i in range(len(self.sources)): self.results['source_'+str(i+1)] = self._idx_to_lag([self.sources[i]]) diff --git a/idtxl/pid_goettingen.py b/idtxl/pid_goettingen.py index adde9f0f..d84b9598 100644 --- a/idtxl/pid_goettingen.py +++ b/idtxl/pid_goettingen.py @@ -1,44 +1,47 @@ -""" -Shared exclusion partial information decomposition (SxPID) -""" +"""Shared exclusion partial information decomposition (SxPID).""" import numpy as np import math -import time from itertools import chain, combinations -from collections import defaultdict -from prettytable import PrettyTable +from . import idtxl_exceptions as ex +try: + from prettytable import PrettyTable +except ImportError as err: + ex.package_missing( + err, + 'PrettyTable is not available on this system. Install it from ' + 'https://pypi.org/project/PrettyTable/ to use the Goettinge PID ' + 'estimator.') + -#--------- -# Lattice -#--------- class Lattice: """Generates the redundancy lattice for 'n' sources - The algerbric structure on which partial information decomposition is + The algebraic structure on which partial information decomposition is build on. """ + def __init__(self, n): self.n = n - self.lis = [i for i in range(1,self.n+1)] - #^ _init_() - + self.lis = [i for i in range(1, self.n+1)] + # ^ _init_() + def powerset(self): return chain.from_iterable(combinations(self.lis, r) - for r in range(1,len(self.lis) + 1) ) - #^ powerset() + for r in range(1, len(self.lis) + 1)) + # ^ powerset() def less_than(self, beta, alpha): - """compare whether an antichain beta is smaller than antichain + """compare whether an antichain beta is smaller than antichain alpha""" return all(any(frozenset(b) <= frozenset(a) for b in beta) for a in alpha) - #^ compare() + # ^compare()s - def comparable(self, a,b): + def comparable(self, a, b): return a < b or a > b - #^ comparable() + # ^comparable() def antichain(self): - """Generates the nodes (antichains) of the lattice""" + """Generates the nodes (antichains) of the lattice""" # dummy expensive function might use dit or networkx functions assert self.n < 5, ( "antichain(n): number of sources should be less than 5") @@ -50,57 +53,63 @@ def antichain(self): # check if alpha is an antichain for a in list(alpha): for b in list(alpha): - if a < b and self.comparable(frozenset(a),frozenset(b)): + if a < b and self.comparable(frozenset(a), frozenset(b)): flag = 0 - #^ if - #^ for b - #^ for a - if flag: achain.append(alpha) - #^ for alpha - #^ for r + # ^if + # ^for b + # ^for a + if flag: + achain.append(alpha) + # ^for alpha + # ^for r return achain - #^ antichain() + # ^antichain() def children(self, alpha, achain): - """Enumerates the direct nodes (antichains) ordered by the node - (antichain) 'alpha'""" + """Enumerates the direct nodes (antichains) ordered by the node + (antichain) 'alpha'""" chl = [] - downset = [beta for beta in achain if self.less_than(beta,alpha) and beta != alpha] + downset = [beta for beta in achain if self.less_than( + beta, alpha) and beta != alpha] for beta in downset: - if all(not self.less_than(beta,gamma) for gamma in downset if gamma != beta): + if all(not self.less_than(beta, gamma) for gamma in downset if gamma != beta): chl.append(beta) - #^ if - #^ for beta + # ^if + # ^for beta return chl - #^ children() + # ^children() -#^ Lattice() +# ^Lattice() -#--------------- +# --------------- # pi^+(t:alpha) # and -# pi^-(t:alpha) -#--------------- +# pi^-(t:alpha) +# --------------- + def powerset(m): lis = [i for i in range(1, m+1)] return chain.from_iterable(combinations(lis, r) - for r in range(1,len(lis) + 1) ) -#^ powerset() + for r in range(1, len(lis) + 1)) +# ^powerset() + def marg(pdf, rlz, uset): """compute the marginal probability mass e.g. p(t,s1,s2)""" - idxs = [ idx - 1 for idx in list(uset)] + idxs = [idx - 1 for idx in list(uset)] summ = 0. for k in pdf.keys(): - if all(k[idx] == rlz[idx] for idx in idxs): summ += pdf[k] - #^ for + if all(k[idx] == rlz[idx] for idx in idxs): + summ += pdf[k] + # ^for return summ -#^ marg() - +# ^marg() + + def prob(n, pdf, rlz, gamma, target=False): - """Compute the Probability mass on a lattice node + """Compute the Probability mass on a lattice node e.g. node = {1}{2} p(s1 \cup s2) using inclusion-exclusion""" m = len(gamma) pset = powerset(m) @@ -110,97 +119,104 @@ def prob(n, pdf, rlz, gamma, target=False): uset = frozenset((n+1,)) else: uset = frozenset(()) - #^ if + # ^if for i in list(idxs): uset |= frozenset(gamma[i-1]) - #^ for i + # ^for i summ += (-1)**(len(idxs) + 1) * marg(pdf, rlz, uset) - #^ for idxs + # ^for idxs return summ -#^ prob() +# ^prob() + def differs(n, pdf, rlz, alpha, chl, target=False): """Compute the probability mass difference - For a node 'alpha' and any child gamma of alpha it computes p(gamma) - + For a node 'alpha' and any child gamma of alpha it computes p(gamma) - p(alpha) for all gamma""" if chl == [] and target: base = prob(n, pdf, rlz, [()], target)/prob(n, pdf, rlz, alpha, target) else: base = prob(n, pdf, rlz, alpha, target) - #^ if bottom + # ^if bottom temp_diffs = [prob(n, pdf, rlz, gamma, target) - base for gamma in chl] temp_diffs.sort() return [base] + temp_diffs -#^ differs() +# ^differs() + def sgn(num_chld): - """Recurrsive function that generates the signs (+ or -) for the + """Recurrsive function that generates the signs (+ or -) for the inclusion-exculison principle""" if num_chld == 0: return np.array([+1]) else: - return np.concatenate((sgn(num_chld - 1), -sgn(num_chld - 1)), axis=None) - #^ if bottom -#^sgn() - + return np.concatenate( + (sgn(num_chld - 1), -sgn(num_chld - 1)), axis=None) + # ^if bottom +# ^ sgn() + + def vec(num_chld, diffs): - """Recurrsive function that returns a numpy vector used in evaluating + """Recurrsive function that returns a numpy vector used in evaluating the moebuis inversion (compute the PPID atoms) - Args: + Args: num_chld: int - the number of the children of alpha: (gamma_1,..., - gamma_{num_chld}) - diffs: list of floats - vector of probability differences (d_i)_i - where d_i = p(gamma_i) - p(alpha) and d_0 = p(alpha) + gamma_{num_chld}) + diffs: list of floats - vector of probability differences (d_i)_i + where d_i = p(gamma_i) - p(alpha) and d_0 = p(alpha) """ # print(diffs) if num_chld == 0: return np.array([diffs[0]]) else: - temp = vec(num_chld - 1, diffs) + diffs[num_chld]*np.ones(2**(num_chld - 1)) + temp = vec(num_chld - 1, diffs) + \ + diffs[num_chld]*np.ones(2**(num_chld - 1)) return np.concatenate((vec(num_chld - 1, diffs), temp), axis=None) - #^ if bottom -#^ vec() + # ^if bottom +# ^vec() + def pi_plus(n, pdf, rlz, alpha, chld, achain): """Compute the informative PPID """ diffs = differs(n, pdf, rlz, alpha, chld[tuple(alpha)], False) - return np.dot(sgn(len(chld[alpha])), -np.log2(vec(len(chld[alpha]),diffs))) -#^ pi_plus() + return np.dot(sgn(len(chld[alpha])), -np.log2(vec(len(chld[alpha]), diffs))) +# ^pi_plus() + def pi_minus(n, pdf, rlz, alpha, chld, achain): """Compute the misinformative PPID """ diffs = differs(n, pdf, rlz, alpha, chld[alpha], True) if chld[alpha] == []: - return np.dot(sgn(len(chld[alpha])), np.log2(vec(len(chld[alpha]),diffs))) + return np.dot(sgn(len(chld[alpha])), np.log2(vec(len(chld[alpha]), diffs))) else: - return np.dot(sgn(len(chld[alpha])), -np.log2(vec(len(chld[alpha]),diffs))) - #^ if bottom -#^ pi_minus() + return np.dot(sgn(len(chld[alpha])), -np.log2(vec(len(chld[alpha]), diffs))) + # ^if bottom +# ^pi_minus() def pid(n, pdf_orig, chld, achain, printing=False): """Estimate partial information decomposition for 'n' inputs and one output - + Implementation of the partial information decomposition (PID) estimator for discrete data. The estimator finds shared information, unique information - and synergistic information between the two, three, or four inputs with + and synergistic information between the two, three, or four inputs with respect to the output t. - - P.S. The implementation can be extended to any number 'n' of variables if + + P.S. The implementation can be extended to any number 'n' of variables if their corresponding redundancy lattice is provided ( check Lattice() ) Args: n : int - number of pid sources - pdf_orig : dict - the original joint distribution of the inputs and - the output (realizations are the keys). It doesn't have + pdf_orig : dict - the original joint distribution of the inputs and + the output (realizations are the keys). It doesn't have to be a full support distribution, i.e., it can contain - realizations with 'zero' mass probability - chld : dict - list of children for each node in the redundancy + realizations with 'zero' mass probability + chld : dict - list of children for each node in the redundancy lattice (nodes are the keys) - achain : tuple - tuple of all the nodes (antichains) in the + achain : tuple - tuple of all the nodes (antichains) in the redundacy lattice printing: Bool - If true prints the results using PrettyTables - + Returns: tuple pointwise decomposition, averaged decomposition @@ -214,35 +230,35 @@ def pid(n, pdf_orig, chld, achain, printing=False): if __debug__: sum_p = 0. - for k,v in pdf_orig.items(): + for k, v in pdf_orig.items(): assert type(k) is tuple, ( 'pid_goettingen.pid(pdf, chld, achain): pdf keys must be tuples') assert len(k) < 6, ( 'pid_goettingen.pid(pdf, chld, achain): pdf keys must be tuples' 'of length at most 5') - assert type(v) is float or ( type(v)==int and v==0 ), ( + assert type(v) is float or (type(v) == int and v == 0), ( 'pid_goettingen.pid(pdf, chld, achain): pdf values must be floats') - assert v >-.1, ( + assert v > -.1, ( 'pid_goettingen.pid(pdf, chld, achain): pdf values must be nonnegative') sum_p += v - #^ for + # ^for assert abs(sum_p - 1) < 1.e-7, ( 'pid_goettingen.pid(pdf, chld, achain): pdf keys must sum up to 1' '(tolerance of precision is 1.e-7)') - #^ if debug + # ^if debug assert type(printing) is bool, ( 'pid_goettingen.pid(pdf, chld, achain, printing): printing must be a bool') # Remove the impossible realization - pdf = {k:v for k,v in pdf_orig.items() if v > 1.e-300 } + pdf = {k: v for k, v in pdf_orig.items() if v > 1.e-300} # Initialize the output where # ptw = { rlz -> { alpha -> pi_alpha } } # avg = { alpha -> PI_alpha } ptw = dict() - #avg = defaultdict(lambda : [0.,0.,0.]) + # avg = defaultdict(lambda : [0.,0.,0.]) avg = dict() # Compute and store the (+, -, +-) atoms for rlz in pdf.keys(): @@ -254,21 +270,21 @@ def pid(n, pdf_orig, chld, achain, printing=False): # avg[alpha][0] += pdf[rlz]*ptw[rlz][alpha][0] # avg[alpha][1] += pdf[rlz]*ptw[rlz][alpha][1] # avg[alpha][2] += pdf[rlz]*ptw[rlz][alpha][2] - - #^ for - #^ for - # compute and store the average of the (+, -, +-) atoms + + # ^for + # ^for + # compute and store the average of the (+, -, +-) atoms for alpha in achain: avgplus = 0. avgminus = 0. avgdiff = 0. for rlz in pdf.keys(): - avgplus += pdf[rlz]*ptw[rlz][alpha][0] + avgplus += pdf[rlz]*ptw[rlz][alpha][0] avgminus += pdf[rlz]*ptw[rlz][alpha][1] - avgdiff += pdf[rlz]*ptw[rlz][alpha][2] + avgdiff += pdf[rlz]*ptw[rlz][alpha][2] avg[alpha] = (avgplus, avgminus, avgdiff) - #^ for - #^ for + # ^for + # ^for # Print the result if asked if printing: @@ -282,21 +298,23 @@ def pid(n, pdf_orig, chld, achain, printing=False): stalpha += "{" for i in a: stalpha += str(i) - #^ for i - stalpha += "}" - #^ for a - if count == 0: table.add_row( [str(rlz), stalpha, - str(ptw[rlz][alpha][0]), - str(ptw[rlz][alpha][1]), - str(ptw[rlz][alpha][2])] ) - else: table.add_row( [" ", stalpha, - str(ptw[rlz][alpha][0]), - str(ptw[rlz][alpha][1]), - str(ptw[rlz][alpha][2])] ) - count += 1 - #^ for alpha + # ^for i + stalpha += "}" + # ^for a + if count == 0: + table.add_row([str(rlz), stalpha, + str(ptw[rlz][alpha][0]), + str(ptw[rlz][alpha][1]), + str(ptw[rlz][alpha][2])]) + else: + table.add_row([" ", stalpha, + str(ptw[rlz][alpha][0]), + str(ptw[rlz][alpha][1]), + str(ptw[rlz][alpha][2])]) + count += 1 + # ^for alpha table.add_row(["*", "*", "*", "*", "*"]) - #^ for realization + # ^for realization table.add_row(["-", "-", "-", "-", "-"]) count = 0 @@ -306,23 +324,23 @@ def pid(n, pdf_orig, chld, achain, printing=False): stalpha += "{" for i in a: stalpha += str(i) - #^ for i - stalpha += "}" - #^ for a - if count == 0: table.add_row( ["avg", stalpha, - str(avg[alpha][0]), - str(avg[alpha][1]), - str(avg[alpha][2])] ) - else: table.add_row( [" ", stalpha, - str(avg[alpha][0]), - str(avg[alpha][1]), - str(avg[alpha][2])] ) + # ^for i + stalpha += "}" + # ^for a + if count == 0: + table.add_row(["avg", stalpha, + str(avg[alpha][0]), + str(avg[alpha][1]), + str(avg[alpha][2])]) + else: + table.add_row([" ", stalpha, + str(avg[alpha][0]), + str(avg[alpha][1]), + str(avg[alpha][2])]) count += 1 - #^ for alpha + # ^for alpha print(table) - #^ if printing - - return ptw, avg -#^ pid() + # ^if printing -# EOF + return ptw, avg +# ^pid() diff --git a/idtxl/results.py b/idtxl/results.py index 9619fab1..89f19b9c 100644 --- a/idtxl/results.py +++ b/idtxl/results.py @@ -713,7 +713,7 @@ def print_edge_list(self, weights, fdr=True): self._print_edge_list(adjacency_matrix, weights=weights) -class ResultsPartialInformationDecomposition(ResultsNetworkAnalysis): +class ResultsPID(ResultsNetworkAnalysis): """Store results of Partial Information Decomposition (PID) analysis. Provide a container for results of Partial Information Decomposition (PID) @@ -778,15 +778,15 @@ def get_single_target(self, target): (result['selected_vars_sources']) or via dot-notation (result.selected_vars_sources). """ - return super(ResultsPartialInformationDecomposition, + return super(ResultsPID, self).get_single_target(target, fdr=False) -class ResultsMultivariatePartialInformationDecomposition(ResultsNetworkAnalysis): - """Store results of Multivariate Partial Information Decomposition (PID) +class ResultsMultivariatePID(ResultsNetworkAnalysis): + """Store results of Multivariate Partial Information Decomposition (PID) analysis. - Provide a container for results of Multivariate Partial Information + Provide a container for results of Multivariate Partial Information Decomposition (PID) algorithms. Note that for convenience all dictionaries in this class can additionally @@ -828,11 +828,11 @@ def get_single_target(self, target): - source_i : tuple - source variable i - selected_vars_sources : list of tuples - source variables used in PID estimation - - avg : dict - avg pid {alpha -> float} where alpha is a redundancy + - avg : dict - avg pid {alpha -> float} where alpha is a redundancy + lattice node + - ptw : dict of dicts - ptw pid {rlz -> {alpha -> float} } where rlz is + a single realisation of the random variables and alpha is a redundancy lattice node - - ptw : dict of dicts - ptw pid {rlz -> {alpha -> float} } where rlz is - a single realisation of the random variables and alpha is a redundancy - lattice node - current_value : tuple - current value used for analysis, described by target and sample index in the data - [estimator-specific settings] @@ -848,7 +848,7 @@ def get_single_target(self, target): (result['selected_vars_sources']) or via dot-notation (result.selected_vars_sources). """ - return super(ResultsMultivariatePartialInformationDecomposition, + return super(ResultsMultivariatePID, self).get_single_target(target, fdr=False) diff --git a/test/systemtest_partial_information_decomposition.py b/test/systemtest_bivariate_pid.py similarity index 93% rename from test/systemtest_partial_information_decomposition.py rename to test/systemtest_bivariate_pid.py index 7e5b932a..24208a96 100644 --- a/test/systemtest_partial_information_decomposition.py +++ b/test/systemtest_bivariate_pid.py @@ -1,8 +1,8 @@ """Provide unit tests for high-level PID estimation.""" import time as tm import numpy as np -from idtxl.partial_information_decomposition import ( - PartialInformationDecomposition) +from idtxl.bivariate_pid import ( + BivariatePID) from idtxl.data import Data import idtxl.idtxl_utils as utils @@ -18,7 +18,7 @@ def test_pid_xor_data(): # Run Tartu estimator settings = {'pid_estimator': 'TartuPID', 'lags_pid': [0, 0]} - pid = PartialInformationDecomposition() + pid = BivariatePID() tic = tm.time() est_tartu = pid.analyse_single_target(settings, data=data, target=2, sources=[0, 1]) @@ -36,7 +36,7 @@ def test_pid_xor_data(): 'max_iters': 1000, 'pid_estimator': 'SydneyPID', 'lags_pid': [0, 0]} - pid = PartialInformationDecomposition() + pid = BivariatePID() tic = tm.time() est_sydney = pid.analyse_single_target(settings, data=data, target=2, sources=[0, 1]) diff --git a/test/test_partial_information_decomposition.py b/test/test_bivariate_pid.py similarity index 95% rename from test/test_partial_information_decomposition.py rename to test/test_bivariate_pid.py index 74e6bef0..6235af36 100644 --- a/test/test_partial_information_decomposition.py +++ b/test/test_bivariate_pid.py @@ -1,8 +1,8 @@ """Provide unit tests for high-level PID estimation.""" import pytest import numpy as np -from idtxl.partial_information_decomposition import ( - PartialInformationDecomposition) +from idtxl.bivariate_pid import ( + BivariatePID) from idtxl.data import Data from test_estimators_pid import optimiser_missing, float128_not_available @@ -11,7 +11,7 @@ def test_pid_user_input(): """Test if user input is handled correctly.""" # Test missing estimator name - pid = PartialInformationDecomposition() + pid = BivariatePID() with pytest.raises(RuntimeError): pid.analyse_single_target(settings={}, data=Data(), target=0, sources=[1, 2]) @@ -41,7 +41,7 @@ def test_pid_user_input(): # Test two-tailed significance test settings = {'pid_estimator': 'TartuPID', 'tail': 'two', 'lags_pid': [0, 0]} - pid = PartialInformationDecomposition() + pid = BivariatePID() with pytest.raises(RuntimeError): # Test incorrect number of sources pid.analyse_single_target(settings=settings, data=data, target=2, @@ -79,7 +79,7 @@ def test_network_analysis(): data = Data(np.vstack((x, y, z)), 'ps', normalise=False) # Run Tartu estimator - pid = PartialInformationDecomposition() + pid = BivariatePID() settings = {'pid_estimator': 'TartuPID', 'tail': 'two', 'lags_pid': [[0, 0], [0, 0]]} est_tartu = pid.analyse_network(settings=settings, @@ -119,7 +119,7 @@ def test_analyse_single_target(): data = Data(np.vstack((x, y, z)), 'ps', normalise=False) # Run Tartu estimator - pid = PartialInformationDecomposition() + pid = BivariatePID() settings = {'pid_estimator': 'TartuPID', 'tail': 'two', 'lags_pid': [0, 0]} diff --git a/test/test_multivariate_partial_information_decomposition.py b/test/test_multivariate_pid.py similarity index 93% rename from test/test_multivariate_partial_information_decomposition.py rename to test/test_multivariate_pid.py index 385b6270..a7029e47 100644 --- a/test/test_multivariate_partial_information_decomposition.py +++ b/test/test_multivariate_pid.py @@ -1,16 +1,16 @@ """Provide unit tests for high-level Multivariate PID estimation.""" import pytest import numpy as np -from idtxl.multivariate_partial_information_decomposition import ( - MultivariatePartialInformationDecomposition) +from idtxl.multivariate_pid import ( + MultivariatePID) from idtxl.data import Data def test_pid_user_input(): """Test if user input is handled correctly.""" # Test missing estimator name - pid = MultivariatePartialInformationDecomposition() - settings = {'verbose': False} + pid = MultivariatePID() + settings = {'verbose': False} with pytest.raises(RuntimeError): pid.analyse_single_target(settings=settings, data=Data(), target=0, sources=[1, 2]) @@ -48,7 +48,7 @@ def test_pid_user_input(): # Test two-tailed significance test settings = {'pid_estimator': 'SxPID', 'tail': 'two', 'lags_pid': [0, 0], 'verbose': False} - pid = MultivariatePartialInformationDecomposition() + pid = MultivariatePID() # Test incorrect number of sources with pytest.raises(RuntimeError): @@ -91,7 +91,7 @@ def test_pid_user_input(): data = Data(np.vstack((s1, s2, s3, s4, s5, z)), 'ps', normalise=False) settings = {'pid_estimator': 'SxPID', 'tail': 'two', 'lags_pid': [0, 0, 0, 0], 'verbose': False} - pid = MultivariatePartialInformationDecomposition() + pid = MultivariatePID() # Test number of sources over limit (N=4) with pytest.raises(RuntimeError): @@ -109,7 +109,7 @@ def test_network_analysis(): data = Data(np.vstack((x, y, z)), 'ps', normalise=False) # Run Goettingen estimator - pid = MultivariatePartialInformationDecomposition() + pid = MultivariatePID() settings = {'pid_estimator': 'SxPID', 'tail': 'two', 'lags_pid': [[0, 0], [0, 0]]} est_goettingen = pid.analyse_network(settings=settings, @@ -135,7 +135,7 @@ def test_network_analysis(): quad_data = Data(np.vstack((s1, s2, s3, s4, quad_target)), 'ps', normalise=False) # Trivariate - pid = MultivariatePartialInformationDecomposition() + pid = MultivariatePID() tri_settings = {'pid_estimator': 'SxPID', 'tail': 'two', 'lags_pid': [[0, 0, 0], [0, 0, 0]]} est_goettingen = pid.analyse_network(settings=tri_settings, @@ -166,7 +166,7 @@ def test_analyse_single_target(): data = Data(np.vstack((x, y, z)), 'ps', normalise=False) # Run Goettingen estimator - pid = MultivariatePartialInformationDecomposition() + pid = MultivariatePID() settings = {'pid_estimator': 'SxPID', 'tail': 'two', 'lags_pid': [0, 0],