-
Notifications
You must be signed in to change notification settings - Fork 1
/
datasets.py
120 lines (102 loc) · 4.23 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from torch.utils.data import Dataset
from json import load, dump, loads
from random import shuffle
from torch import tensor
from itertools import chain
from tqdm import tqdm
class TextualDataset(Dataset):
def __init__(self, texts, tokenizer):
print("encoding")
self.examples = []
for example in tqdm(texts, total=len(texts)):
example = example.replace("\n\n", "")
x = tokenizer.encode(example, add_special_tokens=False)
self.examples += [x]
def __len__(self):
return len(self.examples)
def __getitem__(self, i):
return tensor(self.examples[i])
def __jdump__(self, path):
print("saving")
with open(path, "w") as jp:
for x in tqdm(self.examples, total=self.__len__()):
dump(x, jp)
jp.write("\n")
class EncodedFiles2Dataset(Dataset):
def __init__(self, path, files, shfl=True, trim=None, block=None, shfl_files=False, eos=2):
self.examples = []
if shfl_files:
shuffle(files)
for file in files:
with open(path + file, "r") as jf:
if block:
self.examples += list(self.split(list(chain(*load(jf))), block))
if not block and trim:
for x in tqdm(jf.readlines()):
example = []
buffer = []
for y in loads(x):
buffer.append(y)
if y == eos:
if len(buffer) > trim:
if example:
self.examples.append(example)
example = []
self.examples.append(buffer[:trim])
buffer = []
elif (len(example) + len(buffer)) > trim:
self.examples.append(example)
example = buffer
buffer = []
else:
example += buffer
buffer = []
if len(buffer) > trim:
if example:
self.examples.append(example)
self.examples.append(buffer[:trim])
elif (len(example) + len(buffer)) > trim:
self.examples.append(example)
self.examples.append(buffer)
else:
example += buffer
if example:
self.examples.append(example)
if not block and not trim:
self.examples += load(jf)
if shfl:
shuffle(self.examples)
def __len__(self):
return len(self.examples)
def split(self, list_a, chunk_size):
for i in range(0, len(list_a), chunk_size):
yield list_a[i:i + chunk_size]
def __getitem__(self, i):
return tensor(self.examples[i])
def __jdump__(self, path):
with open(path, "w") as jp:
dump(self.examples, jp)
def __jdumpwsplit__(self, path, dev_ratio=0.01, name=""):
split_line = round(self.__len__() * dev_ratio)
with open(path + "dev" + name + ".jsonl", "w") as jp:
for x in self.examples[:split_line]:
dump(x, jp)
jp.write("\n")
with open(path + "train" + name + ".jsonl", "w") as jp:
for x in self.examples[split_line:]:
dump(x, jp)
jp.write("\n")
class JsonDataset(Dataset):
def __init__(self, jpath):
if isinstance(jpath, str):
with open(jpath, "r", encoding="utf-8") as jf:
self.examples = list(jf)
else:
self.examples = []
for jp in jpath:
with open(jp, "r", encoding="utf-8") as jf:
self.examples += list(jf)
def __len__(self):
return len(self.examples)
def __getitem__(self, i):
return tensor(loads(self.examples[i])).long()