Skip to content

Commit

Permalink
tacotron2: benchmark coverage for custom devices
Browse files Browse the repository at this point in the history
  • Loading branch information
weishi-deng committed Apr 9, 2024
1 parent d9a9600 commit 1cdda97
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion torchbenchmark/models/tacotron2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self, test, device, batch_size=None, extra_args=[]):
raise NotImplementedError("Tacotron2 doesn't support CPU because load_model assumes CUDA.")

self.hparams = self.create_hparams(batch_size=self.batch_size)
self.model = load_model(self.hparams).to(device=device)
self.model = load_model(self.hparams, device)
self.optimizer = torch.optim.Adam(self.model.parameters(),
lr=self.hparams.learning_rate,
weight_decay=self.hparams.weight_decay)
Expand Down
4 changes: 2 additions & 2 deletions torchbenchmark/models/tacotron2/train_tacotron2.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def prepare_directories_and_logger(output_directory, log_directory, rank):
return logger


def load_model(hparams):
model = Tacotron2(hparams).cuda()
def load_model(hparams, device='cuda'):
model = Tacotron2(hparams).to(device)
if hparams.fp16_run:
model.decoder.attention_layer.score_mask_value = finfo('float16').min

Expand Down

0 comments on commit 1cdda97

Please sign in to comment.