Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
  • Loading branch information
JHoelli committed Sep 19, 2023
1 parent fc4fb68 commit 1b411ac
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def get_prediction_torch(
full=False,
):
if self.mode == "time":
individual = np.swapaxes(individual,-1,-2).reshape(
individual = np.swapaxes(individual, -1, -2).reshape(
1, individual.shape[-1], individual.shape[-2]
)
else:
Expand All @@ -125,7 +125,9 @@ def get_prediction_torch(

def get_prediction_tensorflow(self, individual, full=False):
individual = np.array(individual.tolist(), dtype=np.float64)
output = self.model.predict(np.swapaxes(individual,-1,-2).reshape(1, self.window, -1), verbose=0)
output = self.model.predict(
np.swapaxes(individual, -1, -2).reshape(1, self.window, -1), verbose=0
)
idx = output.argmax()

if full:
Expand All @@ -135,7 +137,7 @@ def get_prediction_tensorflow(self, individual, full=False):
def get_prediction_target_torch(self, individual, full=False, binary=False):
individual = np.array(individual.tolist(), dtype=np.float64)
if self.mode == "time":
individual = np.swapaxes(individual,-1,-2).reshape(
individual = np.swapaxes(individual, -1, -2).reshape(
1, individual.shape[-1], individual.shape[-2]
)
else:
Expand All @@ -151,7 +153,9 @@ def get_prediction_target_torch(self, individual, full=False, binary=False):

def get_prediction_target_tensorflow(self, individual, full=False):
individual = np.array(individual.tolist(), dtype=np.float64)
output = self.model.predict(np.swapaxes(individual,-1,-2).reshape(1, self.window, -1), verbose=0)
output = self.model.predict(
np.swapaxes(individual, -1, -2).reshape(1, self.window, -1), verbose=0
)
idx = output.argmax()
if full:
return idx, output[0]
Expand Down
10 changes: 5 additions & 5 deletions TSInterpret/InterpretabilityModels/counterfactual/TSEvoCF.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ def explain(
if len(original_x.shape) < 3:
original_x = np.array([original_x])
if self.backend == "TF" or self.mode == "time":
original_x = np.swapaxes(original_x,2,1) #original_x.reshape(
#original_x.shape[0], original_x.shape[2], original_x.shape[1]
#)
original_x = np.swapaxes(original_x, 2, 1) # original_x.reshape(
# original_x.shape[0], original_x.shape[2], original_x.shape[1]
# )
neighborhood = []
if target_y is not None:
if not type(target_y) == int:
Expand All @@ -95,9 +95,9 @@ def explain(
else:
reference_set = self.x[np.where(self.y != original_y)]
if self.backend == "TF" or self.mode == "time":
reference_set = np.swapaxes(reference_set,1,2)#.reshape(
reference_set = np.swapaxes(reference_set, 1, 2) # .reshape(
# -1, original_x.shape[1], original_x.shape[2]
#)
# )
if len(reference_set.shape) == 2:
reference_set = reference_set.reshape(-1, 1, reference_set.shape[-1])

Expand Down

0 comments on commit 1b411ac

Please sign in to comment.