diff --git a/torchbenchmark/models/BERT_pytorch/__init__.py b/torchbenchmark/models/BERT_pytorch/__init__.py index f041ea54df..298015c577 100644 --- a/torchbenchmark/models/BERT_pytorch/__init__.py +++ b/torchbenchmark/models/BERT_pytorch/__init__.py @@ -208,7 +208,7 @@ def get_module(self): return self.model.model, self.example_inputs def set_module(self, new_model): - self.model.bert = new_model + self.model.model = new_model def eval(self) -> typing.Tuple[torch.Tensor]: model = self.model