forked from ycjungSubhuman/DeepDeformable3DCaricatures
-
Notifications
You must be signed in to change notification settings - Fork 1
/
latent_manipulation_interfacegan.py
executable file
·149 lines (119 loc) · 5.9 KB
/
latent_manipulation_interfacegan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import sys
import os
sys.path.append( os.path.dirname( os.path.dirname( os.path.abspath(__file__) ) ) )
import yaml
import tqdm
import io
import numpy as np
import dataset_caricshop3d as dataset
import training_loop_surface as training_loop
import utils, loss, modules, meta_modules
import torch
from torch.utils.data import DataLoader
import configargparse
from torch import nn
from surface_net import SurfaceDeformationField
from surface_deformation import create_mesh_single
from attr.process_attr import load_dict
from attr.train_boundary import _train_boundary
from attr.helper.manipulator import linear_interpolate
#----------------------------------------------------------------------------
def _build_map(PATH):
"""
Build map: index --> path in the dataset
"""
result = {}
with open(PATH, 'r') as fin:
lines = fin.readlines()
for line in lines:
tokens = line.split()
value = ' '.join(tokens[:-2])
key = int(tokens[-1]) - 1
result[key] = value
return result
#----------------------------------------------------------------------------
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
p = configargparse.ArgumentParser()
p.add_argument('--config', required=True, is_config_file=True, help='Evaluation configuration')
p.add_argument('--dir_caricshop', type=str,default='./3dcaricshop', help='3DCaricShop dataset root')
p.add_argument('--logging_root', type=str, default='./logs', help='root for logging')
p.add_argument('--summary_root', type=str, default='./summaries', help='root for summary')
p.add_argument('--checkpoint_path', type=str, default='', help='checkpoint to use for eval')
p.add_argument('--experiment_name', type=str, default='default',
help='Name of subdirectory in logging_root where summaries and checkpoints will be saved.')
# General training options
p.add_argument('--batch_size', type=int, default=256, help='training batch size.')
p.add_argument('--lr', type=float, default=1e-4, help='learning rate. default=1e-4')
p.add_argument('--epochs', type=int, default=8000, help='Number of epochs to train for.')
p.add_argument('--epochs_til_checkpoint', type=int, default=10,
help='Time interval in seconds until checkpoint is saved.')
p.add_argument('--steps_til_summary', type=int, default=100,
help='Time interval in seconds until tensorboard summary is saved.')
p.add_argument('--model_type', type=str, default='sine',
help='Options are "sine" (all sine activations) and "mixed" (first layer sine, other layers tanh)')
p.add_argument('--latent_dim', type=int,default=128, help='latent code dimension.')
p.add_argument('--hidden_num', type=int,default=128, help='hidden layer dimension of deform-net.')
p.add_argument('--num_hidden_layers', type=int,default=3, help='number of hidden layers of deform-net.')
p.add_argument('--hyper_hidden_layers', type=int,default=1, help='number of hidden layers hyper-net.')
p.add_argument('--attr_index', type=int, default=1, help='index of target attribute')
p.add_argument('--start_distance', type=float, default=-0.01, help='Start point for manipulation in latent space. (default: -3.0)')
p.add_argument('--end_distance', type=float, default=0.01, help='End point for manipulation in latent space. (default: 3.0)')
p.add_argument('--steps', type=int, default=11, help='Number of steps for image editing. (default: 10)')
# load configs
opt = p.parse_args()
meta_params = vars(opt)
# define DIF-Net
model = SurfaceDeformationField(1268, **meta_params)
model.load_state_dict(torch.load(meta_params['checkpoint_path']))
# The network should be fixed for evaluation.
if hasattr(model, 'hyper_net'):
for param in model.hyper_net.parameters():
param.requires_grad = False
model.cuda()
# create save path
root_path = os.path.join(meta_params['logging_root'], meta_params['experiment_name'])
utils.cond_mkdir(root_path)
attr_dict = load_dict("./attr/attr_data/attr_dict.pkl")
attr_list = load_dict("./attr/attr_data/attr_list.pkl")
latent_codes = []
attr_scores = []
attr_index = opt.attr_index
print(f"A target attribute is {attr_list[attr_index]}.")
trainset_length = 1268
db_path = './sort_info.txt'
map = _build_map(db_path)
for i in range(trainset_length):
path_name = map[i]
id_name = os.path.dirname(path_name)
f_name = os.path.splitext(os.path.basename(path_name))[0]
latent = model.latent_codes(torch.Tensor([i]).long().cuda())
latent_codes.append(latent.detach().cpu().squeeze().numpy())
attr_scores.append(attr_dict[id_name][f_name][attr_index])
latent_codes = np.array(latent_codes)
attr_scores = np.expand_dims(np.array(attr_scores), axis=1)
try:
boundary = _train_boundary(f"./attr/attr_data/{attr_list[attr_index]}_boundary", latent_codes, attr_scores, split_ratio=0.8)
except:
boundary = np.load(f'./attr/attr_data/{attr_list[attr_index]}_boundary/boundary.npy')
print(np.linalg.norm(boundary, ord=2))
caric = dataset.CaricShop3D(meta_params['dir_caricshop'], skip=True)
for i in range(100):
dir_path = os.path.join(root_path, str(i))
ckpt_path = os.path.join(dir_path, "checkpoints")
utils.cond_mkdir(ckpt_path)
latent = model.latent_codes(torch.Tensor([i]).long().cuda())
latent = latent.detach().cpu().squeeze().unsqueeze(0).numpy()
interpolations = linear_interpolate(latent,
boundary,
start_distance=opt.start_distance,
end_distance=opt.end_distance,
steps=opt.steps)
for j, intp in enumerate(interpolations) :
distance = opt.start_distance + (opt.end_distance - opt.start_distance) * j / (opt.steps - 1)
create_mesh_single(
model,
os.path.join(ckpt_path, f'{(i):04d}_{attr_list[attr_index]}_{distance:.4f}.obj'),
torch.Tensor(caric.V_ref),
caric.F,
embedding=torch.Tensor(intp).cuda(),
)