Skip to content

Commit

Permalink
Merge pull request #45 from fzi-forschungszentrum-informatik/swpaxes
Browse files Browse the repository at this point in the history
Swpaxes
  • Loading branch information
JHoelli authored Sep 18, 2023
2 parents a248f6b + baa94c5 commit 3deccfd
Show file tree
Hide file tree
Showing 25 changed files with 617 additions and 277 deletions.
Binary file modified ClassificationModels/models/BasicMotions/ResNet
Binary file not shown.
Binary file modified ClassificationModels/models/Epilepsy/ResNet
Binary file not shown.
45 changes: 30 additions & 15 deletions TSInterpret/InterpretabilityModels/Saliency/SaliencyMethods_PTY.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,12 @@ def _getTwoStepRescaling(
newGrad = np.zeros((input_size, sequence_length))
# print("has Sliding Window", hasSliding_window_shapes)
if self.mode == "time":
input = input.reshape(-1, sequence_length, input_size)
newGrad = np.swapaxes(newGrad, -1, -2)
# print(input.shape)
# print('mode timw')
# print('inüut1',input)
# input = np.swapaxes(input,-1,-2)#.reshape(-1, sequence_length, input_size)
# print('inüut1',input)

if hasBaseline is None:
ActualGrad = (
Expand Down Expand Up @@ -283,9 +288,11 @@ def _getTwoStepRescaling(
)
# if self.mode == "time":
# ActualGrad = ActualGrad.reshape(-1, input_size, sequence_length)
if self.mode == "time":
input = np.swapaxes(
input, -1, -2
) # input.reshape(-1, input_size, sequence_length)
for t in range(sequence_length):
if self.mode == "time":
input = input.reshape(-1, input_size, sequence_length)
newInput = input.clone()
# if newInput.shape[-1] == self.NumTimeSteps:
# print('A')
Expand All @@ -294,8 +301,9 @@ def _getTwoStepRescaling(
# print('B')
# newInput[:, t,:] = assignment
if self.mode == "time":
newInput = newInput.reshape(-1, sequence_length, input_size)

newInput = np.swapaxes(
newInput, -1, -2
) # .reshape(-1, sequence_length, input_size)
if hasBaseline is None:
timeGrad_perTime = (
self.Grad.attribute(newInput, target=TestingLabel)
Expand Down Expand Up @@ -339,9 +347,9 @@ def _getTwoStepRescaling(

timeGrad_perTime = np.absolute(ActualGrad - timeGrad_perTime)
if self.mode == "time":
timeGrad_perTime = timeGrad_perTime.reshape(
-1, input_size, sequence_length
)
timeGrad_perTime = np.swapaxes(timeGrad_perTime, -1, -2) # .reshape(
# -1, input_size, sequence_length
# )
timeGrad[:, t] = np.sum(timeGrad_perTime)

timeContribution = preprocessing.minmax_scale(timeGrad, axis=1)
Expand All @@ -354,7 +362,9 @@ def _getTwoStepRescaling(
newInput = input.clone()
newInput[:, c, t] = assignment
if self.mode == "time":
newInput = newInput.reshape(-1, sequence_length, input_size)
newInput = np.swapaxes(
newInput, -1, -2
) # .reshape(-1, sequence_length, input_size)

if hasBaseline is None:
inputGrad_perInput = (
Expand Down Expand Up @@ -395,21 +405,26 @@ def _getTwoStepRescaling(
)

inputGrad_perInput = np.absolute(ActualGrad - inputGrad_perInput)
inputGrad_perInput = inputGrad_perInput.reshape(
-1, input_size, sequence_length
)
inputGrad_perInput = np.swapaxes(
inputGrad_perInput, -1, -2
) # .reshape(
# -1, input_size, sequence_length
# )
inputGrad[c, :] = np.sum(inputGrad_perInput)
featureContribution = preprocessing.minmax_scale(inputGrad, axis=0)

else:
featureContribution = np.ones((input_size, 1)) * 0.1
# print('FC',featureContribution)
newGrad = newGrad.reshape(input_size, sequence_length)
# newGrad = newGrad#.reshape(input_size, sequence_length)
if self.mode == "time":
# newGrad = newGrad.reshape(sequence_length, input_size)
newGrad = np.swapaxes(newGrad, -1, -2)
for c in range(input_size):
newGrad[c, t] = timeContribution[0, t] * featureContribution[c, 0]
if self.mode == "time":
newGrad = newGrad.reshape(sequence_length, input_size)
# print('NewGrad',newGrad.shape)
# newGrad = newGrad.reshape(sequence_length, input_size)
newGrad = np.swapaxes(newGrad, -1, -2)
return newGrad

def _givenAttGetRescaledSaliency(self, attributions, isTensor=True):
Expand Down
27 changes: 19 additions & 8 deletions TSInterpret/InterpretabilityModels/Saliency/SaliencyMethods_TF.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,10 @@ def _getTwoStepRescaling(
):
sequence_length = self.NumTimeSteps
input_size = self.NumFeatures
# print(sequence_length)
# print(input_size)
# print('inputshape',input.shape)
# print('Saliency Rescaling',input)
assignment = input[0, 0, 0]
timeGrad = np.zeros((1, sequence_length))
inputGrad = np.zeros((input_size, 1))
Expand All @@ -162,11 +166,16 @@ def _getTwoStepRescaling(
ActualGrad = self.Grad.explain(
(input, None), self.model, class_index=TestingLabel
) # .data.cpu().numpy()

# print('Actual GRad', ActualGrad)
for t in range(sequence_length):
newInput = input.copy().reshape(1, input_size, sequence_length)
newInput = np.swapaxes(input.copy(), 2, 1).reshape(
1, input_size, sequence_length
)
# print('NEW INPUT',newInput)
newInput[:, :, t] = assignment
newInput = newInput.reshape(1, sequence_length, input_size)
newInput = np.swapaxes(newInput, 2, 1).reshape(
1, sequence_length, input_size
)
if self.method == "FO":
timeGrad_perTime = self.Grad.explain(
(newInput, None),
Expand All @@ -187,13 +196,16 @@ def _getTwoStepRescaling(

timeContibution = preprocessing.minmax_scale(timeGrad, axis=1)
meanTime = np.quantile(timeContibution, 0.55)

for t in range(sequence_length):
if timeContibution[0, t] > meanTime:
for c in range(input_size):
newInput = input.copy().reshape(1, input_size, sequence_length)
newInput = np.swapaxes(input.copy(), 2, 1).reshape(
1, input_size, sequence_length
)
newInput[:, c, t] = assignment
newInput = newInput.reshape(1, sequence_length, input_size)
newInput = np.swapaxes(newInput, 2, 1).reshape(
1, sequence_length, input_size
)
if self.method == "FO":
inputGrad_perInput = self.Grad.explain(
(newInput, None),
Expand All @@ -204,7 +216,6 @@ def _getTwoStepRescaling(
elif self.method == "DLS" or self.method == "GS":
inputGrad_perInput = self.Grad.shap_values(newInput)
inputGrad_perInput = np.array(inputGrad_perInput)
# print(inputGrad_perInput.shape)
else:
newInput = newInput.reshape(1, sequence_length, input_size, 1)
inputGrad_perInput = self.Grad.explain(
Expand All @@ -220,4 +231,4 @@ def _getTwoStepRescaling(

for c in range(input_size):
newGrad[c, t] = timeContibution[0, t] * featureContibution[c, 0]
return newGrad.reshape(sequence_length, input_size)
return np.swapaxes(newGrad, 0, 1)
7 changes: 5 additions & 2 deletions TSInterpret/InterpretabilityModels/Saliency/Saliency_Base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import seaborn as sns

from TSInterpret.InterpretabilityModels.FeatureAttribution import FeatureAttribution
import numpy as np


class Saliency(FeatureAttribution):
Expand Down Expand Up @@ -58,8 +59,10 @@ def plot(self, item, exp, figsize=(6.4, 4.8), heatmap=False, save=None):
i = 0
if self.mode == "time":
print("time mode")
item = item.reshape(1, item.shape[2], item.shape[1])
exp = exp.reshape(exp.shape[-1], -1)
item = np.swapaxes(
item, -1, -2
) # item.reshape(1, item.shape[2], item.shape[1])
exp = np.swapaxes(exp, -1, -2) # exp.reshape(exp.shape[-1], -1)
else:
print("NOT Time mode")

Expand Down
8 changes: 5 additions & 3 deletions TSInterpret/InterpretabilityModels/counterfactual/CF.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,12 +178,14 @@ def plot_in_one(
save_fig str: Path to Save the figure.
"""
if self.mode == "time":
item = item.reshape(item.shape[-1], item.shape[-2])
exp = exp.reshape(exp.shape[-1], exp.shape[-2])
org = item
item = np.swapaxes(item, -1, -2).reshape(org.shape[-1], org.shape[-2])
exp = np.swapaxes(exp, -1, -2).reshape(org.shape[-1], org.shape[-2])
else:
item = item.reshape(item.shape[-2], item.shape[-1])
exp = exp.reshape(item.shape[-2], item.shape[-1])

# item = np.swapaxes(item, -2, -1) # .reshape(item.shape[-1], item.shape[-2])
# exp = np.swapaxes(exp, -2, -1) # exp.reshape(exp.shape[-1], exp.shape[-2])
# TODO This is new and needs to be testes
ind = ""
# print("Item Shape", item.shape[-2])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ def construct_per_class_trees(self):
return
self.per_class_trees = {}
self.per_class_node_indices = {c: [] for c in np.unique(self.labels)}

input_ = self.timeseries.reshape(-1, self.channels, self.window_size)
input_ = self.timeseries

preds = np.argmax(self.clf(input_), axis=1)
true_positive_node_ids = {c: [] for c in np.unique(self.labels)}
Expand Down Expand Up @@ -246,7 +245,7 @@ def _find_best(self, x_test, distractor, label_idx):
CLASSIFIER = self.clf
X_TEST = x_test
DISTRACTOR = distractor
input_ = x_test.reshape(1, -1, self.window_size)
input_ = x_test
best_case = self.clf(input_)[0][label_idx]
best_column = None
tuples = []
Expand All @@ -271,7 +270,7 @@ def _find_best(self, x_test, distractor, label_idx):
return best_column, best_case

def explain(self, x_test, to_maximize=None, num_features=10):
input_ = x_test.reshape(1, -1, self.window_size)
input_ = x_test
orig_preds = self.clf(input_)
if to_maximize is None:
to_maximize = np.argsort(orig_preds)[0][-2:-1][0]
Expand All @@ -292,10 +291,8 @@ def explain(self, x_test, to_maximize=None, num_features=10):
prev_best = 0
# best_dist = dist
while True:
input_ = modified.reshape(1, -1, self.window_size)
input_ = modified
probas = self.clf(input_)
# print('Current may',np.argmax(probas))
# print(to_maximize)
if np.argmax(probas) == to_maximize:
current_best = np.max(probas)
if current_best > best_explanation_score:
Expand Down Expand Up @@ -376,16 +373,15 @@ def _prune_explanation(
modified = x_test.copy()
for c in short_explanation:
modified[0][c] = dist[0][c]
input_ = modified.reshape(1, -1, self.window_size)
input_ = modified
prev_proba = self.clf(input_)[0][to_maximize]
best_col = None
best_diff = 0
for c in explanation:
tmp = modified.copy()

tmp[0][c] = dist[0][c]

input_ = tmp.reshape(1, self.channels, self.window_size)
input_ = tmp
cur_proba = self.clf(input_)[0][to_maximize]
if cur_proba - prev_proba > best_diff:
best_col = c
Expand All @@ -399,7 +395,7 @@ def _prune_explanation(
def explain(
self, x_test, num_features=None, to_maximize=None
) -> Tuple[np.array, int]:
input_ = x_test.reshape(1, -1, self.window_size)
input_ = x_test
orig_preds = self.clf(input_)

orig_label = np.argmax(orig_preds)
Expand All @@ -419,8 +415,6 @@ def explain(
x_test, num_features=num_features, to_maximize=to_maximize
)
best, other = explanation
# print('Other',np.array(other).shape)
# print('Best',np.array(best).shape)
target = np.argmax(self.clf(best), axis=1)

return best, target
Expand All @@ -429,8 +423,6 @@ def _get_explanation(self, x_test, to_maximize, num_features):
distractors = self._get_distractors(
x_test, to_maximize, n_distractors=self.num_distractors
)
# print('distracotr shape',np.array(distractors).shape)
# print('distracotr classification',np.argmax(self.clf(np.array(distractors).reshape(2,6,100)), axis=1))

# Avoid constructing KDtrees twice
self.backup.per_class_trees = self.per_class_trees
Expand Down Expand Up @@ -469,7 +461,7 @@ def _get_explanation(self, x_test, to_maximize, num_features):
for c in columns:
if c in explanation:
modified[0][c] = dist[0][c]
input_ = modified.reshape(1, -1, self.window_size)
input_ = modified # .reshape(1, -1, self.window_size)
probas = self.clf(input_)

if not self.silent:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,32 +85,17 @@ def evaluate(self, feature_matrix):

for col_replace, a in zip(self.cols_swap, feature_matrix):
if a == 1:
# print(self.distractor.shape)
new_case[0][col_replace] = self.distractor[0][col_replace]

replaced_feature_count = np.sum(feature_matrix)
# print('replaced_Feature', replaced_feature_count)

# print('NEW CASE', new_case)
# print('self xtest', self.x_test)
# print('NEW CASE', new_case.shape)
# print('self xtest', self.x_test.shape)
# print('DIFF', np.where((self.x_test.reshape(-1)-new_case.reshape(-1)) != 0) )

input_ = new_case.reshape(1, self.channels, self.window_size)
input_ = new_case
result_org = self.clf(input_)
result = result_org[0][self.target]
# print('RESULT',result)
feature_loss = self.reg * np.maximum(
0, replaced_feature_count - self.max_features
)

# print('FEATURE LOSS',feature_loss)
loss_pred = np.square(np.maximum(0, 0.95 - result))
# print('losspred ',loss_pred)
# if np.argmax(result_org[0]) != self.target:
# loss_pred=np.inf

loss_pred = loss_pred + feature_loss

return loss_pred
8 changes: 6 additions & 2 deletions TSInterpret/InterpretabilityModels/counterfactual/COMTECF.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def __init__(
# Parse test data into (1, feat, time):
change = True
self.ts_length = shape[-2]
test_x = test_x.reshape(test_x.shape[0], test_x.shape[2], test_x.shape[1])
test_x = np.swapaxes(
test_x, 2, 1
) # test_x.reshape(test_x.shape[0], test_x.shape[2], test_x.shape[1])
elif mode == "feat":
change = False
self.ts_length = shape[-1]
Expand Down Expand Up @@ -87,7 +89,7 @@ def explain(
"""
org_shape = x.shape
if self.mode != "feat":
x = x.reshape(-1, x.shape[-1], x.shape[-2])
x = np.swapaxes(x, -1, -2) # x.reshape(-1, x.shape[-1], x.shape[-2])
train_x, train_y = self.referenceset
if len(train_y.shape) > 1:
train_y = np.argmax(train_y, axis=1)
Expand All @@ -106,4 +108,6 @@ def explain(
elif self.method == "brute":
opt = BruteForceSearch(self.predict, train_x, train_y, threads=1)
exp, label = opt.explain(x, to_maximize=target)
if self.mode != "feat":
exp = np.swapaxes(exp, -1, -2)
return exp.reshape(org_shape), label
2 changes: 1 addition & 1 deletion TSInterpret/Models/PyTorchModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def predict(self, item) -> List:
"""
item = np.array(item.tolist()) # , dtype=np.float64)
if self.change:
item = torch.from_numpy(item.reshape(-1, item.shape[-1], item.shape[-2]))
item = torch.from_numpy(np.swapaxes(item, -1, -2))

else:
item = torch.from_numpy(item)
Expand Down
5 changes: 4 additions & 1 deletion TSInterpret/Models/TensorflowModel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import tensorflow as tf

from TSInterpret.Models.base_model import BaseModel
import numpy as np


class TensorFlowModel(BaseModel):
Expand All @@ -20,7 +21,9 @@ def predict(self, item):
an array of output scores for a classifier.
"""
if self.change:
item = item.reshape(item.shape[0], item.shape[2], item.shape[1])
item = np.swapaxes(
item, 2, 1
) # item.reshape(item.shape[0], item.shape[2], item.shape[1])
out = self.model.predict(item)
return out

Expand Down
2 changes: 1 addition & 1 deletion TSInterpret/__version__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
VERSION = (0, 3, 4)
VERSION = (0, 4, 0)
__version__ = ".".join(map(str, VERSION)) # noqa: F401
Binary file modified docs/Notebooks/Ates.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 3deccfd

Please sign in to comment.