diff --git a/lpips_2dirs.py b/lpips_2dirs.py index f0857e04..0d92e1f9 100644 --- a/lpips_2dirs.py +++ b/lpips_2dirs.py @@ -5,6 +5,7 @@ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('-d0','--dir0', type=str, default='./imgs/ex_dir0') parser.add_argument('-d1','--dir1', type=str, default='./imgs/ex_dir1') +parser.add_argument('-n','--net', type=str, default='alex') # Can be 'alex', 'vgg' parser.add_argument('-o','--out', type=str, default='./imgs/example_dists.txt') parser.add_argument('-v','--version', type=str, default='0.1') parser.add_argument('--use_gpu', action='store_true', help='turn on flag to use GPU') @@ -12,7 +13,7 @@ opt = parser.parse_args() ## Initializing the model -loss_fn = lpips.LPIPS(net='alex',version=opt.version) +loss_fn = lpips.LPIPS(net=opt.net,version=opt.version) if(opt.use_gpu): loss_fn.cuda()