Skip to content

Commit

Permalink
pytorch_CycleGAN_and_pix2pix: support xpu devices
Browse files Browse the repository at this point in the history
  • Loading branch information
weishi-deng committed Dec 13, 2023
1 parent 7de2aed commit 340bd32
Show file tree
Hide file tree
Showing 7 changed files with 26 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def __init__(self, test, device, batch_size=None, extra_args=[]):
device_arg = "--gpu_ids -1"
elif self.device == "cuda":
device_arg = "--gpu_ids 0"
elif self.device == "xpu":
device_arg = "--gpu_ids -2"
if self.test == "train":
train_args = f"--tb_device {self.device} --dataroot {data_root}/datasets/horse2zebra --name horse2zebra --model cycle_gan --display_id 0 --n_epochs 3 " + \
f"--n_epochs_decay 3 {device_arg} {checkpoints_arg}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,13 @@ def __init__(self, opt):
"""
self.opt = opt
self.gpu_ids = opt.gpu_ids
self.use_xpu = opt.use_xpu
self.isTrain = opt.isTrain
self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
if self.use_xpu:
self.device = torch.device('xpu')
print("using xpu")
else:
self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
torch.backends.cudnn.benchmark = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def init_func(m): # define the initialization function
net.apply(init_func) # apply the initialization function <init_func>


def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], use_xpu=False):
"""Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
Parameters:
net (network) -- the network to be initialized
Expand All @@ -111,11 +111,13 @@ def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
assert(torch.cuda.is_available())
net.to(gpu_ids[0])
net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
if use_xpu:
net.to('xpu')
init_weights(net, init_type, init_gain=init_gain)
return net


def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]):
def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[], use_xpu=False):
"""Create a generator
Parameters:
Expand All @@ -128,6 +130,7 @@ def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, in
init_type (str) -- the name of our initialization method.
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
use_xpu (bool) -- if run on xpu devices
Returns a generator
Expand Down Expand Up @@ -155,10 +158,10 @@ def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, in
net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
else:
raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
return init_net(net, init_type, init_gain, gpu_ids)
return init_net(net, init_type, init_gain, gpu_ids, use_xpu)


def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]):
def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[], use_xpu=False):
"""Create a discriminator
Parameters:
Expand All @@ -170,6 +173,7 @@ def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal'
init_type (str) -- the name of the initialization method.
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
use_xpu (bool) -- if run on xpu devices
Returns a discriminator
Expand Down Expand Up @@ -199,7 +203,7 @@ def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal'
net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)
else:
raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD)
return init_net(net, init_type, init_gain, gpu_ids)
return init_net(net, init_type, init_gain, gpu_ids, use_xpu)


##############################################################################
Expand Down Expand Up @@ -281,7 +285,7 @@ def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', const
netD (network) -- discriminator network
real_data (tensor array) -- real images
fake_data (tensor array) -- generated images from the generator
device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
device (str) -- GPU / CPU /XPU
type (str) -- if we mix real and fake data or not [real | fake | mixed].
constant (float) -- the constant used in formula ( ||gradient||_2 - constant)^2
lambda_gp (float) -- weight for this loss
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self, opt):
self.model_names = ['G']
# define networks (both generator and discriminator)
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, self.use_xpu)

if self.isTrain: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(self, opt):
# you can use opt.isTrain to specify different behaviors for training and test. For example, some networks will not be used during test, and you don't need to load them.
self.model_names = ['G']
# define networks; you can use opt.isTrain to specify different behaviors for training and test.
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, gpu_ids=self.gpu_ids)
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, gpu_ids=self.gpu_ids, use_xpu=self.use_xpu)
if self.isTrain: # only defined during training time
# define your loss functions. You can use losses provided by torch.nn such as torch.nn.L1Loss.
# We also provide a GANLoss class "networks.GANLoss". self.criterionGAN = networks.GANLoss().to(self.device)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self, opt):
# specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
self.model_names = ['G' + opt.model_suffix] # only generator is needed.
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG,
opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, self.use_xpu)

# assigns the model to self.netG_[suffix] so that it can be loaded
# please see <BaseModel.load_networks>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def initialize(self, parser):
# basic parameters
parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')
parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU, use -2 for XPU')
parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
# model parameters
parser.add_argument('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')
Expand Down Expand Up @@ -123,11 +123,14 @@ def parse(self, args=None):
# set gpu ids
str_ids = opt.gpu_ids.split(',')
opt.gpu_ids = []
opt.use_xpu = False
for str_id in str_ids:
id = int(str_id)
if id >= 0:
opt.gpu_ids.append(id)
if len(opt.gpu_ids) > 0:
if id == -2:
opt.use_xpu = True
if len(opt.gpu_ids) > 0 and torch.cuda.avaliable():
torch.cuda.set_device(opt.gpu_ids[0])

self.opt = opt
Expand Down

0 comments on commit 340bd32

Please sign in to comment.