Skip to content

Commit

Permalink
Cast return values of JIDT estimators to float
Browse files Browse the repository at this point in the history
Cast return values of JIDT estimators to float. Without this, return
values may be JPype types. This is inconvenient, for example, when
saving IDTxl results to disk and loading them in a new session, JPype
types require that a JVM is running to load the results. If the JVM
does not run, pickle fails to load the results with a hard to
understand error message.
  • Loading branch information
pwollstadt committed Apr 27, 2022
1 parent de42bfb commit 6f3916b
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions idtxl/estimators_jidt.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ def estimate(self, var1, var2, conditional=None):
if self.settings['local_values']:
return np.array(self.calc.computeLocalOfPreviousObservations())
else:
return self.calc.computeAverageLocalOfObservations()
return float(self.calc.computeAverageLocalOfObservations())


class JidtDiscreteCMI(JidtDiscrete):
Expand Down Expand Up @@ -627,7 +627,7 @@ def estimate(self, var1, var2, conditional=None, return_calc=False):
jp.JArray(jp.JInt, 1)(var2.tolist()),
jp.JArray(jp.JInt, 1)(conditional.tolist())))
else:
result = self.calc.computeAverageLocalOfObservations()
result = float(self.calc.computeAverageLocalOfObservations())
if return_calc:
return (result, self.calc)
else:
Expand Down Expand Up @@ -784,7 +784,7 @@ def estimate(self, var1, var2, return_calc=False):
jp.JArray(jp.JInt, 1)(var1.tolist()),
jp.JArray(jp.JInt, 1)(var2.tolist())))
else:
result = self.calc.computeAverageLocalOfObservations()
result = float(self.calc.computeAverageLocalOfObservations())
if return_calc:
return (result, self.calc)
else:
Expand Down Expand Up @@ -908,7 +908,7 @@ def estimate(self, var1, var2):
if self.settings['local_values']:
return np.array(self.calc.computeLocalOfPreviousObservations())
else:
return self.calc.computeAverageLocalOfObservations()
return float(self.calc.computeAverageLocalOfObservations())


class JidtKraskovAIS(JidtKraskov):
Expand Down Expand Up @@ -999,7 +999,7 @@ def estimate(self, process):
if self.settings['local_values']:
return np.array(self.calc.computeLocalOfPreviousObservations())
else:
return self.calc.computeAverageLocalOfObservations()
return float(self.calc.computeAverageLocalOfObservations())


class JidtDiscreteAIS(JidtDiscrete):
Expand Down Expand Up @@ -1125,7 +1125,7 @@ def estimate(self, process, return_calc=False):
result = np.array(self.calc.computeLocalFromPreviousObservations(
jp.JArray(jp.JInt, 1)(process.tolist())))
else:
result = self.calc.computeAverageLocalOfObservations()
result = float(self.calc.computeAverageLocalOfObservations())
if return_calc:
return (result, self.calc)
else:
Expand Down Expand Up @@ -1228,7 +1228,7 @@ def estimate(self, process):
if self.settings['local_values']:
return np.array(self.calc.computeLocalOfPreviousObservations())
else:
return self.calc.computeAverageLocalOfObservations()
return float(self.calc.computeAverageLocalOfObservations())


class JidtGaussianMI(JidtGaussian):
Expand Down Expand Up @@ -1300,7 +1300,7 @@ def estimate(self, var1, var2):
if self.settings['local_values']:
return np.array(self.calc.computeLocalOfPreviousObservations())
else:
return self.calc.computeAverageLocalOfObservations()
return float(self.calc.computeAverageLocalOfObservations())


class JidtGaussianCMI(JidtGaussian):
Expand Down Expand Up @@ -1387,7 +1387,7 @@ def estimate(self, var1, var2, conditional=None):
if self.settings['local_values']:
return np.array(self.calc.computeLocalOfPreviousObservations())
else:
return self.calc.computeAverageLocalOfObservations()
return float(self.calc.computeAverageLocalOfObservations())

def get_analytic_distribution(self, var1, var2, conditional=None):
"""Return a JIDT AnalyticNullDistribution object.
Expand Down Expand Up @@ -1507,7 +1507,7 @@ def estimate(self, source, target):
if self.settings['local_values']:
return np.array(self.calc.computeLocalOfPreviousObservations())
else:
return self.calc.computeAverageLocalOfObservations()
return float(self.calc.computeAverageLocalOfObservations())


class JidtDiscreteTE(JidtDiscrete):
Expand Down Expand Up @@ -1652,7 +1652,7 @@ def estimate(self, source, target, return_calc=False):
jp.JArray(jp.JInt, 1)(source.tolist()),
jp.JArray(jp.JInt, 1)(target.tolist())))
else:
result = self.calc.computeAverageLocalOfObservations()
result = float(self.calc.computeAverageLocalOfObservations())
if return_calc:
return (result, self.calc)
else:
Expand Down Expand Up @@ -1765,7 +1765,7 @@ def estimate(self, source, target):
if self.settings['local_values']:
return np.array(self.calc.computeLocalOfPreviousObservations())
else:
return self.calc.computeAverageLocalOfObservations()
return float(self.calc.computeAverageLocalOfObservations())


def common_estimate_surrogates_analytic(estimator, n_perm=200, **data):
Expand Down

0 comments on commit 6f3916b

Please sign in to comment.