Skip to content

Commit

Permalink
Add PEP8 and documentation for Goettingen PID
Browse files Browse the repository at this point in the history
Add PEP8 and documentation for Goettingen PID. Make PrettyTable import
optional and issue a warning similar to the ECOS estimator for the Tartu
PID. Otherwise, any analysis crashes because it tries to import somthing
that is optional and thus probably not installed. Rename current PID
estimators to BivariatePID and new estimator to MultivariatePID to be
consistent with MI and TE class names.

Fix import error in multivariate PID estimator.
  • Loading branch information
pwollstadt committed Apr 14, 2020
1 parent 64ed304 commit 4c4ea88
Show file tree
Hide file tree
Showing 10 changed files with 262 additions and 244 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Import classes
import numpy as np
from idtxl.partial_information_decomposition import (
PartialInformationDecomposition)
from idtxl.bivariate_pid import BivariatePID
from idtxl.data import Data

# a) Generate test data
Expand All @@ -13,7 +12,7 @@
data = Data(np.vstack((x, y, z)), 'ps', normalise=False)

# b) Initialise analysis object and define settings for both PID estimators
pid = PartialInformationDecomposition()
pid = BivariatePID()
settings_tartu = {'pid_estimator': 'TartuPID', 'lags_pid': [0, 0]}
settings_sydney = {
'alph_s1': alph,
Expand All @@ -30,7 +29,7 @@
settings=settings_tartu, data=data, target=2, sources=[0, 1])

# d) Run Sydney estimator
pid = PartialInformationDecomposition()
pid = BivariatePID()
results_sydney = pid.analyse_single_target(
settings=settings_sydney, data=data, target=2, sources=[0, 1])

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Import classes
import numpy as np
from idtxl.multivariate_partial_information_decomposition import (
MultivariatePartialInformationDecomposition)
from idtxl.multivariate_pid import MultivariatePID
from idtxl.data import Data

# a) Generate test data
Expand All @@ -13,7 +12,7 @@
data = Data(np.vstack((x, y, z)), 'ps', normalise=False)

# b) Initialise analysis object and define settings for SxPID estimators
pid = MultivariatePartialInformationDecomposition()
pid = MultivariatePID()
settings_SxPID = {'pid_estimator': 'SxPID', 'lags_pid': [0, 0]}

# c) Run Goettingen estimator
Expand All @@ -39,13 +38,13 @@

# Some special Examples

# Pointwise Unique
# Pointwise Unique
x = np.asarray([0, 1, 0, 2])
y = np.asarray([1, 0, 2, 0])
z = np.asarray([1, 1, 2, 2])
data = Data(np.vstack((x, y, z)), 'ps', normalise=False)

pid = MultivariatePartialInformationDecomposition()
pid = MultivariatePID()
settings_SxPID = {'pid_estimator': 'SxPID', 'lags_pid': [0, 0]}

results_SxPID = pid.analyse_single_target(
Expand All @@ -66,13 +65,13 @@
results_SxPID.get_single_target(2)['avg'][((1,2,),)][2],
0.))

# Redundancy Error
# Redundancy Error
x = np.asarray([0, 0, 0, 1, 1, 1, 0, 1])
y = np.asarray([0, 0, 0, 1, 1, 1, 1, 0])
z = np.asarray([0, 0, 0, 1, 1, 1, 0, 1])
data = Data(np.vstack((x, y, z)), 'ps', normalise=False)

pid = MultivariatePartialInformationDecomposition()
pid = MultivariatePID()
settings_SxPID = {'pid_estimator': 'SxPID', 'lags_pid': [0, 0]}

results_SxPID = pid.analyse_single_target(
Expand All @@ -93,14 +92,14 @@
results_SxPID.get_single_target(2)['avg'][((1,2,),)][2],
0.368))

# Three bits hash
# Three bits hash
s1 = np.asarray([0, 0, 0, 0, 1, 1, 1, 1])
s2 = np.asarray([0, 0, 1, 1, 0, 0, 1, 1])
s3 = np.asarray([0, 1, 0, 1, 0, 1, 0, 1])
z = np.asarray([0, 1, 1, 0, 1, 0, 0, 1])
data = Data(np.vstack((s1, s2, s3, z)), 'ps', normalise=False)

pid = MultivariatePartialInformationDecomposition()
pid = MultivariatePID()
settings_SxPID = {'pid_estimator': 'SxPID', 'lags_pid': [0, 0, 0]}

