From 7458e39ff0756d0bae38b139e0e534e61e1fa0cf Mon Sep 17 00:00:00 2001 From: "shixian.shi" Date: Wed, 17 Jan 2024 19:21:08 +0800 Subject: [PATCH] bug fix --- .../paraformer/demo.py | 4 ++- funasr/models/bicif_paraformer/model.py | 34 +++++++++---------- 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/examples/industrial_data_pretraining/paraformer/demo.py b/examples/industrial_data_pretraining/paraformer/demo.py index ef33bf40d..78af3aa1d 100644 --- a/examples/industrial_data_pretraining/paraformer/demo.py +++ b/examples/industrial_data_pretraining/paraformer/demo.py @@ -11,6 +11,7 @@ print(res) +''' can not use currently from funasr import AutoFrontend frontend = AutoFrontend(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model_revision="v2.0.2") @@ -19,4 +20,5 @@ for batch_idx, fbank_dict in enumerate(fbanks): res = model.generate(**fbank_dict) - print(res) \ No newline at end of file + print(res) +''' \ No newline at end of file diff --git a/funasr/models/bicif_paraformer/model.py b/funasr/models/bicif_paraformer/model.py index 01f19c697..0069b8c98 100644 --- a/funasr/models/bicif_paraformer/model.py +++ b/funasr/models/bicif_paraformer/model.py @@ -235,23 +235,23 @@ def inference(self, self.nbest = kwargs.get("nbest", 1) meta_data = {} - if isinstance(data_in, torch.Tensor): # fbank - speech, speech_lengths = data_in, data_lengths - if len(speech.shape) < 3: - speech = speech[None, :, :] - if speech_lengths is None: - speech_lengths = speech.shape[1] - else: - # extract fbank feats - time1 = time.perf_counter() - audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000)) - time2 = time.perf_counter() - meta_data["load_data"] = f"{time2 - time1:0.3f}" - speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), - frontend=frontend) - time3 = time.perf_counter() - meta_data["extract_feat"] = f"{time3 - time2:0.3f}" - meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 + # if isinstance(data_in, torch.Tensor): # fbank + # speech, speech_lengths = data_in, data_lengths + # if len(speech.shape) < 3: + # speech = speech[None, :, :] + # if speech_lengths is None: + # speech_lengths = speech.shape[1] + # else: + # extract fbank feats + time1 = time.perf_counter() + audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000)) + time2 = time.perf_counter() + meta_data["load_data"] = f"{time2 - time1:0.3f}" + speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), + frontend=frontend) + time3 = time.perf_counter() + meta_data["extract_feat"] = f"{time3 - time2:0.3f}" + meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 speech = speech.to(device=kwargs["device"]) speech_lengths = speech_lengths.to(device=kwargs["device"])