Skip to content

Commit

Permalink
yolov3 and tacotron2: add support for xpu devices
Browse files Browse the repository at this point in the history
  • Loading branch information
weishi-deng committed Dec 13, 2023
1 parent 7de2aed commit aa8eb32
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
10 changes: 9 additions & 1 deletion torchbenchmark/models/tacotron2/train_tacotron2.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,15 @@ def prepare_directories_and_logger(output_directory, log_directory, rank):


def load_model(hparams):
model = Tacotron2(hparams).cuda()
use_xpu = torch.xpu.is_available()
use_cuda = torch.cuda.is_available()
if use_xpu:
device = 'xpu'
elif use_cuda:
device = 'cuda'
else:
device = 'cpu'
model = Tacotron2(hparams).to(device)
if hparams.fp16_run:
model.decoder.attention_layer.score_mask_value = finfo('float16').min

Expand Down
8 changes: 6 additions & 2 deletions torchbenchmark/models/yolov3/yolo_utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,17 @@ def init_seeds(seed=0):
def select_device(device='', apex=False, batch_size=None):
# device = 'cpu' or '0' or '0,1,2,3'
cpu_request = device.lower() == 'cpu'
if device and not cpu_request: # if device requested other than 'cpu'
xpu_request = device.lower() == 'xpu'
if device and not cpu_request and not xpu_request: # if device requested other than 'cpu' and 'xpu'
os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable
assert torch.cuda.is_available(), 'CUDA unavailable, invalid device %s requested' % device # check availablity

cuda = False if cpu_request else torch.cuda.is_available()
cuda = False if cpu_request or xpu_request else torch.cuda.is_available()
if cuda:
return torch.device(f"cuda:{torch.cuda.current_device()}")
if xpu_request:
print('Using XPU')
return torch.device('xpu')

print('Using CPU')
return torch.device('cpu')
Expand Down

0 comments on commit aa8eb32

Please sign in to comment.