Skip to content

Commit

Permalink
push
Browse files Browse the repository at this point in the history
  • Loading branch information
msaroufim committed Jul 25, 2023
1 parent f232aac commit c3d5d10
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions torchbenchmark/models/hf_Whisper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,10 @@ def __init__(self, test, device, jit=False, batch_size=None, extra_args=[]):
super().__init__(name="hf_Whisper", test=test, device=device, jit=jit, batch_size=batch_size, extra_args=extra_args)
self.feature_size = 80
self.sequence_length = 3000
input_features = torch.randn(size=(self.batch_size, self.feature_size, self.sequence_length),device=self.device).half()
self.example_inputs = {"input_features": input_features.to(self.device)}
self.input_features = torch.randn(size=(self.batch_size, self.feature_size, self.sequence_length),device=self.device).half()
self.example_inputs = {"input_features": self.input_features.to(self.device), "input_ids" : self.input_features.to(self.device)}
self.model.to(self.device)

def get_module(self):
return self.model, (self.example_inputs)

def train(self):
raise NotImplementedError("Training is not implemented.")
raise NotImplementedError("Training is not implemented.")

0 comments on commit c3d5d10

Please sign in to comment.