diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py index d99fc5613..7a79d63aa 100644 --- a/funasr/auto/auto_model.py +++ b/funasr/auto/auto_model.py @@ -174,7 +174,7 @@ def build_model(self, **kwargs): # build model model_class = tables.model_classes.get(kwargs["model"]) model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size) - model.eval() + model.to(device) # init_param @@ -209,6 +209,7 @@ def inference(self, input, input_len=None, model=None, kwargs=None, key=None, ** kwargs = self.kwargs if kwargs is None else kwargs kwargs.update(cfg) model = self.model if model is None else model + model.eval() batch_size = kwargs.get("batch_size", 1) # if kwargs.get("device", "cpu") == "cpu": diff --git a/funasr/models/fsmn_vad_streaming/model.py b/funasr/models/fsmn_vad_streaming/model.py index 5fc6aae2f..4fd18c85f 100644 --- a/funasr/models/fsmn_vad_streaming/model.py +++ b/funasr/models/fsmn_vad_streaming/model.py @@ -575,7 +575,8 @@ def inference(self, time1 = time.perf_counter() is_streaming_input = kwargs.get("is_streaming_input", False) if chunk_size >= 15000 else kwargs.get("is_streaming_input", True) - cfg = {"is_final": kwargs.get("is_final", False), "is_streaming_input": is_streaming_input} + is_final = kwargs.get("is_final", False) if is_streaming_input else kwargs.get("is_final", True) + cfg = {"is_final": is_final, "is_streaming_input": is_streaming_input} audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),