-
Notifications
You must be signed in to change notification settings - Fork 3
/
reshape-predict.py
67 lines (54 loc) · 2.22 KB
/
reshape-predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""
predictをreshape
pose => pose_reshaped
(block,192,64) => (frames,192)
example
(2,192,64) => (128,192)
"""
#!/usr/bin/python
# -*- coding: utf-8 -*-
import numpy as np
import os
import argparse
from distutils.util import strtobool
import glob
parser = argparse.ArgumentParser(description='speech to gesture by PyTorch')
parser.add_argument('--denorm', '-d', type=strtobool, default=1, help='denorm(1) or not denorm(0)')
parser.add_argument('--denormpath', type=str, default="./norm/", help='denorm path')
parser.add_argument('--datatype', type=str, default="train",help='denorm datatype(train or dev)')
parser.add_argument('--npypath', type=str, default="./test_inputs/*20200404-182816_weights.npy",help='npyfiles path')
parser.add_argument('--outpath', type=str, default="./predict_reshaped/" ,help='out path')
args = parser.parse_args()
def main():
files = glob.glob(args.npypath)
print("処理ファイル数",len(files))
os.makedirs(args.outpath, exist_ok=True)
for filename in files:
print("process file...",filename)
if os.path.exists(filename):
predict = np.load(filename)
pose_seq = np.array([])
for i,data in enumerate(predict):
data_transpose = np.transpose(data)
if i == 0:
pose_seq = data_transpose
else:
pose_seq = np.append(pose_seq,data_transpose,axis=0)
# 学習データを標準化していたならば、元に戻す(denorm)
if args.denorm:
pose_seq = data_denorm(pose_seq)
filename2 = (filename.split("/")[-1]).split("-")[0]
# save
if args.denorm:
np.savetxt(args.outpath+"{}_posegan-denorm.txt".format(filename2),pose_seq)
else:
np.savetxt(args.outpath+"{}_posegan.txt".format(filename2),pose_seq)
def data_denorm(pose_seq):
print("-denorm-")
ave = np.load(args.denormpath + "ave_"+args.datatype+"_posegan.npy")
std = np.load(args.denormpath + "std_"+args.datatype+"_posegan.npy")
pose_seq_denorm = pose_seq*std+ave
return pose_seq_denorm
if __name__ == '__main__':
main()
print("--complete--")