-
Notifications
You must be signed in to change notification settings - Fork 25
/
preprocess.py
65 lines (48 loc) · 1.62 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
import sys
import json
import pickle
import nltk
import tqdm
from torchvision import transforms
from PIL import Image
from transforms import Scale
def process_question(root, split, word_dic=None, answer_dic=None):
if word_dic is None:
word_dic = {}
if answer_dic is None:
answer_dic = {}
with open(os.path.join(root, 'questions',
f'CLEVR_{split}_questions.json')) as f:
data = json.load(f)
result = []
word_index = 1
answer_index = 0
for question in tqdm.tqdm(data['questions']):
words = nltk.word_tokenize(question['question'])
question_token = []
for word in words:
try:
question_token.append(word_dic[word])
except:
question_token.append(word_index)
word_dic[word] = word_index
word_index += 1
answer_word = question['answer']
try:
answer = answer_dic[answer_word]
except:
answer = answer_index
answer_dic[answer_word] = answer_index
answer_index += 1
result.append((question['image_filename'], question_token, answer,
question['question_family_index']))
with open(f'data/{split}.pkl', 'wb') as f:
pickle.dump(result, f)
return word_dic, answer_dic
if __name__ == '__main__':
root = sys.argv[1]
word_dic, answer_dic = process_question(root, 'train')
process_question(root, 'val', word_dic, answer_dic)
with open('data/dic.pkl', 'wb') as f:
pickle.dump({'word_dic': word_dic, 'answer_dic': answer_dic}, f)