From c3d5d100f27939af568acf874d4e50cd081eb96a Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Tue, 25 Jul 2023 21:30:20 +0000 Subject: [PATCH] push --- torchbenchmark/models/hf_Whisper/__init__.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/torchbenchmark/models/hf_Whisper/__init__.py b/torchbenchmark/models/hf_Whisper/__init__.py index 6ff94c461c..a2d8126cfa 100644 --- a/torchbenchmark/models/hf_Whisper/__init__.py +++ b/torchbenchmark/models/hf_Whisper/__init__.py @@ -10,12 +10,10 @@ def __init__(self, test, device, jit=False, batch_size=None, extra_args=[]): super().__init__(name="hf_Whisper", test=test, device=device, jit=jit, batch_size=batch_size, extra_args=extra_args) self.feature_size = 80 self.sequence_length = 3000 - input_features = torch.randn(size=(self.batch_size, self.feature_size, self.sequence_length),device=self.device).half() - self.example_inputs = {"input_features": input_features.to(self.device)} + self.input_features = torch.randn(size=(self.batch_size, self.feature_size, self.sequence_length),device=self.device).half() + self.example_inputs = {"input_features": self.input_features.to(self.device), "input_ids" : self.input_features.to(self.device)} self.model.to(self.device) - def get_module(self): - return self.model, (self.example_inputs) - def train(self): - raise NotImplementedError("Training is not implemented.") \ No newline at end of file + raise NotImplementedError("Training is not implemented.") +