-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathgenerate_test_data.py
64 lines (52 loc) · 2.27 KB
/
generate_test_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# generate data for ROC plot visualization
# as well as test bench during hls4ml c-simulation
import common as com
import os
import numpy as np
import argparse
from tqdm import tqdm
def main(args):
# load parameter.yaml
param = com.yaml_load(args.config)
convert = param["convert"]
param = param["train"]
# make output result directory
os.makedirs(param["result_directory"], exist_ok=True)
# load base directory
target_dir = com.select_dirs(param=param)[0]
print("target dir is {}".format(target_dir))
machine_type = os.path.split(target_dir)[1]
print(machine_type)
machine_id_list = com.get_machine_id_list_for_test(target_dir)
print(machine_id_list)
X = []
y = []
for id_str in machine_id_list:
# load test file
X_machine_data = []
y_machine_data = []
test_files, y_true = com.test_file_list_generator(target_dir, id_str)
print("\n============== CREATING TEST DATA FOR A MACHINE ID ==============")
for file_idx, file_path in tqdm(enumerate(test_files), total=len(test_files)):
data = com.file_to_vector_array(file_path,
n_mels=param["feature"]["n_mels"],
frames=param["feature"]["frames"],
n_fft=param["feature"]["n_fft"],
hop_length=param["feature"]["hop_length"],
power=param["feature"]["power"],
downsample=param["feature"]["downsample"])
X_machine_data.append(data)
X.append(X_machine_data)
y.append(y_true)
#save test_data
if not os.path.exists('test_data/anomaly_detection/'):
os.makedirs('test_data/anomaly_detection/')
np.save(convert['x_npy_plot_roc'],X)
np.save(convert['y_npy_plot_roc'],y)
np.save(convert['x_npy_test_bench'],X[0][0][0:10])
np.save(convert['y_npy_test_bench'],y[0][0:10])
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, default="baseline.yml", help="specify yaml config")
args = parser.parse_args()
main(args)