diff --git a/wenet/ssl/init_dataset.py b/wenet/ssl/init_dataset.py index 10072a5c2..67417ec6b 100644 --- a/wenet/ssl/init_dataset.py +++ b/wenet/ssl/init_dataset.py @@ -1,10 +1,61 @@ +from collections.abc import Callable from functools import partial import sys +from typing import List import torch from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import IterDataPipe, functional_datapipe from wenet.dataset import processor -from wenet.dataset.datapipes import WenetRawDatasetSource, WenetTarShardDatasetSource +from wenet.dataset.datapipes import (WenetRawDatasetSource, + WenetTarShardDatasetSource) + + +@functional_datapipe("pack_speech") +class PackSpeechDatapipe(IterDataPipe): + + def __init__( + self, + dataset: IterDataPipe, + length_fn: Callable, + merge_speech_fn: Callable, + max_length: int = 30000, + ) -> None: + super().__init__() + self.dp = dataset + self._iter = None + self.length_fn = length_fn + self.max_length = max_length + self.merge_fn = merge_speech_fn + self.buf = [] + + self.length = 0 + + def __iter__(self): + for elem in self.dp: + elem_length = self.length_fn(elem) + current_length = self.length + elem_length + if current_length >= self.max_length: + long_elem = self.merge_fn(self.buf) + yield long_elem + del self.buf + self.buf = [] + self.length = 0 + self.buf.append(elem) + self.length += elem_length + if len(self.buf) > 0: + yield self.merge_fn(self.buf) + del self.buf + self.buf = [] + self.length = 0 + + +def cat_speech(buffer: List): + assert len(buffer) > 0 + waves = [sample['wav'] for sample in buffer] + sample_rate = buffer[0]['sample_rate'] + wav = torch.cat(waves, dim=1) + return {"wav": wav, "sample_rate": sample_rate} def padding(data): @@ -39,6 +90,11 @@ def padding(data): return batch +def wav_length_fn(sample): + wav = sample['wav'] + return wav.size(1) + + def Dataset(data_type, data_list_file, conf=None, partition=True): """ Construct dataset from arguments for ssl model @@ -81,6 +137,10 @@ def Dataset(data_type, data_list_file, conf=None, partition=True): dataset = dataset.map( partial(processor.singal_channel, **singal_channel_conf)) + pack_conf = conf.get('pack_conf', {}) + if pack_conf: + dataset = dataset.pack_speech(wav_length_fn, cat_speech, + pack_conf['max_speech_length']) filter_conf = conf.get('filter_conf', {}) dataset = dataset.filter(partial(processor.filter, **filter_conf))