Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
R1ckShi committed Jan 17, 2024
1 parent 9a9c3b7 commit 7458e39
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 18 deletions.
4 changes: 3 additions & 1 deletion examples/industrial_data_pretraining/paraformer/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
print(res)


''' can not use currently
from funasr import AutoFrontend
frontend = AutoFrontend(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model_revision="v2.0.2")
Expand All @@ -19,4 +20,5 @@
for batch_idx, fbank_dict in enumerate(fbanks):
res = model.generate(**fbank_dict)
print(res)
print(res)
'''
34 changes: 17 additions & 17 deletions funasr/models/bicif_paraformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,23 +235,23 @@ def inference(self,
self.nbest = kwargs.get("nbest", 1)

meta_data = {}
if isinstance(data_in, torch.Tensor): # fbank
speech, speech_lengths = data_in, data_lengths
if len(speech.shape) < 3:
speech = speech[None, :, :]
if speech_lengths is None:
speech_lengths = speech.shape[1]
else:
# extract fbank feats
time1 = time.perf_counter()
audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
frontend=frontend)
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
# if isinstance(data_in, torch.Tensor): # fbank
# speech, speech_lengths = data_in, data_lengths
# if len(speech.shape) < 3:
# speech = speech[None, :, :]
# if speech_lengths is None:
# speech_lengths = speech.shape[1]
# else:
# extract fbank feats
time1 = time.perf_counter()
audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
frontend=frontend)
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000

speech = speech.to(device=kwargs["device"])
speech_lengths = speech_lengths.to(device=kwargs["device"])
Expand Down

0 comments on commit 7458e39

Please sign in to comment.