-
Notifications
You must be signed in to change notification settings - Fork 1
/
datamodule.py
67 lines (56 loc) · 1.87 KB
/
datamodule.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
from utils import TextProcess
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
class VivosDataModule(pl.LightningDataModule):
def __init__(
self,
trainset: Dataset,
testset: Dataset,
text_process: TextProcess,
batch_size: int,
):
super().__init__()
self.trainset = trainset
self.valset = testset
self.testset = testset
self.batch_size = batch_size
self.text_process = text_process
def train_dataloader(self):
return DataLoader(
self.trainset,
batch_size=self.batch_size,
collate_fn=self._collate_fn,
shuffle=True,
pin_memory=True,
)
def val_dataloader(self):
return DataLoader(
self.valset,
batch_size=self.batch_size,
collate_fn=self._collate_fn,
pin_memory=True,
)
def test_dataloader(self):
return DataLoader(
self.testset,
batch_size=self.batch_size,
collate_fn=self._collate_fn,
pin_memory=True,
)
def _collate_fn(self, batch):
"""
Take feature and input, transform and then padding it
"""
specs = [i[0] for i in batch]
input_lengths = torch.LongTensor([i.size(0) for i in specs])
trans = [i[1] for i in batch]
# batch, time, feature
specs = torch.nn.utils.rnn.pad_sequence(specs, batch_first=True)
specs = specs.unsqueeze(1) # batch, channel, time, feature
trans = [self.text_process.text2int(s) for s in trans]
target_lengths = torch.LongTensor([s.size(0) for s in trans])
trans = torch.nn.utils.rnn.pad_sequence(trans, batch_first=True).to(
dtype=torch.long
)
return specs, input_lengths, trans, target_lengths