From 812ae1a0cf404984c8f7677cd85050d721de4fa9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Wed, 31 Jan 2024 16:20:27 +0800 Subject: [PATCH 1/9] funasr1.0.5 --- .../conformer/demo.py | 13 +++ .../conformer/infer.sh | 11 +++ funasr/models/conformer/template.yaml | 12 +-- funasr/models/transformer/model.py | 87 +++++++++---------- funasr/models/transformer/search.py | 2 +- funasr/models/transformer/template.yaml | 1 - 6 files changed, 74 insertions(+), 52 deletions(-) create mode 100644 examples/industrial_data_pretraining/conformer/demo.py create mode 100644 examples/industrial_data_pretraining/conformer/infer.sh diff --git a/examples/industrial_data_pretraining/conformer/demo.py b/examples/industrial_data_pretraining/conformer/demo.py new file mode 100644 index 000000000..358a1f800 --- /dev/null +++ b/examples/industrial_data_pretraining/conformer/demo.py @@ -0,0 +1,13 @@ +#!/usr/bin/env python3 +# -*- encoding: utf-8 -*- +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + +from funasr import AutoModel + +model = AutoModel(model="iic/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch", model_revision="v2.0.4", + ) + +res = model.generate(input="https://modelscope.oss-cn-beijing.aliyuncs.com/test/audios/asr_example.wav") +print(res) + diff --git a/examples/industrial_data_pretraining/conformer/infer.sh b/examples/industrial_data_pretraining/conformer/infer.sh new file mode 100644 index 000000000..c259799f3 --- /dev/null +++ b/examples/industrial_data_pretraining/conformer/infer.sh @@ -0,0 +1,11 @@ + +model="iic/speech_conformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch" +model_revision="v2.0.4" + +python funasr/bin/inference.py \ ++model=${model} \ ++model_revision=${model_revision} \ ++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav" \ ++output_dir="./outputs/debug" \ ++device="cpu" \ + diff --git a/funasr/models/conformer/template.yaml b/funasr/models/conformer/template.yaml index 4cbeca46f..f646acc9d 100644 --- a/funasr/models/conformer/template.yaml +++ b/funasr/models/conformer/template.yaml @@ -6,8 +6,7 @@ # tables.print() # network architecture -#model: funasr.models.paraformer.model:Paraformer -model: Transformer +model: Conformer model_conf: ctc_weight: 0.3 lsm_weight: 0.1 # label smoothing option @@ -16,14 +15,14 @@ model_conf: # encoder encoder: ConformerEncoder encoder_conf: - output_size: 256 # dimension of attention + output_size: 256 attention_heads: 4 - linear_units: 2048 # the number of units of position-wise feed forward - num_blocks: 12 # the number of encoder blocks + linear_units: 2048 + num_blocks: 12 dropout_rate: 0.1 positional_dropout_rate: 0.1 attention_dropout_rate: 0.0 - input_layer: conv2d # encoder architecture type + input_layer: conv2d normalize_before: true pos_enc_layer_type: rel_pos selfattention_layer_type: rel_selfattn @@ -52,6 +51,7 @@ frontend_conf: n_mels: 80 frame_length: 25 frame_shift: 10 + dither: 0.0 lfr_m: 1 lfr_n: 1 diff --git a/funasr/models/transformer/model.py b/funasr/models/transformer/model.py index 7e40060dc..4ad466b4f 100644 --- a/funasr/models/transformer/model.py +++ b/funasr/models/transformer/model.py @@ -24,18 +24,16 @@ class Transformer(nn.Module): def __init__( self, - frontend: Optional[str] = None, - frontend_conf: Optional[Dict] = None, - specaug: Optional[str] = None, - specaug_conf: Optional[Dict] = None, + specaug: str = None, + specaug_conf: dict = None, normalize: str = None, - normalize_conf: Optional[Dict] = None, + normalize_conf: dict = None, encoder: str = None, - encoder_conf: Optional[Dict] = None, + encoder_conf: dict = None, decoder: str = None, - decoder_conf: Optional[Dict] = None, + decoder_conf: dict = None, ctc: str = None, - ctc_conf: Optional[Dict] = None, + ctc_conf: dict = None, ctc_weight: float = 0.5, interctc_weight: float = 0.0, input_size: int = 80, @@ -59,20 +57,17 @@ def __init__( super().__init__() - if frontend is not None: - frontend_class = tables.frontend_classes.get_class(frontend) - frontend = frontend_class(**frontend_conf) if specaug is not None: - specaug_class = tables.specaug_classes.get_class(specaug) + specaug_class = tables.specaug_classes.get(specaug) specaug = specaug_class(**specaug_conf) if normalize is not None: - normalize_class = tables.normalize_classes.get_class(normalize) + normalize_class = tables.normalize_classes.get(normalize) normalize = normalize_class(**normalize_conf) - encoder_class = tables.encoder_classes.get_class(encoder) + encoder_class = tables.encoder_classes.get(encoder) encoder = encoder_class(input_size=input_size, **encoder_conf) encoder_output_size = encoder.output_size() if decoder is not None: - decoder_class = tables.decoder_classes.get_class(decoder) + decoder_class = tables.decoder_classes.get(decoder) decoder = decoder_class( vocab_size=vocab_size, encoder_output_size=encoder_output_size, @@ -93,7 +88,6 @@ def __init__( self.vocab_size = vocab_size self.ignore_id = ignore_id self.ctc_weight = ctc_weight - self.frontend = frontend self.specaug = specaug self.normalize = normalize self.encoder = encoder @@ -338,6 +332,7 @@ def init_beam_search(self, ) token_list = kwargs.get("token_list") scorers.update( + decoder=self.decoder, length_bonus=LengthBonus(len(token_list)), ) @@ -348,14 +343,14 @@ def init_beam_search(self, scorers["ngram"] = ngram weights = dict( - decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.0), - ctc=kwargs.get("decoding_ctc_weight", 0.0), + decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.5), + ctc=kwargs.get("decoding_ctc_weight", 0.5), lm=kwargs.get("lm_weight", 0.0), ngram=kwargs.get("ngram_weight", 0.0), length_bonus=kwargs.get("penalty", 0.0), ) beam_search = BeamSearch( - beam_size=kwargs.get("beam_size", 2), + beam_size=kwargs.get("beam_size", 10), weights=weights, scorers=scorers, sos=self.sos, @@ -364,17 +359,15 @@ def init_beam_search(self, token_list=token_list, pre_beam_score_key=None if self.ctc_weight == 1.0 else "full", ) - # beam_search.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval() - # for scorer in scorers.values(): - # if isinstance(scorer, torch.nn.Module): - # scorer.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval() + self.beam_search = beam_search - def generate(self, - data_in: list, - data_lengths: list=None, + def inference(self, + data_in, + data_lengths=None, key: list=None, tokenizer=None, + frontend=None, **kwargs, ): @@ -382,27 +375,34 @@ def generate(self, raise NotImplementedError("batch decoding is not implemented") # init beamsearch - is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None - is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None - if self.beam_search is None and (is_use_lm or is_use_ctc): + if self.beam_search is None: logging.info("enable beam_search") self.init_beam_search(**kwargs) self.nbest = kwargs.get("nbest", 1) - + meta_data = {} - # extract fbank feats - time1 = time.perf_counter() - audio_sample_list = load_audio_text_image_video(data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000)) - time2 = time.perf_counter() - meta_data["load_data"] = f"{time2 - time1:0.3f}" - speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=self.frontend) - time3 = time.perf_counter() - meta_data["extract_feat"] = f"{time3 - time2:0.3f}" - meta_data["batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000 - + if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank + speech, speech_lengths = data_in, data_lengths + if len(speech.shape) < 3: + speech = speech[None, :, :] + if speech_lengths is None: + speech_lengths = speech.shape[1] + else: + # extract fbank feats + time1 = time.perf_counter() + audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000), + data_type=kwargs.get("data_type", "sound"), + tokenizer=tokenizer) + time2 = time.perf_counter() + meta_data["load_data"] = f"{time2 - time1:0.3f}" + speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), + frontend=frontend) + time3 = time.perf_counter() + meta_data["extract_feat"] = f"{time3 - time2:0.3f}" + meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 + speech = speech.to(device=kwargs["device"]) speech_lengths = speech_lengths.to(device=kwargs["device"]) - # Encoder encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) if isinstance(encoder_out, tuple): @@ -439,14 +439,13 @@ def generate(self, token = tokenizer.ids2tokens(token_int) text = tokenizer.tokens2text(token) - text_postprocessed, _ = postprocess_utils.sentence_postprocess(token) - result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed} + # text_postprocessed, _ = postprocess_utils.sentence_postprocess(token) + result_i = {"key": key[i], "token": token, "text": text} results.append(result_i) if ibest_writer is not None: ibest_writer["token"][key[i]] = " ".join(token) ibest_writer["text"][key[i]] = text - ibest_writer["text_postprocessed"][key[i]] = text_postprocessed return results, meta_data diff --git a/funasr/models/transformer/search.py b/funasr/models/transformer/search.py index 39c4f8c48..ab7ac7d78 100644 --- a/funasr/models/transformer/search.py +++ b/funasr/models/transformer/search.py @@ -9,7 +9,7 @@ import torch -from funasr.metrics import end_detect +from funasr.metrics.common import end_detect from funasr.models.transformer.scorers.scorer_interface import PartialScorerInterface from funasr.models.transformer.scorers.scorer_interface import ScorerInterface diff --git a/funasr/models/transformer/template.yaml b/funasr/models/transformer/template.yaml index c9228f433..87814dc3b 100644 --- a/funasr/models/transformer/template.yaml +++ b/funasr/models/transformer/template.yaml @@ -6,7 +6,6 @@ # tables.print() # network architecture -#model: funasr.models.paraformer.model:Paraformer model: Transformer model_conf: ctc_weight: 0.3 From b2126521459e8be6a082f3f5c4bc46121c68cd90 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Wed, 31 Jan 2024 22:39:13 +0800 Subject: [PATCH 2/9] funasr1.0.5 audio samples input --- .../seaco_paraformer/demo.py | 20 ++++++++++++++++++- funasr/auto/auto_model.py | 2 +- funasr/models/transformer/model.py | 6 +++--- 3 files changed, 23 insertions(+), 5 deletions(-) diff --git a/examples/industrial_data_pretraining/seaco_paraformer/demo.py b/examples/industrial_data_pretraining/seaco_paraformer/demo.py index 065b698a3..e9e226d1c 100644 --- a/examples/industrial_data_pretraining/seaco_paraformer/demo.py +++ b/examples/industrial_data_pretraining/seaco_paraformer/demo.py @@ -15,8 +15,26 @@ # spk_model_revision="v2.0.2", ) + +# example1 res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", hotword='达摩院 魔搭', # sentence_timestamp=True, # return sentence level information when spk_model is not given ) -print(res) \ No newline at end of file +print(res) + +# example2 +import torchaudio +wav_file = os.path.join(model.model_path, "example/asr_example.wav") +input_tensor, sample_rate = torchaudio.load(wav_file) +input_tensor = input_tensor.mean(0) +res = model.generate(input=[input_tensor], batch_size_s=300, is_final=True) + + +# example3 +import soundfile +import os +wav_file = os.path.join(model.model_path, "example/asr_example.wav") +speech, sample_rate = soundfile.read(wav_file) +res = model.generate(input=[speech], batch_size_s=300, is_final=True) + diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py index 3986a110c..d99fc5613 100644 --- a/funasr/auto/auto_model.py +++ b/funasr/auto/auto_model.py @@ -228,7 +228,7 @@ def inference(self, input, input_len=None, model=None, kwargs=None, key=None, ** data_batch = data_list[beg_idx:end_idx] key_batch = key_list[beg_idx:end_idx] batch = {"data_in": data_batch, "key": key_batch} - if (end_idx - beg_idx) == 1 and isinstance(data_batch[0], torch.Tensor): # fbank + if (end_idx - beg_idx) == 1 and kwargs.get("data_type", None) == "fbank": # fbank batch["data_in"] = data_batch[0] batch["data_lengths"] = input_len diff --git a/funasr/models/transformer/model.py b/funasr/models/transformer/model.py index 4ad466b4f..e813e2205 100644 --- a/funasr/models/transformer/model.py +++ b/funasr/models/transformer/model.py @@ -439,13 +439,13 @@ def inference(self, token = tokenizer.ids2tokens(token_int) text = tokenizer.tokens2text(token) - # text_postprocessed, _ = postprocess_utils.sentence_postprocess(token) - result_i = {"key": key[i], "token": token, "text": text} + text_postprocessed, _ = postprocess_utils.sentence_postprocess(token) + result_i = {"key": key[i], "token": token, "text": text_postprocessed} results.append(result_i) if ibest_writer is not None: ibest_writer["token"][key[i]] = " ".join(token) - ibest_writer["text"][key[i]] = text + ibest_writer["text"][key[i]] = text_postprocessed return results, meta_data From ad6eafe7d07bc058d00556250ec08d0dbdb560a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Thu, 1 Feb 2024 16:26:55 +0800 Subject: [PATCH 3/9] batch_type token --- .../seaco_paraformer/demo.py | 3 ++- funasr/datasets/audio_datasets/samplers.py | 10 +++++++--- funasr/models/whisper/model.py | 8 ++++++-- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/examples/industrial_data_pretraining/seaco_paraformer/demo.py b/examples/industrial_data_pretraining/seaco_paraformer/demo.py index e9e226d1c..85d989e44 100644 --- a/examples/industrial_data_pretraining/seaco_paraformer/demo.py +++ b/examples/industrial_data_pretraining/seaco_paraformer/demo.py @@ -25,6 +25,7 @@ # example2 import torchaudio +import os wav_file = os.path.join(model.model_path, "example/asr_example.wav") input_tensor, sample_rate = torchaudio.load(wav_file) input_tensor = input_tensor.mean(0) @@ -33,7 +34,7 @@ # example3 import soundfile -import os + wav_file = os.path.join(model.model_path, "example/asr_example.wav") speech, sample_rate = soundfile.read(wav_file) res = model.generate(input=[speech], batch_size_s=300, is_final=True) diff --git a/funasr/datasets/audio_datasets/samplers.py b/funasr/datasets/audio_datasets/samplers.py index 0d9309814..535df5d05 100644 --- a/funasr/datasets/audio_datasets/samplers.py +++ b/funasr/datasets/audio_datasets/samplers.py @@ -26,6 +26,8 @@ def __init__(self, dataset, self.max_token_length = kwargs.get("max_token_length", 5000) self.shuffle_idx = np.arange(self.total_samples) self.shuffle = shuffle and is_training + self.length_scale_source = kwargs.get("length_scale_source", 1.0) + def __len__(self): return (self.total_samples-1) // self.batch_size + 1 @@ -53,8 +55,10 @@ def __iter__(self): idx_map = self.shuffle_idx[idx] # prompt = self.dataset.indexed_dataset[idx_map]["prompt"] - sample_len_cur = self.dataset.get_source_len(idx_map) + \ - self.dataset.get_target_len(idx_map) + target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0 + source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source + sample_len_cur = source_len + target_len + datalen_with_index.append([idx, sample_len_cur]) @@ -66,7 +70,7 @@ def __iter__(self): max_token_cur = max(max_token, sample_len_cur_raw) max_token_padding = 1 + num_sample - if self.batch_type == 'length': + if self.batch_type != 'example': max_token_padding *= max_token_cur if max_token_padding <= self.batch_size: batch.append(idx) diff --git a/funasr/models/whisper/model.py b/funasr/models/whisper/model.py index 6a6d47c6e..381a50161 100644 --- a/funasr/models/whisper/model.py +++ b/funasr/models/whisper/model.py @@ -10,6 +10,8 @@ from funasr.models.whisper.utils.decoding import detect_language as detect_language_function, decode as decode_function +from funasr.register import tables + @dataclass class ModelDimensions: @@ -128,6 +130,8 @@ def forward( return x + +@tables.register("encoder_classes", "WhisperEncoder") class AudioEncoder(nn.Module): def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int): super().__init__() @@ -158,7 +162,7 @@ def forward(self, x: Tensor): x = self.ln_post(x) return x - +@tables.register("decoder_classes", "WhisperDecoder") class TextDecoder(nn.Module): def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int): super().__init__() @@ -193,7 +197,7 @@ def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): return logits - +@tables.register("model_classes", "Whisper") class Whisper(nn.Module): def __init__(self, dims: dict): super().__init__() From 7ae2902a9414fb1f09507522a6b7f71e5137decc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Thu, 1 Feb 2024 16:38:21 +0800 Subject: [PATCH 4/9] batch_type token --- funasr/bin/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/funasr/bin/train.py b/funasr/bin/train.py index d9d4d6241..8ea0c0db5 100644 --- a/funasr/bin/train.py +++ b/funasr/bin/train.py @@ -154,7 +154,7 @@ def main(**kwargs): if batch_sampler is not None: batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler) batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf")) - batch_sampler_val = batch_sampler_class(dataset_tr, is_training=False, **kwargs.get("dataset_conf")) + batch_sampler_val = batch_sampler_class(dataset_val, is_training=False, **kwargs.get("dataset_conf")) dataloader_tr = torch.utils.data.DataLoader(dataset_tr, collate_fn=dataset_tr.collator, batch_sampler=batch_sampler, From bce516da29eb5bdc098c53f5506941210ff7e92e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Fri, 2 Feb 2024 00:08:34 +0800 Subject: [PATCH 5/9] huggingface model zoo --- README.md | 20 ++++++++++---------- README_zh.md | 16 ++++++++-------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index 9f345533a..bdedfb268 100644 --- a/README.md +++ b/README.md @@ -69,16 +69,16 @@ FunASR has open-sourced a large number of pre-trained models on industrial data. (Note: 🤗 represents the Huggingface model zoo link, ⭐ represents the ModelScope model zoo link) -| Model Name | Task Details | Training Data | Parameters | -|:------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:--------------------------------------------------:|:--------------------------------:|:----------:| -| paraformer-zh
([⭐](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) [🤗]() ) | speech recognition, with timestamps, non-streaming | 60000 hours, Mandarin | 220M | -| paraformer-zh-online
( [⭐](https://modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/summary) [🤗]() )
| speech recognition, streaming | 60000 hours, Mandarin | 220M | -| paraformer-en
( [⭐](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020/summary) [🤗]() ) | speech recognition, with timestamps, non-streaming | 50000 hours, English | 220M | -| conformer-en
( [⭐](https://modelscope.cn/models/damo/speech_conformer_asr-en-16k-vocab4199-pytorch/summary) [🤗]() ) | speech recognition, non-streaming | 50000 hours, English | 220M | -| ct-punc
( [⭐](https://modelscope.cn/models/damo/punc_ct-transformer_cn-en-common-vocab471067-large/summary) [🤗]() ) | punctuation restoration | 100M, Mandarin and English | 1.1G | -| fsmn-vad
( [⭐](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary) [🤗]() ) | voice activity detection | 5000 hours, Mandarin and English | 0.4M | -| fa-zh
( [⭐](https://modelscope.cn/models/damo/speech_timestamp_prediction-v1-16k-offline/summary) [🤗]() ) | timestamp prediction | 5000 hours, Mandarin | 38M | -| cam++
( [⭐](https://modelscope.cn/models/iic/speech_campplus_sv_zh-cn_16k-common/summary) [🤗]() ) | speaker verification/diarization | 5000 hours | 7.2M | +| Model Name | Task Details | Training Data | Parameters | +|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:--------------------------------------------------:|:--------------------------------:|:----------:| +| paraformer-zh
([⭐](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) [🤗](https://huggingface.co/funasr/paraformer-tp) ) | speech recognition, with timestamps, non-streaming | 60000 hours, Mandarin | 220M | +| paraformer-zh-streaming
( [⭐](https://modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/summary) [🤗](https://huggingface.co/funasr/paraformer-zh-streaming) )
| speech recognition, streaming | 60000 hours, Mandarin | 220M | +| paraformer-en
( [⭐](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020/summary) [🤗](https://huggingface.co/funasr/paraformer-en) ) | speech recognition, with timestamps, non-streaming | 50000 hours, English | 220M | +| conformer-en
( [⭐](https://modelscope.cn/models/damo/speech_conformer_asr-en-16k-vocab4199-pytorch/summary) [🤗](https://huggingface.co/funasr/conformer-en) ) | speech recognition, non-streaming | 50000 hours, English | 220M | +| ct-punc
( [⭐](https://modelscope.cn/models/damo/punc_ct-transformer_cn-en-common-vocab471067-large/summary) [🤗](https://huggingface.co/funasr/ct-punc) ) | punctuation restoration | 100M, Mandarin and English | 1.1G | +| fsmn-vad
( [⭐](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary) [🤗](https://huggingface.co/funasr/fsmn-vad) ) | voice activity detection | 5000 hours, Mandarin and English | 0.4M | +| fa-zh
( [⭐](https://modelscope.cn/models/damo/speech_timestamp_prediction-v1-16k-offline/summary) [🤗](https://huggingface.co/funasr/fa-zh) ) | timestamp prediction | 5000 hours, Mandarin | 38M | +| cam++
( [⭐](https://modelscope.cn/models/iic/speech_campplus_sv_zh-cn_16k-common/summary) [🤗](https://huggingface.co/funasr/campplus) ) | speaker verification/diarization | 5000 hours | 7.2M | diff --git a/README_zh.md b/README_zh.md index ed25b2a28..5d9061b62 100644 --- a/README_zh.md +++ b/README_zh.md @@ -73,14 +73,14 @@ FunASR开源了大量在工业数据上预训练模型,您可以在[模型许 | 模型名字 | 任务详情 | 训练数据 | 参数量 | |:------------------------------------------------------------------------------------------------------------------------------------------------------------:|:------------------:|:------------:|:----:| -| paraformer-zh
([⭐](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) [🤗]() ) | 语音识别,带时间戳输出,非实时 | 60000小时,中文 | 220M | -| paraformer-zh-streaming
( [⭐](https://modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/summary) [🤗]() ) | 语音识别,实时 | 60000小时,中文 | 220M | -| paraformer-en
( [⭐](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020/summary) [🤗]() ) | 语音识别,非实时 | 50000小时,英文 | 220M | -| conformer-en
( [⭐](https://modelscope.cn/models/damo/speech_conformer_asr-en-16k-vocab4199-pytorch/summary) [🤗]() ) | 语音识别,非实时 | 50000小时,英文 | 220M | -| ct-punc
( [⭐](https://modelscope.cn/models/damo/punc_ct-transformer_cn-en-common-vocab471067-large/summary) [🤗]() ) | 标点恢复 | 100M,中文与英文 | 1.1G | -| fsmn-vad
( [⭐](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary) [🤗]() ) | 语音端点检测,实时 | 5000小时,中文与英文 | 0.4M | -| fa-zh
( [⭐](https://modelscope.cn/models/damo/speech_timestamp_prediction-v1-16k-offline/summary) [🤗]() ) | 字级别时间戳预测 | 50000小时,中文 | 38M | -| cam++
( [⭐](https://modelscope.cn/models/iic/speech_campplus_sv_zh-cn_16k-common/summary) [🤗]() ) | 说话人确认/分割 | 5000小时 | 7.2M | +| paraformer-zh
([⭐](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) [🤗](https://huggingface.co/funasr/paraformer-tp) ) | 语音识别,带时间戳输出,非实时 | 60000小时,中文 | 220M | +| paraformer-zh-streaming
( [⭐](https://modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/summary) [🤗](https://huggingface.co/funasr/paraformer-zh-streaming) ) | 语音识别,实时 | 60000小时,中文 | 220M | +| paraformer-en
( [⭐](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020/summary) [🤗](https://huggingface.co/funasr/paraformer-en) ) | 语音识别,非实时 | 50000小时,英文 | 220M | +| conformer-en
( [⭐](https://modelscope.cn/models/damo/speech_conformer_asr-en-16k-vocab4199-pytorch/summary) [🤗](https://huggingface.co/funasr/conformer-en) ) | 语音识别,非实时 | 50000小时,英文 | 220M | +| ct-punc
( [⭐](https://modelscope.cn/models/damo/punc_ct-transformer_cn-en-common-vocab471067-large/summary) [🤗](https://huggingface.co/funasr/ct-punc) ) | 标点恢复 | 100M,中文与英文 | 1.1G | +| fsmn-vad
( [⭐](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary) [🤗](https://huggingface.co/funasr/fsmn-vad) ) | 语音端点检测,实时 | 5000小时,中文与英文 | 0.4M | +| fa-zh
( [⭐](https://modelscope.cn/models/damo/speech_timestamp_prediction-v1-16k-offline/summary) [🤗](https://huggingface.co/funasr/fa-zh) ) | 字级别时间戳预测 | 50000小时,中文 | 38M | +| cam++
( [⭐](https://modelscope.cn/models/iic/speech_campplus_sv_zh-cn_16k-common/summary) [🤗](https://huggingface.co/funasr/campplus) ) | 说话人确认/分割 | 5000小时 | 7.2M | From 7051dee8cc4f130be0799de0c56e4c3ce6965610 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Fri, 2 Feb 2024 11:33:52 +0800 Subject: [PATCH 6/9] dataloader --- funasr/datasets/audio_datasets/index_ds.py | 54 ++++++++++++- funasr/datasets/audio_datasets/samplers.py | 94 ++++++++++++++++++++++ 2 files changed, 146 insertions(+), 2 deletions(-) diff --git a/funasr/datasets/audio_datasets/index_ds.py b/funasr/datasets/audio_datasets/index_ds.py index c94d20961..008b08ff1 100644 --- a/funasr/datasets/audio_datasets/index_ds.py +++ b/funasr/datasets/audio_datasets/index_ds.py @@ -6,8 +6,8 @@ from funasr.register import tables -@tables.register("index_ds_classes", "IndexDSJsonl") -class IndexDSJsonl(torch.utils.data.Dataset): +@tables.register("index_ds_classes", "IndexDSJsonlRankSplit") +class IndexDSJsonlRankSplit(torch.utils.data.Dataset): def __init__(self, path): super().__init__() @@ -66,3 +66,53 @@ def get_source_len(self, data_dict): def get_target_len(self, data_dict): return data_dict["target_len"] if "target_len" in data_dict else 0 + +@tables.register("index_ds_classes", "IndexDSJsonl") +@tables.register("index_ds_classes", "IndexDSJsonlRankFull") +class IndexDSJsonlRankFull(torch.utils.data.Dataset): + + def __init__(self, path): + super().__init__() + + contents = [] + with open(path, encoding='utf-8') as fin: + for line in fin: + data = json.loads(line.strip()) + if "text" in data: # for sft + self.contents.append(data['text']) + if "source" in data: # for speech lab pretrain + prompt = data.get("prompt", "") + source = data["source"] + target = data["target"] + source_len = data.get("source_len", 1) + target_len = data.get("target_len", 0) + + contents.append({"source": source, + "prompt": prompt, + "target": target, + "source_len": source_len, + "target_len": target_len, + } + ) + + self.contents = contents + + logging.info( + "total_num of samplers across ranks: {}".format(len(self.contents))) + + def __len__(self): + return len(self.contents) + + def __getitem__(self, index): + try: + data = self.contents[index] + except: + print(index) + return data + + def get_source_len(self, data_dict): + return data_dict.get("source_len", 1) + + def get_target_len(self, data_dict): + + return data_dict.get("target_len", 0) diff --git a/funasr/datasets/audio_datasets/samplers.py b/funasr/datasets/audio_datasets/samplers.py index 535df5d05..b3bf36b1f 100644 --- a/funasr/datasets/audio_datasets/samplers.py +++ b/funasr/datasets/audio_datasets/samplers.py @@ -1,5 +1,7 @@ import torch import numpy as np +import logging +import torch.distributed as dist from funasr.register import tables @@ -82,3 +84,95 @@ def __iter__(self): max_token = sample_len_cur_raw num_sample = 1 + +@tables.register("batch_sampler_classes", "BatchSampler") +@tables.register("batch_sampler_classes", "RankFullLocalShuffleBatchSampler") +class RankFullLocalShuffleBatchSampler(torch.utils.data.BatchSampler): + + def __init__(self, dataset, + batch_type: str = "example", + batch_size: int = 100, + buffer_size: int = 30, + drop_last: bool = True, + shuffle: bool = True, + is_training: bool = True, + **kwargs): + + self.drop_last = drop_last + self.pre_idx = -1 + self.dataset = dataset + self.total_samples = len(dataset) + self.batch_type = batch_type + self.batch_size = int(batch_size) + self.buffer_size = buffer_size + self.max_token_length = kwargs.get("max_token_length", 1500) + self.shuffle_idx = np.arange(self.total_samples) + self.shuffle = shuffle and is_training + self.length_scale_source = kwargs.get("length_scale_source", 1.0) + + try: + rank = dist.get_rank() + world_size = dist.get_world_size() + except: + rank = 0 + world_size = 1 + self.rank = rank + self.world_size = world_size + + def __len__(self): + return (self.total_samples - 1) // self.batch_size + 1 + + def set_epoch(self, epoch): + np.random.seed(epoch) + + def __iter__(self): + + batch_size_total = self.batch_size * self.world_size + if self.shuffle: + np.random.shuffle(self.shuffle_idx) + + batch = [] + max_token = 0 + num_sample = 0 + + iter_num = (self.total_samples - 1) // self.buffer_size + 1 + # print("iter_num: ", iter_num) + for iter in range(self.pre_idx + 1, iter_num): + if iter == iter_num -1 and self.drop_last: + continue + datalen_with_index = [] + for i in range(self.buffer_size): + idx = iter * self.buffer_size + i + if idx >= self.total_samples: + continue + + idx_map = self.shuffle_idx[idx] + # prompt = self.dataset.indexed_dataset[idx_map]["prompt"] + + source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source + target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0 + sample_len_cur = source_len + target_len + + datalen_with_index.append([idx, sample_len_cur]) + + datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1]) + for item in datalen_with_index_sort: + idx, sample_len_cur_raw = item + if sample_len_cur_raw > self.max_token_length: + continue + + max_token_cur = max(max_token, sample_len_cur_raw) + max_token_padding = 1 + num_sample + # if self.batch_type != 'example': + # max_token_padding *= max_token_cur + if max_token_padding <= batch_size_total: + batch.append(idx) + max_token = max_token_cur + num_sample += 1 + else: + batch_rank = batch[self.rank*self.batch_size: (self.rank+1)*self.batch_size] + yield batch_rank + batch = [idx] + max_token = sample_len_cur_raw + num_sample = 1 + From 7a108176063300b0de355b8fd36e23dd19220167 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Sun, 4 Feb 2024 23:34:58 +0800 Subject: [PATCH 7/9] dataloader --- funasr/datasets/audio_datasets/samplers.py | 105 ++++++++++++++++++++- funasr/models/paraformer/cif_predictor.py | 2 +- funasr/train_utils/trainer.py | 18 ++++ 3 files changed, 121 insertions(+), 4 deletions(-) diff --git a/funasr/datasets/audio_datasets/samplers.py b/funasr/datasets/audio_datasets/samplers.py index b3bf36b1f..914e77692 100644 --- a/funasr/datasets/audio_datasets/samplers.py +++ b/funasr/datasets/audio_datasets/samplers.py @@ -120,7 +120,7 @@ def __init__(self, dataset, self.world_size = world_size def __len__(self): - return (self.total_samples - 1) // self.batch_size + 1 + return (self.total_samples - 1) // (self.batch_size * self.world_size) + 1 def set_epoch(self, epoch): np.random.seed(epoch) @@ -128,6 +128,7 @@ def set_epoch(self, epoch): def __iter__(self): batch_size_total = self.batch_size * self.world_size + if self.shuffle: np.random.shuffle(self.shuffle_idx) @@ -138,8 +139,8 @@ def __iter__(self): iter_num = (self.total_samples - 1) // self.buffer_size + 1 # print("iter_num: ", iter_num) for iter in range(self.pre_idx + 1, iter_num): - if iter == iter_num -1 and self.drop_last: - continue + # if iter == iter_num -1 and self.drop_last: + # continue datalen_with_index = [] for i in range(self.buffer_size): idx = iter * self.buffer_size + i @@ -176,3 +177,101 @@ def __iter__(self): max_token = sample_len_cur_raw num_sample = 1 + +@tables.register("batch_sampler_classes", "RankFullLocalShuffleDynamicBatchSampler") +class RankFullLocalShuffleDynamicBatchSampler(torch.utils.data.BatchSampler): + + def __init__(self, dataset, + batch_type: str = "example", + batch_size: int = 100, + buffer_size: int = 30, + drop_last: bool = True, + shuffle: bool = True, + is_training: bool = True, + **kwargs): + + self.drop_last = drop_last + self.pre_idx = -1 + self.dataset = dataset + self.total_samples = len(dataset) + self.batch_type = batch_type + self.batch_size = int(batch_size) + self.buffer_size = buffer_size + self.max_token_length = kwargs.get("max_token_length", 1500) + self.shuffle_idx = np.arange(self.total_samples) + self.shuffle = shuffle and is_training + self.length_scale_source = kwargs.get("length_scale_source", 1.0) + + try: + rank = dist.get_rank() + world_size = dist.get_world_size() + except: + rank = 0 + world_size = 1 + self.rank = rank + self.world_size = world_size + + def __len__(self): + return (self.total_samples - 1) // (self.batch_size * self.world_size) + 1 + + def set_epoch(self, epoch): + np.random.seed(epoch) + + def __iter__(self): + + batch_size_total = self.batch_size * self.world_size + if self.shuffle: + np.random.shuffle(self.shuffle_idx) + + batch_list_all_rank = [] + batch_list_cur = [] + max_token = 0 + num_sample = 0 + + iter_num = (self.total_samples - 1) // self.buffer_size + 1 + # print("iter_num: ", iter_num) + for iter in range(self.pre_idx + 1, iter_num): + # if iter == iter_num - 1 and self.drop_last: + # continue + datalen_with_index = [] + for i in range(self.buffer_size): + idx = iter * self.buffer_size + i + if idx >= self.total_samples: + continue + + idx_map = self.shuffle_idx[idx] + # prompt = self.dataset.indexed_dataset[idx_map]["prompt"] + + source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source + target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0 + sample_len_cur = source_len + target_len + + datalen_with_index.append([idx, sample_len_cur]) + + datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1]) + for ii, item in enumerate(datalen_with_index_sort): + is_last_batch = iter == iter_num - 1 and ii == len(datalen_with_index_sort) + idx, sample_len_cur_raw = item + if sample_len_cur_raw > self.max_token_length: + continue + + max_token_cur = max(max_token, sample_len_cur_raw) + max_token_padding = 1 + num_sample + + if self.batch_type != 'example': + max_token_padding *= max_token_cur + if len(batch_list_all_rank) < self.world_size: + + if max_token_padding <= self.batch_size: + batch_list_cur.append(idx) + max_token = max_token_cur + num_sample += 1 + else: + batch_list_all_rank.append(batch_list_cur) + batch_list_cur = [] + else: + batch_rank = batch_list_all_rank[self.rank] + yield batch_rank + batch_list_all_rank = [idx] + max_token = sample_len_cur_raw + num_sample = 1 diff --git a/funasr/models/paraformer/cif_predictor.py b/funasr/models/paraformer/cif_predictor.py index a5086c3c2..60ddc24e0 100644 --- a/funasr/models/paraformer/cif_predictor.py +++ b/funasr/models/paraformer/cif_predictor.py @@ -186,7 +186,7 @@ def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk alphas = alphas.squeeze(-1) mask = mask.squeeze(-1) if target_label_length is not None: - target_length = target_label_length + target_length = target_label_length.squeeze(-1) elif target_label is not None: target_length = (target_label != ignore_id).float().sum(-1) else: diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py index 414c0d7ca..d144019aa 100644 --- a/funasr/train_utils/trainer.py +++ b/funasr/train_utils/trainer.py @@ -204,7 +204,25 @@ def _train_epoch(self, epoch): my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext with my_context(): time2 = time.perf_counter() + print("before, GPU, memory: {:.1} MB, " + "{:.1} MB, " + "{:.1} MB, " + "{:.1} MB".format(torch.cuda.memory_allocated()/1024/1024/1024, + torch.cuda.max_memory_allocated()/1024/1024/1024, + torch.cuda.memory_reserved()/1024/1024/1024, + torch.cuda.max_memory_reserved()/1024/1024/1024, + )) + retval = self.model(**batch) + torch.cuda.empty_cache() + print("after, GPU, memory: {:.1} MB, " + "{:.1} MB, " + "{:.1} MB, " + "{:.1} MB".format(torch.cuda.memory_allocated()/1024/1024/1024, + torch.cuda.max_memory_allocated()/1024/1024/1024, + torch.cuda.memory_reserved()/1024/1024/1024, + torch.cuda.max_memory_reserved()/1024/1024/1024, + )) time3 = time.perf_counter() speed_stats["forward_time"] = f"{time3 - time2:0.3f}" loss, stats, weight = retval From 3b2cf1731232d618c959f5232fec3fe8c3092fd0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Tue, 6 Feb 2024 10:17:18 +0800 Subject: [PATCH 8/9] fbank input --- funasr/models/paraformer/model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py index 6e422ad75..77471466b 100644 --- a/funasr/models/paraformer/model.py +++ b/funasr/models/paraformer/model.py @@ -491,6 +491,8 @@ def inference(self, b, n, d = decoder_out.size() if isinstance(key[0], (list, tuple)): key = key[0] + if len(key) < b: + key = key*b for i in range(b): x = encoder_out[i, :encoder_out_lens[i], :] am_scores = decoder_out[i, :pre_token_length[i], :] From f0fcb28bdee9543048de08cd6badbf43b925c469 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Tue, 6 Feb 2024 21:17:11 +0800 Subject: [PATCH 9/9] vad is_final=True bugfix --- funasr/auto/auto_model.py | 3 ++- funasr/models/fsmn_vad_streaming/model.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py index d99fc5613..7a79d63aa 100644 --- a/funasr/auto/auto_model.py +++ b/funasr/auto/auto_model.py @@ -174,7 +174,7 @@ def build_model(self, **kwargs): # build model model_class = tables.model_classes.get(kwargs["model"]) model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size) - model.eval() + model.to(device) # init_param @@ -209,6 +209,7 @@ def inference(self, input, input_len=None, model=None, kwargs=None, key=None, ** kwargs = self.kwargs if kwargs is None else kwargs kwargs.update(cfg) model = self.model if model is None else model + model.eval() batch_size = kwargs.get("batch_size", 1) # if kwargs.get("device", "cpu") == "cpu": diff --git a/funasr/models/fsmn_vad_streaming/model.py b/funasr/models/fsmn_vad_streaming/model.py index 5fc6aae2f..4fd18c85f 100644 --- a/funasr/models/fsmn_vad_streaming/model.py +++ b/funasr/models/fsmn_vad_streaming/model.py @@ -575,7 +575,8 @@ def inference(self, time1 = time.perf_counter() is_streaming_input = kwargs.get("is_streaming_input", False) if chunk_size >= 15000 else kwargs.get("is_streaming_input", True) - cfg = {"is_final": kwargs.get("is_final", False), "is_streaming_input": is_streaming_input} + is_final = kwargs.get("is_final", False) if is_streaming_input else kwargs.get("is_final", True) + cfg = {"is_final": is_final, "is_streaming_input": is_streaming_input} audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),