-
Notifications
You must be signed in to change notification settings - Fork 83
/
Copy pathcreate_validation.py
100 lines (79 loc) · 3.3 KB
/
create_validation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import random
import argparse
import os
import shutil
random.seed()
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--path', default='trajdata',
help='glob expression for data files')
parser.add_argument('--val_ratio', default=0.2, type=float,
help='sample ratio of val set given the train set')
args = parser.parse_args()
args.path = 'DATA_BLOCK/' + args.path
## Prepare destination folder containing dataset with train and val split
args.dest_path = args.path + '_split'
if not os.path.exists(args.dest_path):
os.makedirs(args.dest_path)
if not os.path.exists('{}/train/'.format(args.dest_path)):
os.makedirs('{}/train/'.format(args.dest_path))
if not os.path.exists('{}/val/'.format(args.dest_path)):
os.makedirs('{}/val/'.format(args.dest_path))
## List train file names
files = [f.split('.')[-2] for f in os.listdir(args.path + '/train/') if f.endswith('.ndjson')]
print(files)
# exit()
for file in files:
## Read All Samples
orig_train_file = args.path + '/train/' + file + '.ndjson'
with open(orig_train_file, "r") as f:
lines = f.readlines()
## Split scenes into train and val
train_file = open(args.dest_path + '/train/' + file + ".ndjson", "w")
val_file = open(args.dest_path + '/val/' + file + ".ndjson", "w")
for line in lines:
## Sample Scenes
if 'scene' in line:
if random.random() < args.val_ratio:
val_file.write(line)
else:
train_file.write(line)
continue
## Write All tracks
train_file.write(line)
val_file.write(line)
train_file.close()
val_file.close()
# ## Assert val folder does not exist
# if os.path.isdir(args.path + '/val'):
# print("Validation folder already exists")
# exit()
if __name__ == '__main__':
main()
# ## Iterate over file names
# for file in files:
# with open("DATA_BLOCK/honda/test/honda_v1.ndjson", "r") as f
# reader = trajnetplusplustools.Reader(path + '/train/' + file + '.ndjson', scene_type='paths')
# ## Necessary modification of train scene to add filename
# scene = [(file, s_id, s) for s_id, s in reader.scenes()]
# all_scenes += scene
# with open("test_dummy1.ndjson", "w") as f:
# for line in lines:
# if 'scene' in line and random.random() < 0.8:
# continue
# f.write(line)
# ## read goal files
# all_goals = {}
# all_scenes = []
# ## List file names
# files = [f.split('.')[-2] for f in os.listdir(path + subset) if f.endswith('.ndjson')]
# ## Iterate over file names
# for file in files:
# reader = trajnetplusplustools.Reader(path + subset + file + '.ndjson', scene_type='paths')
# ## Necessary modification of train scene to add filename
# scene = [(file, s_id, s) for s_id, s in reader.scenes(sample=sample)]
# if goals:
# goal_dict = pickle.load(open('goal_files/' + subset + file +'.pkl', "rb"))
# ## Get goals corresponding to train scene
# all_goals[file] = {s_id: [goal_dict[path[0].pedestrian] for path in s] for _, s_id, s in scene}
# all_scenes += scene