From aa8eb329d1d07f0e28f2717849574c9d7d4f5b53 Mon Sep 17 00:00:00 2001 From: "Deng, Weishi" Date: Tue, 12 Dec 2023 19:20:09 -0800 Subject: [PATCH] yolov3 and tacotron2: add support for xpu devices --- torchbenchmark/models/tacotron2/train_tacotron2.py | 10 +++++++++- torchbenchmark/models/yolov3/yolo_utils/torch_utils.py | 8 ++++++-- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/torchbenchmark/models/tacotron2/train_tacotron2.py b/torchbenchmark/models/tacotron2/train_tacotron2.py index 373b082ce7..a9bacbea92 100644 --- a/torchbenchmark/models/tacotron2/train_tacotron2.py +++ b/torchbenchmark/models/tacotron2/train_tacotron2.py @@ -71,7 +71,15 @@ def prepare_directories_and_logger(output_directory, log_directory, rank): def load_model(hparams): - model = Tacotron2(hparams).cuda() + use_xpu = torch.xpu.is_available() + use_cuda = torch.cuda.is_available() + if use_xpu: + device = 'xpu' + elif use_cuda: + device = 'cuda' + else: + device = 'cpu' + model = Tacotron2(hparams).to(device) if hparams.fp16_run: model.decoder.attention_layer.score_mask_value = finfo('float16').min diff --git a/torchbenchmark/models/yolov3/yolo_utils/torch_utils.py b/torchbenchmark/models/yolov3/yolo_utils/torch_utils.py index 5e09407fc8..4fe24caf48 100644 --- a/torchbenchmark/models/yolov3/yolo_utils/torch_utils.py +++ b/torchbenchmark/models/yolov3/yolo_utils/torch_utils.py @@ -26,13 +26,17 @@ def init_seeds(seed=0): def select_device(device='', apex=False, batch_size=None): # device = 'cpu' or '0' or '0,1,2,3' cpu_request = device.lower() == 'cpu' - if device and not cpu_request: # if device requested other than 'cpu' + xpu_request = device.lower() == 'xpu' + if device and not cpu_request and not xpu_request: # if device requested other than 'cpu' and 'xpu' os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable assert torch.cuda.is_available(), 'CUDA unavailable, invalid device %s requested' % device # check availablity - cuda = False if cpu_request else torch.cuda.is_available() + cuda = False if cpu_request or xpu_request else torch.cuda.is_available() if cuda: return torch.device(f"cuda:{torch.cuda.current_device()}") + if xpu_request: + print('Using XPU') + return torch.device('xpu') print('Using CPU') return torch.device('cpu')