From d30e28686e3b8352b9a91491c8ad8ea47b4f3c24 Mon Sep 17 00:00:00 2001 From: Binbin Zhang Date: Wed, 23 Aug 2023 22:31:04 +0800 Subject: [PATCH] [binding] refactor wenet_decode api, remove wenet_get_result --- runtime/binding/python/cpp/binding.cc | 5 ++--- runtime/binding/python/py/decoder.py | 3 +-- runtime/binding/python/setup.py | 2 +- runtime/core/api/wenet_api.cc | 31 ++++++++++++--------------- runtime/core/api/wenet_api.h | 15 ++++++------- runtime/core/bin/api_main.cc | 6 +++--- 6 files changed, 28 insertions(+), 34 deletions(-) diff --git a/runtime/binding/python/cpp/binding.cc b/runtime/binding/python/cpp/binding.cc index cff4f545e..42578f211 100644 --- a/runtime/binding/python/cpp/binding.cc +++ b/runtime/binding/python/cpp/binding.cc @@ -25,9 +25,8 @@ PYBIND11_MODULE(_wenet, m) { "wenet init"); m.def("wenet_free", &wenet_free, "wenet free"); m.def("wenet_reset", &wenet_reset, "wenet reset"); - m.def("wenet_decode", &wenet_decode, "wenet decode"); - m.def("wenet_get_result", &wenet_get_result, py::return_value_policy::copy, - "wenet get result"); + m.def("wenet_decode", &wenet_decode, py::return_value_policy::copy, + "wenet decode"); m.def("wenet_set_log_level", &wenet_set_log_level, "set log level"); m.def("wenet_set_nbest", &wenet_set_nbest, "set nbest"); m.def("wenet_set_timestamp", &wenet_set_timestamp, "set timestamp flag"); diff --git a/runtime/binding/python/py/decoder.py b/runtime/binding/python/py/decoder.py index 42d79ab1a..a93750ec8 100644 --- a/runtime/binding/python/py/decoder.py +++ b/runtime/binding/python/py/decoder.py @@ -96,8 +96,7 @@ def decode(self, pcm: bytes, last: bool = True) -> str: """ assert isinstance(pcm, bytes) finish = 1 if last else 0 - _wenet.wenet_decode(self.d, pcm, len(pcm), finish) - result = _wenet.wenet_get_result(self.d) + result = _wenet.wenet_decode(self.d, pcm, len(pcm), finish) if last: # Reset status for next decoding automatically self.reset() return result diff --git a/runtime/binding/python/setup.py b/runtime/binding/python/setup.py index e61775768..7c31f6c34 100644 --- a/runtime/binding/python/setup.py +++ b/runtime/binding/python/setup.py @@ -79,7 +79,7 @@ def read_long_description(): cmdclass={"build_ext": BuildExtension}, zip_safe=False, setup_requires=["tqdm"], - install_requires=["torch", "tqdm"] if "ONNX=ON" not in + install_requires=["torch>=1.10.0", "tqdm"] if "ONNX=ON" not in os.environ.get("WENET_CMAKE_ARGS", "") else ["tqdm"], classifiers=[ "Programming Language :: C++", diff --git a/runtime/core/api/wenet_api.cc b/runtime/core/api/wenet_api.cc index 36a0b7f48..15a01426b 100644 --- a/runtime/core/api/wenet_api.cc +++ b/runtime/core/api/wenet_api.cc @@ -89,7 +89,6 @@ class Recognizer { if (decoder_ != nullptr) { decoder_->Reset(); } - result_.clear(); } void InitDecoder() { @@ -115,7 +114,7 @@ class Recognizer { *decode_options_); } - void Decode(const char* data, int len, int last) { + std::string Decode(const char* data, int len, int last) { using wenet::DecodeState; // Init decoder when it is called first time if (decoder_ == nullptr) { @@ -129,25 +128,29 @@ class Recognizer { feature_pipeline_->set_input_finished(); } + std::string result = "{}"; // empty json while (true) { DecodeState state = decoder_->Decode(false); if (state == DecodeState::kWaitFeats) { + result = UpdateResult(false); break; } else if (state == DecodeState::kEndFeats) { decoder_->Rescoring(); - UpdateResult(true); + result = UpdateResult(true); break; } else if (state == DecodeState::kEndpoint && continuous_decoding_) { decoder_->Rescoring(); - UpdateResult(true); + result = UpdateResult(true); decoder_->ResetContinuousDecoding(); + break; } else { // kEndBatch - UpdateResult(false); + result = UpdateResult(false); } } + return result; } - void UpdateResult(bool final_result) { + std::string UpdateResult(bool final_result) { json::JSON obj; obj["type"] = final_result ? "final_result" : "partial_result"; int nbest = final_result ? nbest_ : 1; @@ -168,11 +171,9 @@ class Recognizer { one["sentence"] = decoder_->result()[i].sentence; obj["nbest"].append(one); } - result_ = obj.dump(); + return obj.dump(); } - const char* GetResult() { return result_.c_str(); } - void set_nbest(int n) { nbest_ = n; } void set_enable_timestamp(bool flag) { enable_timestamp_ = flag; } void AddContext(const char* word) { context_.emplace_back(word); } @@ -191,7 +192,6 @@ class Recognizer { std::shared_ptr post_process_opts_ = nullptr; int nbest_ = 1; - std::string result_; bool enable_timestamp_ = false; std::vector context_; float context_score_; @@ -213,14 +213,11 @@ void wenet_reset(void* decoder) { recognizer->Reset(); } -void wenet_decode(void* decoder, const char* data, int len, int last) { - Recognizer* recognizer = reinterpret_cast(decoder); - recognizer->Decode(data, len, last); -} - -const char* wenet_get_result(void* decoder) { +const char* wenet_decode(void* decoder, const char* data, int len, int last) { + static std::string result; Recognizer* recognizer = reinterpret_cast(decoder); - return recognizer->GetResult(); + result = recognizer->Decode(data, len, last); + return result.c_str(); } void wenet_set_log_level(int level) { diff --git a/runtime/core/api/wenet_api.h b/runtime/core/api/wenet_api.h index e839aaa40..fe524fc51 100644 --- a/runtime/core/api/wenet_api.h +++ b/runtime/core/api/wenet_api.h @@ -34,14 +34,7 @@ void wenet_free(void* decoder); */ void wenet_reset(void* decoder); -/** Decode the input wav data - * @param data: pcm data, encoded as int16_t(16 bits) - * @param len: data length - * @param last: if it is the last package - */ -void wenet_decode(void* decoder, const char* data, int len, int last); - -/** Get decode result in json format +/** Decode the input wav data and get decode result in json format * It returns partial result when last is 0 * It returns final result when last is 1 @@ -68,7 +61,13 @@ void wenet_decode(void* decoder, const char* data, int len, int last); "nbest": nbest is enabled when n > 1 in final_result "sentence": the ASR result "word_pieces": optional, output timestamp when enabled + + * @param data: pcm data, encoded as int16_t(16 bits) + * @param len: data length + * @param last: if it is the last package */ +const char* wenet_decode(void* decoder, const char* data, int len, int last); + const char* wenet_get_result(void* decoder); /** Set n-best, range 1~10 diff --git a/runtime/core/bin/api_main.cc b/runtime/core/bin/api_main.cc index 94b20d52a..d35a4aa27 100644 --- a/runtime/core/bin/api_main.cc +++ b/runtime/core/bin/api_main.cc @@ -36,9 +36,9 @@ int main(int argc, char* argv[]) { for (int i = 0; i < 10; i++) { // Return the final result when last is 1 - wenet_decode(decoder, reinterpret_cast(data.data()), - data.size() * 2, 1); - const char* result = wenet_get_result(decoder); + const char* result = + wenet_decode(decoder, reinterpret_cast(data.data()), + data.size() * 2, 1); LOG(INFO) << i << " " << result; wenet_reset(decoder); }