-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataloader.py
38 lines (34 loc) · 1.16 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizer
from data import KorQuadDataset
import json
# Get torch dataloader from KorQuad dataset
class KorQuadDataLoader(object):
def __init__(
self,
tokenizer: PreTrainedTokenizer,
max_seq_length: int = None,
doc_stride: int = None,
max_query_length: int = None,
):
self.tokenizer = tokenizer
self.max_seq_length = (
max_seq_length if max_seq_length else self.tokenizer.model_max_length
)
assert doc_stride is not None
assert max_query_length is not None
self.doc_stride = doc_stride
self.max_query_length = max_query_length
def read_json(self, path):
with open(path) as f:
return json.load(f)
def get_dataloader(self, file_path, batch_size, **kwargs):
data = self.read_json(file_path)
dataset = KorQuadDataset(
data,
self.tokenizer,
self.max_seq_length,
self.doc_stride,
self.max_query_length,
)
return DataLoader(dataset, batch_size=batch_size, shuffle=False, **kwargs)