From 1b411ac29a47c5db7a7ba0e0e4e95e50e2c4f13d Mon Sep 17 00:00:00 2001 From: JHoelli Date: Tue, 19 Sep 2023 08:01:40 +0200 Subject: [PATCH] save --- .../counterfactual/TSEvo/Problem.py | 12 ++++++++---- .../InterpretabilityModels/counterfactual/TSEvoCF.py | 10 +++++----- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/TSInterpret/InterpretabilityModels/counterfactual/TSEvo/Problem.py b/TSInterpret/InterpretabilityModels/counterfactual/TSEvo/Problem.py index 8b82c51..b2c2c92 100644 --- a/TSInterpret/InterpretabilityModels/counterfactual/TSEvo/Problem.py +++ b/TSInterpret/InterpretabilityModels/counterfactual/TSEvo/Problem.py @@ -104,7 +104,7 @@ def get_prediction_torch( full=False, ): if self.mode == "time": - individual = np.swapaxes(individual,-1,-2).reshape( + individual = np.swapaxes(individual, -1, -2).reshape( 1, individual.shape[-1], individual.shape[-2] ) else: @@ -125,7 +125,9 @@ def get_prediction_torch( def get_prediction_tensorflow(self, individual, full=False): individual = np.array(individual.tolist(), dtype=np.float64) - output = self.model.predict(np.swapaxes(individual,-1,-2).reshape(1, self.window, -1), verbose=0) + output = self.model.predict( + np.swapaxes(individual, -1, -2).reshape(1, self.window, -1), verbose=0 + ) idx = output.argmax() if full: @@ -135,7 +137,7 @@ def get_prediction_tensorflow(self, individual, full=False): def get_prediction_target_torch(self, individual, full=False, binary=False): individual = np.array(individual.tolist(), dtype=np.float64) if self.mode == "time": - individual = np.swapaxes(individual,-1,-2).reshape( + individual = np.swapaxes(individual, -1, -2).reshape( 1, individual.shape[-1], individual.shape[-2] ) else: @@ -151,7 +153,9 @@ def get_prediction_target_torch(self, individual, full=False, binary=False): def get_prediction_target_tensorflow(self, individual, full=False): individual = np.array(individual.tolist(), dtype=np.float64) - output = self.model.predict(np.swapaxes(individual,-1,-2).reshape(1, self.window, -1), verbose=0) + output = self.model.predict( + np.swapaxes(individual, -1, -2).reshape(1, self.window, -1), verbose=0 + ) idx = output.argmax() if full: return idx, output[0] diff --git a/TSInterpret/InterpretabilityModels/counterfactual/TSEvoCF.py b/TSInterpret/InterpretabilityModels/counterfactual/TSEvoCF.py index 00393f2..08bc8f5 100644 --- a/TSInterpret/InterpretabilityModels/counterfactual/TSEvoCF.py +++ b/TSInterpret/InterpretabilityModels/counterfactual/TSEvoCF.py @@ -80,9 +80,9 @@ def explain( if len(original_x.shape) < 3: original_x = np.array([original_x]) if self.backend == "TF" or self.mode == "time": - original_x = np.swapaxes(original_x,2,1) #original_x.reshape( - #original_x.shape[0], original_x.shape[2], original_x.shape[1] - #) + original_x = np.swapaxes(original_x, 2, 1) # original_x.reshape( + # original_x.shape[0], original_x.shape[2], original_x.shape[1] + # ) neighborhood = [] if target_y is not None: if not type(target_y) == int: @@ -95,9 +95,9 @@ def explain( else: reference_set = self.x[np.where(self.y != original_y)] if self.backend == "TF" or self.mode == "time": - reference_set = np.swapaxes(reference_set,1,2)#.reshape( + reference_set = np.swapaxes(reference_set, 1, 2) # .reshape( # -1, original_x.shape[1], original_x.shape[2] - #) + # ) if len(reference_set.shape) == 2: reference_set = reference_set.reshape(-1, 1, reference_set.shape[-1])