forked from nobody132/masr
-
Notifications
You must be signed in to change notification settings - Fork 0
/
data.py
97 lines (78 loc) · 2.7 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
import librosa
import wave
import numpy as np
import scipy
import json
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
sample_rate = 16000
window_size = 0.02
window_stride = 0.01
n_fft = int(sample_rate * window_size)
win_length = n_fft
hop_length = int(sample_rate * window_stride)
window = "hamming"
def load_audio(wav_path, normalize=True): # -> numpy array
with wave.open(wav_path) as wav:
wav = np.frombuffer(wav.readframes(wav.getnframes()), dtype="int16")
wav = wav.astype("float")
if normalize:
return (wav - wav.mean()) / wav.std()
else:
return wav
def spectrogram(wav, normalize=True):
D = librosa.stft(
wav, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window
)
spec, phase = librosa.magphase(D)
spec = np.log1p(spec)
spec = torch.FloatTensor(spec)
if normalize:
spec = (spec - spec.mean()) / spec.std()
return spec
class MASRDataset(Dataset):
def __init__(self, index_path, labels_path):
with open(index_path) as f:
idx = f.readlines()
idx = [x.strip().split(",", 1) for x in idx]
self.idx = idx
with open(labels_path) as f:
labels = json.load(f)
self.labels = dict([(labels[i], i) for i in range(len(labels))])
self.labels_str = labels
def __getitem__(self, index):
wav, transcript = self.idx[index]
wav = load_audio(wav)
spect = spectrogram(wav)
transcript = list(filter(None, [self.labels.get(x) for x in transcript]))
return spect, transcript
def __len__(self):
return len(self.idx)
def _collate_fn(batch):
def func(p):
return p[0].size(1)
batch = sorted(batch, key=lambda sample: sample[0].size(1), reverse=True)
longest_sample = max(batch, key=func)[0]
freq_size = longest_sample.size(0)
minibatch_size = len(batch)
max_seqlength = longest_sample.size(1)
inputs = torch.zeros(minibatch_size, freq_size, max_seqlength)
input_lens = torch.IntTensor(minibatch_size)
target_lens = torch.IntTensor(minibatch_size)
targets = []
for x in range(minibatch_size):
sample = batch[x]
tensor = sample[0]
target = sample[1]
seq_length = tensor.size(1)
inputs[x].narrow(1, 0, seq_length).copy_(tensor)
input_lens[x] = seq_length
target_lens[x] = len(target)
targets.extend(target)
targets = torch.IntTensor(targets)
return inputs, targets, input_lens, target_lens
class MASRDataLoader(DataLoader):
def __init__(self, *args, **kwargs):
super(MASRDataLoader, self).__init__(*args, **kwargs)
self.collate_fn = _collate_fn