results_SxPID = pid.analyse_single_target(
Expand All @@ -113,7 +112,7 @@
print('Uni s2 {0:.4f}\t\t{1:.4f}'.format(
results_SxPID.get_single_target(3)['avg'][((2,),)][2], 0.3219))
print('Uni s3 {0:.4f}\t\t{1:.4f}'.format(
results_SxPID.get_single_target(3)['avg'][((3,),)][2], 0.3219))
results_SxPID.get_single_target(3)['avg'][((3,),)][2], 0.3219))
print('Synergy s1_s2_s3 {0:.4f}\t\t{1:.4f}'.format(
results_SxPID.get_single_target(3)['avg'][((1,2,3),)][2], 0.2451))
print('Synergy s1_s2 {0:.4f}\t\t{1:.4f}'.format(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
import numpy as np
from .single_process_analysis import SingleProcessAnalysis
from .estimator import find_estimator
from .results import ResultsPartialInformationDecomposition
from .results import ResultsPID


class PartialInformationDecomposition(SingleProcessAnalysis):
class BivariatePID(SingleProcessAnalysis):
"""Perform partial information decomposition for individual processes.
Perform partial information decomposition (PID) for two source processes
Expand Down Expand Up @@ -75,7 +75,7 @@ def analyse_network(self, settings, data, targets, sources):
>>> 'pid_estimator': 'SydneyPID'}
>>> targets = [0, 1, 2]
>>> sources = [[1, 2], [0, 2], [0, 1]]
>>> pid_analysis = PartialInformationDecomposition()
>>> pid_analysis = BivariatePID()
>>> results = pid_analysis.analyse_network(settings, data, targets,
>>> sources)
Expand All @@ -97,9 +97,9 @@ def analyse_network(self, settings, data, targets, sources):
[[0, 2], [1, 0]], must have the same length as targets
Returns:
ResultsPartialInformationDecomposition instance
ResultsPID instance
results of network inference, see documentation of
ResultsPartialInformationDecomposition()
ResultsPID()
"""
# Set defaults for PID estimation.
settings.setdefault('verbose', True)
Expand All @@ -112,7 +112,7 @@ def analyse_network(self, settings, data, targets, sources):
list_of_lags = settings['lags_pid']

# Perform PID estimation for each target individually
results = ResultsPartialInformationDecomposition(
results = ResultsPID(
n_nodes=data.n_processes,
n_realisations=data.n_realisations(),
normalised=data.normalise)
Expand Down Expand Up @@ -158,7 +158,7 @@ def analyse_single_target(self, settings, data, target, sources):
>>> 'max_iters': 1000,
>>> 'pid_calc_name': 'SydneyPID',
>>> 'lags_pid': [2, 3]}
>>> pid_analysis = PartialInformationDecomposition()
>>> pid_analysis = BivariatePID()
>>> results = pid_analysis.analyse_single_target(settings=settings,
>>> data=data,
>>> target=0,
Expand All @@ -181,9 +181,9 @@ def analyse_single_target(self, settings, data, target, sources):
sources : list of ints
indices of the two source processes for the target
Returns: ResultsPartialInformationDecomposition instance results of
Returns: ResultsPID instance results of
network inference, see documentation of
ResultsPartialInformationDecomposition()
ResultsPID()
"""
# Check input and initialise values for analysis.
self._initialise(settings, data, target, sources)
Expand All @@ -192,7 +192,7 @@ def analyse_single_target(self, settings, data, target, sources):
self._calculate_pid(data)

# Add analyis info.
results = ResultsPartialInformationDecomposition(
results = ResultsPID(
n_nodes=data.n_processes,
n_realisations=data.n_realisations(self.current_value),
normalised=data.normalise)
Expand Down
66 changes: 34 additions & 32 deletions idtxl/estimators_multivariate_pid.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,39 @@
"""Multivariate Partical information decomposition for discrete random variables.
This module provides an estimator for multivariate partial information decomposition
as proposed in
This module provides an estimator for multivariate partial information
decomposition as proposed in
- Makkeh, A. & Gutknecht, A. & Wibral, M. (2020). A Differentiable measure
- Makkeh, A. & Gutknecht, A. & Wibral, M. (2020). A Differentiable measure
for shared information. 1- 27 Retrieved from
http://arxiv.org/abs/2002.03356
http://arxiv.org/abs/2002.03356
"""
import numpy as np
from . import lattices as lt
from . import pid_goettingen
from .estimator import Estimator
from .estimators_pid import _join_variables

# TODO add support for multivariate estimation for Tartu and Sydney estimator


class SxPID(Estimator):
"""Estimate partial information decomposition for multiple inputs (up to 4 inputs)
and one output
"""Estimate partial information decomposition for multiple inputs.
Implementation of the multivariate partial information decomposition (PID)
estimator for discrete data. The estimator finds shared information, unique
information and synergistic information between the multiple inputs
s1, s2, ..., sn with respect to the output t for each realization (t, s1, ..., sn)
and then average them according to their distribution weights p(t, s1, ..., sn).
Both the pointwise (on the realization level) PID and the averaged PID are
returned (see the 'return' of 'estimate()').
Implementation of the multivariate partial information decomposition (PID)
estimator for discrete data with (up to 4 inputs) and one output. The
estimator finds shared information, unique information and synergistic
information between the multiple inputs s1, s2, ..., sn with respect to the
output t for each realization (t, s1, ..., sn) and then average them
according to their distribution weights p(t, s1, ..., sn). Both the
pointwise (on the realization level) PID and the averaged PID are returned
(see the 'return' of 'estimate()').
The algorithm uses recurrsion to compute the partial information decomposition
The algorithm uses recursion to compute the partial information
decomposition.
References:
- Makkeh, A. & Wibral, M. (2020). A differentiable pointwise partial
- Makkeh, A. & Wibral, M. (2020). A differentiable pointwise partial
Information Decomposition estimator. https://github.com/Abzinger/SxPID.
Args:
Expand All @@ -39,14 +42,12 @@ class SxPID(Estimator):
- verbose : bool [optional] - print output to console
(default=False)
"""

def __init__(self, settings):
# get estimation parameters
self.settings = settings.copy()
self.settings.setdefault('verbose', False)


def is_parallel():
return False
Expand All @@ -69,7 +70,7 @@ def estimate(self, s, t):
'avg' -> {alpha -> [float, float, float]}
}
where the list of floats is ordered
where the list of floats is ordered
[informative, misinformative, informative - misinformative]
ptw stands for pointwise decomposition
avg stands for average decomposition
Expand All @@ -84,22 +85,24 @@ def estimate(self, s, t):
# children is a list of tuples
lattices = lt.lattices
num_source_vars = len(s)
retval_ptw, retval_avg = pid_goettingen.pid(num_source_vars, pdf_orig=pdf,
chld=lattices[num_source_vars][0],
achain=lattices[num_source_vars][1],
printing = self.settings['verbose'])

#AskM: Trivariate: does it make sense to name the alphas
retval_ptw, retval_avg = pid_goettingen.pid(
num_source_vars,
pdf_orig=pdf,
chld=lattices[num_source_vars][0],
achain=lattices[num_source_vars][1],
printing=self.settings['verbose'])

# TODO AskM: Trivariate: does it make sense to name the alphas
# for example shared_syn_s1_s2__syn_s1_s3 ?
results = {
'ptw': retval_ptw,
'avg': retval_avg,
}
}
return results


def _get_pdf_dict(s, t):
""""Stores the probability mass function estimated via counting to a dictionary"""
""""Write probability mass function estimated via counting to a dict."""
# Create dictionary with probability mass function
counts = dict()
n_samples = s[0].shape[0]
Expand Down Expand Up @@ -142,12 +145,12 @@ def _check_input(s, t, settings):
alph_new = len(np.unique(s[i][:, 0]))
for col in range(1, s[i].shape[1]):
alph_col = len(np.unique(s[i][:, col]))
si_joint, alph_new = _join_variables(si_joint, s[i][:, col],
alph_new, alph_col)
si_joint, alph_new = _join_variables(
si_joint, s[i][:, col], alph_new, alph_col)
settings['alph_s'+str(i+1)] = alph_new
else:
raise ValueError('Input source {0} s{0} has to be a 1D or 2D numpy '
'array.'.format(i+1))
raise ValueError('Input source {0} s{0} has to be a 1D or 2D '
'numpy array.'.format(i+1))

if t.ndim != 1:
if t.shape[1] == 1:
Expand All @@ -163,7 +166,7 @@ def _check_input(s, t, settings):
if not issubclass(s[i].dtype.type, np.integer):
raise TypeError('Input s{0} (source {0}) must be an integer numpy '
'array.'.format(i+1))
#^ for
# ^ for
if not issubclass(t.dtype.type, np.integer):
raise TypeError('Input t (target) must be an integer numpy array.')

Expand All @@ -173,4 +176,3 @@ def _check_input(s, t, settings):
raise ValueError('Number of samples s and t must be equal')

return s, t, settings

Loading

0 comments on commit 4c4ea88

Please sign in to comment.