Skip to content

Commit

Permalink
precommit
Browse files Browse the repository at this point in the history
  • Loading branch information
JHoelli committed Jun 14, 2023
1 parent daa9c43 commit c3ad128
Showing 1 changed file with 7 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,11 @@ def __init__(
max_iter int : max number of runs
"""
super().__init__(model, mode)

self.backend = backend
test_x, test_y = data
test_x = np.array(test_x) # , dtype=np.float32)
shape= (test_x.shape[-2],test_x.shape[-1])
shape = (test_x.shape[-2], test_x.shape[-1])
print(shape)
if mode == "time":
# Parse test data into (1, feat, time):
Expand All @@ -74,9 +74,9 @@ def __init__(

if backend == "PYT":
self.remove_all_hooks(self.model)
#try:
self.cam_extractor = CAM(self.model,input_shape=shape)
#except:
# try:
self.cam_extractor = CAM(self.model, input_shape=shape)
# except:
# print("GradCam Hook already registered")
change = False
if self.mode == "time":
Expand All @@ -100,8 +100,8 @@ def __init__(
self.n_neighbors = n_neighbors
# Manipulate reference set replace original y with predicted y

def remove_all_hooks(self,model: torch.nn.Module) -> None:
#TODO Move THIS TO TSINTERPRET !
def remove_all_hooks(self, model: torch.nn.Module) -> None:
# TODO Move THIS TO TSINTERPRET !
if hasattr(model, "_forward_hooks"):
if model._forward_hooks != OrderedDict():
model._forward_hooks: Dict[int, Callable] = OrderedDict()
Expand Down

0 comments on commit c3ad128

Please sign in to comment.