Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nvidia_deeprecommender: add support for xpu devices #2085

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 23 additions & 8 deletions torchbenchmark/models/nvidia_deeprecommender/nvinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def getCommandLineArgs() :
help='jit-ify model before running')
parser.add_argument('--forcecuda', action='store_true',
help='force cuda use')
parser.add_argument('--forcexpu', action='store_true',
help='force xpu use')
parser.add_argument('--forcecpu', action='store_true',
help='force cpu use')
parser.add_argument('--nooutput', action='store_true',
Expand All @@ -57,7 +59,7 @@ def getCommandLineArgs() :

return args

def getBenchmarkArgs(forceCuda):
def getBenchmarkArgs(forceCuda, forceXpu):

class Args:
pass
Expand All @@ -76,7 +78,8 @@ class Args:
args.batch_size = 1
args.jit = False
args.forcecuda = forceCuda
args.forcecpu = not forceCuda
args.forcexpu = forceXpu
args.forcecpu = not (forceCuda or forceXpu)
args.nooutput = True
args.silent = True
args.profile = False
Expand All @@ -93,20 +96,26 @@ def processArgState(args) :
quit()

args.use_cuda = torch.cuda.is_available() # global flag
args.use_xpu = torch.xpu.is_available() # global flag
if not args.silent:
if args.use_cuda:
if args.use_cuda or args.use_xpu:
print('GPU is available.')
else:
print('GPU is not available.')

if args.use_cuda and args.forcecpu:
args.use_cuda = False
device = 'cpu'

if args.use_xpu and args.forcecpu:
args.use_xpu = False
device = 'cpu'

if not args.silent:
if args.use_cuda:
if args.use_cuda or args.use_xpu:
print('Running On GPU')
else:
print('Running On CUDA')
print('Running On CPU')

if args.profile:
print('Profiler Enabled')
Expand All @@ -132,13 +141,18 @@ def __init__(self, device = 'cpu', jit=False, batch_size=256, usecommandlineargs
else:
if device == "cpu":
forcecuda = False
forcexpu = False
elif device == "cuda":
forcecuda = True
forcexpu = False
elif device == "xpu":
forcecuda = False
forcexpu = True
else:
# unknown device string, quit init
return

self.args = getBenchmarkArgs(forcecuda)
self.args = getBenchmarkArgs(forcecuda, forcexpu)

args = processArgState(self.args)

Expand Down Expand Up @@ -199,6 +213,7 @@ def __init__(self, device = 'cpu', jit=False, batch_size=256, usecommandlineargs


if self.args.use_cuda: self.rencoder = self.rencoder.cuda()
if self.args.use_xpu: self.rencoder = self.rencoder.xpu()

if self.toytest == False:
self.inv_userIdMap = {v: k for k, v in self.data_layer.userIdMap.items()}
Expand All @@ -214,7 +229,7 @@ def eval(self, niter=1):
continue

for i, ((out, src), majorInd) in enumerate(self.eval_data_layer.iterate_one_epoch_eval(for_inf=True)):
inputs = Variable(src.cuda().to_dense() if self.args.use_cuda else src.to_dense())
inputs = Variable(src.to(device).to_dense())
targets_np = out.to_dense().numpy()[0, :]

out = self.rencoder(inputs)
Expand All @@ -237,7 +252,7 @@ def TimedInferenceRun(self) :
e_start_time = time.time()

if self.args.profile:
with profiler.profile(record_shapes=True, use_cuda=True) as prof:
with profiler.profile(record_shapes=True, use_cuda=self.args.use_cuda, use_xpu=self.args.use_xpu) as prof:
with profiler.record_function("Inference"):
self.eval()
else:
Expand Down
44 changes: 33 additions & 11 deletions torchbenchmark/models/nvidia_deeprecommender/nvtrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ def getTrainCommandLineArgs() :
help='disable all messages')
parser.add_argument('--forcecuda', action='store_true',
help='force cuda use')
parser.add_argument('--forcexpu', action='store_true',
help='force xpu use')
parser.add_argument('--forcecpu', action='store_true',
help='force cpu use')
parser.add_argument('--profile', action='store_true',
Expand All @@ -120,18 +122,25 @@ def processTrainArgState(args) :
quit()

args.use_cuda = torch.cuda.is_available() # global flag
args.use_xpu = torch.xpu.is_available() # global flag
if not args.silent:
if args.use_cuda:
if args.use_cuda or args.use_xpu:
print('GPU is available.')
else:
print('GPU is not available.')

if args.use_cuda and args.forcecpu:
args.use_cuda = False
device = 'cpu'
if args.use_xpu and args.forcecpu:
args.use_xpu = False
device = 'cpu'

if not args.silent:
if args.use_cuda:
print('Running On CUDA')
elif args.use_xpu:
print('Running On XPU')
else:
print('Running On CPU')

Expand Down Expand Up @@ -164,13 +173,13 @@ def log_var_and_grad_summaries(logger, layers, global_step, prefix, log_histogra
logger.histo_summary(tag="Gradients/{}_{}".format(prefix, ind), values=w.grad.data.cpu().numpy(),
step=global_step)

def DoTrainEval(encoder, evaluation_data_layer, use_cuda):
def DoTrainEval(encoder, evaluation_data_layer, device):
encoder.eval()
denom = 0.0
total_epoch_loss = 0.0
for i, (eval, src) in enumerate(evaluation_data_layer.iterate_one_epoch_eval()):
inputs = Variable(src.cuda().to_dense() if use_cuda else src.to_dense())
targets = Variable(eval.cuda().to_dense() if use_cuda else eval.to_dense())
inputs = Variable(src.to(device).to_dense())
targets = Variable(eval.to(device).to_dense())
outputs = encoder(inputs)
loss, num_ratings = model.MSEloss(outputs, targets)
total_epoch_loss += loss.item()
Expand Down Expand Up @@ -201,14 +210,24 @@ def TrainInit(self, device="cpu", jit=False, batch_size=256, processCommandLine

if device == "cpu":
forcecuda = False
forcecpu = True
forcexpu = False
elif device == "cuda":
forcecuda = True
forcecpu = False
forcexpu = False
elif device == 'xpu':
forcecuda = False
forcecpu = False
forcexpu = True
else:
# unknown device string, quit init
print('warning: skip by unknown device:', device)
return

self.args.forcecuda = forcecuda
self.args.forcecpu = not forcecuda
self.args.forcecpu = forcecpu
self.args.forcexpu = forcexpu

self.args = processTrainArgState(self.args)

Expand Down Expand Up @@ -274,15 +293,18 @@ def TrainInit(self, device="cpu", jit=False, batch_size=256, processCommandLine
gpu_ids = [int(g) for g in self.args.gpu_ids.split(',')]
if not self.args.silent:
print('Using GPUs: {}'.format(gpu_ids))

if len(gpu_ids)>1:
self.rencoder = nn.DataParallel(self.rencoder,
device_ids=gpu_ids)

self.rencoder = self.rencoder.cuda()
self.toyinputs = self.toyinputs.to(device)


if self.args.use_xpu:
self.rencoder = self.rencoder.xpu()
self.toyinputs = self.toyinputs.to('xpu')

if self.args.optimizer == "adam":
self.optimizer = optim.Adam(self.rencoder.parameters(),
lr=self.args.lr,
Expand Down Expand Up @@ -326,7 +348,7 @@ def DoTrain(self):

for i, mb in enumerate(self.data_layer.iterate_one_epoch()):

inputs = Variable(mb.cuda().to_dense() if self.args.use_cuda else mb.to_dense())
inputs = Variable(mb.to(device).to_dense())

self.optimizer.zero_grad()

Expand Down Expand Up @@ -404,7 +426,7 @@ def train(self, niter=1) :
self.logger.scalar_summary("Training_RMSE_per_epoch", sqrt(self.total_epoch_loss/self.denom), self.epoch)
self.logger.scalar_summary("Epoch_time", e_end_time - e_start_time, self.epoch)
if self.epoch % self.args.save_every == 0 or self.epoch == self.args.num_epochs - 1:
eval_loss = DoTrainEval(self.rencoder, self.eval_data_layer, self.args.use_cuda)
eval_loss = DoTrainEval(self.rencoder, self.eval_data_layer, device)
print('Epoch {} EVALUATION LOSS: {}'.format(self.epoch, eval_loss))

self.logger.scalar_summary("EVALUATION_RMSE", eval_loss, self.epoch)
Expand All @@ -417,13 +439,13 @@ def train(self, niter=1) :

# save to onnx
dummy_input = Variable(torch.randn(self.params['batch_size'], self.data_layer.vector_dim).type(torch.float))
torch.onnx.export(self.rencoder.float(), dummy_input.cuda() if self.args.use_cuda else dummy_input,
torch.onnx.export(self.rencoder.float(), dummy_input.to(device),
self.model_checkpoint + ".onnx", verbose=True)
print("ONNX model saved to {}!".format(self.model_checkpoint + ".onnx"))

def TimedTrainingRun(self):
if self.args.profile:
with profiler.profile(record_shapes=True, use_cuda=self.args.use_cuda) as prof:
with profiler.profile(record_shapes=True, use_cuda=self.args.use_cuda, use_xpu=self.args.use_xpu) as prof:
with profiler.record_function("training_epoch"):
self.train(self.args.num_epochs)
else:
Expand Down
Loading