-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_select_lib.py
84 lines (56 loc) · 1.38 KB
/
data_select_lib.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from ptb import *
def stratified_select(num, cat, data):
index = 'a'
if cat == 'key':
index = 2
elif cat == 'time':
index = 1
dic = {}
for i in range(len(data)):
piece = data[i]['input']
tok = piece[index]
if tok in dic:
if len(dic[tok]) < num:
dic[tok].append(i)
else:
dic[tok] = [i]
return dic
def stratified_bar_select(num, cat, data):
dic = {}
for i in range(len(data)):
tok = data[i][cat]
if tok in dic:
if len(dic[tok]) < num:
dic[tok].append(i)
else:
dic[tok] = [i]
return dic
def index_stratified_bar_select(num, index, data):
dic = {}
for i in range(len(data)):
tok = str(data[i]['input'][:index])
if tok in dic:
if len(dic[tok]) < num:
dic[tok].append(i)
else:
dic[tok] = [i]
return dic
'''
# load tunes from split
data = PTB(
data_dir='data',
create_data=False,
split='test',
max_sequence_length=256,
data_prefix='data_v2_cleaned',
conditioned=False,
bars=False
)
d = stratified_select(100, 'key', data)
print(d)
with open(f'data/data_v2_cleaned.vocab.json', 'r') as file:
vocab = json.load(file)
w2i, i2w = vocab['w2i'], vocab['i2w']
for x in d.keys():
print(i2w[str(x)])
'''