Skip to content

Commit

Permalink
Tacotron2, yolov3*: benchmark coverage for custom devices. (#2230)
Browse files Browse the repository at this point in the history
Summary:
Works for Roadmap #1293 to increase benchmark coverage.

For these 5 models:tacotron2, yolov3, nvidia_deeprecommender, LearningToPoint and pytorch_CycleGAN_and_pix2pix,
when running on custom devices except for CPU and CUDA(e.g. XPU), it will raise the error as it's hard-coded with CPU/CUDA backends.
In this PR, we accept the device args as a param within the training process and inference process which will cover the model initializing and data transposition for these custom devices.

Pull Request resolved: #2230

Reviewed By: aaronenyeshi

Differential Revision: D56097643

Pulled By: xuzhao9

fbshipit-source-id: deba28fee42b5119f62dbddc15e017bf00eb6843
  • Loading branch information
weishi-deng authored and facebook-github-bot committed Apr 13, 2024
1 parent 6bff330 commit 9d68c51
Show file tree
Hide file tree
Showing 14 changed files with 85 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def prBlack(prt):


def to_numpy(var):
return var.cpu().data.numpy() if USE_CUDA else var.data.numpy()
return var.cpu().data.numpy()


def to_tensor(ndarray, device):
Expand Down
26 changes: 16 additions & 10 deletions torchbenchmark/models/nvidia_deeprecommender/nvinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def getCommandLineArgs() :

return args

def getBenchmarkArgs(forceCuda):
def getBenchmarkArgs(forceCuda, device='cuda'):

class Args:
pass
Expand All @@ -76,10 +76,11 @@ class Args:
args.batch_size = 1
args.jit = False
args.forcecuda = forceCuda
args.forcecpu = not forceCuda
args.forcecpu = False if forceCuda else device == 'cpu'
args.nooutput = True
args.silent = True
args.profile = False
args.device = device

return args

Expand All @@ -93,20 +94,22 @@ def processArgState(args) :
quit()

args.use_cuda = torch.cuda.is_available() # global flag
args.use_xpu = torch.xpu.is_available()
if not args.silent:
if args.use_cuda:
if args.use_cuda or args.use_xpu:
print('GPU is available.')
else:
print('GPU is not available.')

if args.use_cuda and args.forcecpu:
if args.forcecpu:
args.use_cuda = False

args.use_xpu = False

if not args.silent:
if args.use_cuda:
if args.use_cuda or args.use_xpu:
print('Running On GPU')
else:
print('Running On CUDA')
print('Running On CPU')

if args.profile:
print('Profiler Enabled')
Expand Down Expand Up @@ -134,11 +137,13 @@ def __init__(self, device = 'cpu', jit=False, batch_size=256, usecommandlineargs
forcecuda = False
elif device == "cuda":
forcecuda = True
elif device == "xpu":
forcecuda = False
else:
# unknown device string, quit init
return

self.args = getBenchmarkArgs(forcecuda)
self.args = getBenchmarkArgs(forcecuda, device)

args = processArgState(self.args)

Expand Down Expand Up @@ -199,6 +204,7 @@ def __init__(self, device = 'cpu', jit=False, batch_size=256, usecommandlineargs


if self.args.use_cuda: self.rencoder = self.rencoder.cuda()
elif self.args.use_xpu: self.rencoder = self.rencoder.xpu()

if self.toytest == False:
self.inv_userIdMap = {v: k for k, v in self.data_layer.userIdMap.items()}
Expand All @@ -214,7 +220,7 @@ def eval(self, niter=1):
continue

for i, ((out, src), majorInd) in enumerate(self.eval_data_layer.iterate_one_epoch_eval(for_inf=True)):
inputs = Variable(src.cuda().to_dense() if self.args.use_cuda else src.to_dense())
inputs = Variable(src.to(device).to_dense())
targets_np = out.to_dense().numpy()[0, :]

out = self.rencoder(inputs)
Expand All @@ -237,7 +243,7 @@ def TimedInferenceRun(self) :
e_start_time = time.time()

if self.args.profile:
with profiler.profile(record_shapes=True, use_cuda=True) as prof:
with profiler.profile(record_shapes=True, use_cuda=self.args.use_cuda, use_xpu=self.args.use_xpu) as prof:
with profiler.record_function("Inference"):
self.eval()
else:
Expand Down
30 changes: 17 additions & 13 deletions torchbenchmark/models/nvidia_deeprecommender/nvtrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,9 @@ def processTrainArgState(args) :
quit()

args.use_cuda = torch.cuda.is_available() # global flag
args.use_xpu = args.device == 'xpu'
if not args.silent:
if args.use_cuda:
if args.use_cuda or args.use_xpu:
print('GPU is available.')
else:
print('GPU is not available.')
Expand All @@ -130,8 +131,8 @@ def processTrainArgState(args) :
args.use_cuda = False

if not args.silent:
if args.use_cuda:
print('Running On CUDA')
if args.use_cuda or args.use_xpu:
print('Running On GPU')
else:
print('Running On CPU')

Expand Down Expand Up @@ -164,13 +165,13 @@ def log_var_and_grad_summaries(logger, layers, global_step, prefix, log_histogra
logger.histo_summary(tag="Gradients/{}_{}".format(prefix, ind), values=w.grad.data.cpu().numpy(),
step=global_step)

def DoTrainEval(encoder, evaluation_data_layer, use_cuda):
def DoTrainEval(encoder, evaluation_data_layer, device):
encoder.eval()
denom = 0.0
total_epoch_loss = 0.0
for i, (eval, src) in enumerate(evaluation_data_layer.iterate_one_epoch_eval()):
inputs = Variable(src.cuda().to_dense() if use_cuda else src.to_dense())
targets = Variable(eval.cuda().to_dense() if use_cuda else eval.to_dense())
inputs = Variable(src.to(device).to_dense())
targets = Variable(eval.to(device).to_dense())
outputs = encoder(inputs)
loss, num_ratings = model.MSEloss(outputs, targets)
total_epoch_loss += loss.item()
Expand Down Expand Up @@ -203,12 +204,15 @@ def TrainInit(self, device="cpu", jit=False, batch_size=256, processCommandLine
forcecuda = False
elif device == "cuda":
forcecuda = True
elif device == "xpu":
forcecuda = False
else:
# unknown device string, quit init
return

self.args.forcecuda = forcecuda
self.args.forcecpu = not forcecuda
self.args.forcecpu = not forcecuda and device == 'cpu'
self.args.device = device

self.args = processTrainArgState(self.args)

Expand Down Expand Up @@ -279,9 +283,9 @@ def TrainInit(self, device="cpu", jit=False, batch_size=256, processCommandLine
self.rencoder = nn.DataParallel(self.rencoder,
device_ids=gpu_ids)

self.rencoder = self.rencoder.cuda()
self.toyinputs = self.toyinputs.to(device)

self.toyinputs = self.toyinputs.to(device)
self.rencoder = self.rencoder.to(device)

if self.args.optimizer == "adam":
self.optimizer = optim.Adam(self.rencoder.parameters(),
Expand Down Expand Up @@ -326,7 +330,7 @@ def DoTrain(self):

for i, mb in enumerate(self.data_layer.iterate_one_epoch()):

inputs = Variable(mb.cuda().to_dense() if self.args.use_cuda else mb.to_dense())
inputs = Variable(mb.to(self.args.device).to_dense())

self.optimizer.zero_grad()

Expand Down Expand Up @@ -404,7 +408,7 @@ def train(self, niter=1) :
self.logger.scalar_summary("Training_RMSE_per_epoch", sqrt(self.total_epoch_loss/self.denom), self.epoch)
self.logger.scalar_summary("Epoch_time", e_end_time - e_start_time, self.epoch)
if self.epoch % self.args.save_every == 0 or self.epoch == self.args.num_epochs - 1:
eval_loss = DoTrainEval(self.rencoder, self.eval_data_layer, self.args.use_cuda)
eval_loss = DoTrainEval(self.rencoder, self.eval_data_layer, self.args.device)
print('Epoch {} EVALUATION LOSS: {}'.format(self.epoch, eval_loss))

self.logger.scalar_summary("EVALUATION_RMSE", eval_loss, self.epoch)
Expand All @@ -417,13 +421,13 @@ def train(self, niter=1) :

# save to onnx
dummy_input = Variable(torch.randn(self.params['batch_size'], self.data_layer.vector_dim).type(torch.float))
torch.onnx.export(self.rencoder.float(), dummy_input.cuda() if self.args.use_cuda else dummy_input,
torch.onnx.export(self.rencoder.float(), dummy_input.to(device),
self.model_checkpoint + ".onnx", verbose=True)
print("ONNX model saved to {}!".format(self.model_checkpoint + ".onnx"))

def TimedTrainingRun(self):
if self.args.profile:
with profiler.profile(record_shapes=True, use_cuda=self.args.use_cuda) as prof:
with profiler.profile(record_shapes=True, use_cuda=self.args.use_cuda, use_xpu=self.args.use_xpu) as prof:
with profiler.record_function("training_epoch"):
self.train(self.args.num_epochs)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,18 @@ def __init__(self, test, device, batch_size=None, extra_args=[]):
results_arg = f"--results_dir {results_dir}"
data_root = os.path.join(DATA_PATH, "pytorch_CycleGAN_and_pix2pix_inputs")
device_arg = ""
device_type_arg = f"--device_type {self.device}"
if self.device == "cpu":
device_arg = "--gpu_ids -1"
elif self.device == "cuda":
else:
device_arg = "--gpu_ids 0"

if self.test == "train":
train_args = f"--tb_device {self.device} --dataroot {data_root}/datasets/horse2zebra --name horse2zebra --model cycle_gan --display_id 0 --n_epochs 3 " + \
f"--n_epochs_decay 3 {device_arg} {checkpoints_arg}"
f"--n_epochs_decay 3 {device_type_arg} {device_arg} {checkpoints_arg}"
self.training_loop = prepare_training_loop(train_args.split(' '))
args = f"--dataroot {data_root}/datasets/horse2zebra/testA --name horse2zebra_pretrained --model test " + \
f"--no_dropout {device_arg} {checkpoints_arg} {results_arg}"
f"--no_dropout {device_type_arg} {device_arg} {checkpoints_arg} {results_arg}"
self.model, self.input = get_model(args, self.device)

def get_module(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ def __init__(self, opt):
"""
self.opt = opt
self.gpu_ids = opt.gpu_ids
self.device_type = opt.device_type
self.isTrain = opt.isTrain
self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
self.device = torch.device('{}:{}'.format(self.device_type, self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
torch.backends.cudnn.benchmark = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,15 @@ def __init__(self, opt):
# The naming is different from those used in the paper.
# Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, self.device_type)
self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm,
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, self.device_type)

if self.isTrain: # define discriminators
self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids, self.device_type)
self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids, self.device_type)

if self.isTrain:
if opt.lambda_identity > 0.0: # only works when input and output images have the same number of channels
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,25 +97,28 @@ def init_func(m): # define the initialization function
net.apply(init_func) # apply the initialization function <init_func>


def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], device='cuda'):
"""Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
Parameters:
net (network) -- the network to be initialized
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
gain (float) -- scaling factor for normal, xavier and orthogonal.
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
device (str) -- device type: cpu/cuda/xpu
Return an initialized network.
"""
if len(gpu_ids) > 0:
assert(torch.cuda.is_available())
net.to(gpu_ids[0])
assert(device != 'cpu')
assert(hasattr(torch, device))
assert(getattr(torch, device).is_available())
net.to('{}:{}'.format(device, gpu_ids[0]))
net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
init_weights(net, init_type, init_gain=init_gain)
return net


def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]):
def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[], device='cuda'):
"""Create a generator
Parameters:
Expand All @@ -128,6 +131,7 @@ def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, in
init_type (str) -- the name of our initialization method.
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
device (str) -- device type: cpu/cuda/xpu
Returns a generator
Expand Down Expand Up @@ -155,21 +159,22 @@ def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, in
net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
else:
raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
return init_net(net, init_type, init_gain, gpu_ids)
return init_net(net, init_type, init_gain, gpu_ids, device)


def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]):
def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[], device='cuda'):
"""Create a discriminator
Parameters:
input_nc (int) -- the number of channels in input images
ndf (int) -- the number of filters in the first conv layer
netD (str) -- the architecture's name: basic | n_layers | pixel
n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers'
norm (str) -- the type of normalization layers used in the network.
init_type (str) -- the name of the initialization method.
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
input_nc (int) -- the number of channels in input images
ndf (int) -- the number of filters in the first conv layer
netD (str) -- the architecture's name: basic | n_layers | pixel
n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers'
norm (str) -- the type of normalization layers used in the network.
init_type (str) -- the name of the initialization method.
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
device (str) -- device type: cpu/cuda/xpu
Returns a discriminator
Expand Down Expand Up @@ -199,7 +204,7 @@ def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal'
net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)
else:
raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD)
return init_net(net, init_type, init_gain, gpu_ids)
return init_net(net, init_type, init_gain, gpu_ids, device)


##############################################################################
Expand Down Expand Up @@ -281,7 +286,7 @@ def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', const
netD (network) -- discriminator network
real_data (tensor array) -- real images
fake_data (tensor array) -- generated images from the generator
device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
device (str) -- cpu / cuda / xpu
type (str) -- if we mix real and fake data or not [real | fake | mixed].
constant (float) -- the constant used in formula ( ||gradient||_2 - constant)^2
lambda_gp (float) -- weight for this loss
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ def __init__(self, opt):
self.model_names = ['G']
# define networks (both generator and discriminator)
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, self.device_type)

if self.isTrain: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids, self.device_type)

if self.isTrain:
# define loss functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(self, opt):
# you can use opt.isTrain to specify different behaviors for training and test. For example, some networks will not be used during test, and you don't need to load them.
self.model_names = ['G']
# define networks; you can use opt.isTrain to specify different behaviors for training and test.
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, gpu_ids=self.gpu_ids)
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, gpu_ids=self.gpu_ids, device=self.device_type)
if self.isTrain: # only defined during training time
# define your loss functions. You can use losses provided by torch.nn such as torch.nn.L1Loss.
# We also provide a GANLoss class "networks.GANLoss". self.criterionGAN = networks.GANLoss().to(self.device)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self, opt):
# specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
self.model_names = ['G' + opt.model_suffix] # only generator is needed.
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG,
opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, self.device_type)

# assigns the model to self.netG_[suffix] so that it can be loaded
# please see <BaseModel.load_networks>
Expand Down
Loading

0 comments on commit 9d68c51

Please sign in to comment.