Skip to content

Commit

Permalink
[binding] refactor wenet_decode api, remove wenet_get_result
Browse files Browse the repository at this point in the history
  • Loading branch information
robin1001 committed Aug 23, 2023
1 parent f393be9 commit d30e286
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 34 deletions.
5 changes: 2 additions & 3 deletions runtime/binding/python/cpp/binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@ PYBIND11_MODULE(_wenet, m) {
"wenet init");
m.def("wenet_free", &wenet_free, "wenet free");
m.def("wenet_reset", &wenet_reset, "wenet reset");
m.def("wenet_decode", &wenet_decode, "wenet decode");
m.def("wenet_get_result", &wenet_get_result, py::return_value_policy::copy,
"wenet get result");
m.def("wenet_decode", &wenet_decode, py::return_value_policy::copy,
"wenet decode");
m.def("wenet_set_log_level", &wenet_set_log_level, "set log level");
m.def("wenet_set_nbest", &wenet_set_nbest, "set nbest");
m.def("wenet_set_timestamp", &wenet_set_timestamp, "set timestamp flag");
Expand Down
3 changes: 1 addition & 2 deletions runtime/binding/python/py/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ def decode(self, pcm: bytes, last: bool = True) -> str:
"""
assert isinstance(pcm, bytes)
finish = 1 if last else 0
_wenet.wenet_decode(self.d, pcm, len(pcm), finish)
result = _wenet.wenet_get_result(self.d)
result = _wenet.wenet_decode(self.d, pcm, len(pcm), finish)
if last: # Reset status for next decoding automatically
self.reset()
return result
Expand Down
2 changes: 1 addition & 1 deletion runtime/binding/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def read_long_description():
cmdclass={"build_ext": BuildExtension},
zip_safe=False,
setup_requires=["tqdm"],
install_requires=["torch", "tqdm"] if "ONNX=ON" not in
install_requires=["torch>=1.10.0", "tqdm"] if "ONNX=ON" not in
os.environ.get("WENET_CMAKE_ARGS", "") else ["tqdm"],
classifiers=[
"Programming Language :: C++",
Expand Down
31 changes: 14 additions & 17 deletions runtime/core/api/wenet_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ class Recognizer {
if (decoder_ != nullptr) {
decoder_->Reset();
}
result_.clear();
}

void InitDecoder() {
Expand All @@ -115,7 +114,7 @@ class Recognizer {
*decode_options_);
}

void Decode(const char* data, int len, int last) {
std::string Decode(const char* data, int len, int last) {
using wenet::DecodeState;
// Init decoder when it is called first time
if (decoder_ == nullptr) {
Expand All @@ -129,25 +128,29 @@ class Recognizer {
feature_pipeline_->set_input_finished();
}

std::string result = "{}"; // empty json
while (true) {
DecodeState state = decoder_->Decode(false);
if (state == DecodeState::kWaitFeats) {
result = UpdateResult(false);
break;
} else if (state == DecodeState::kEndFeats) {
decoder_->Rescoring();
UpdateResult(true);
result = UpdateResult(true);
break;
} else if (state == DecodeState::kEndpoint && continuous_decoding_) {
decoder_->Rescoring();
UpdateResult(true);
result = UpdateResult(true);
decoder_->ResetContinuousDecoding();
break;
} else { // kEndBatch
UpdateResult(false);
result = UpdateResult(false);
}
}
return result;
}

void UpdateResult(bool final_result) {
std::string UpdateResult(bool final_result) {
json::JSON obj;
obj["type"] = final_result ? "final_result" : "partial_result";
int nbest = final_result ? nbest_ : 1;
Expand All @@ -168,11 +171,9 @@ class Recognizer {
one["sentence"] = decoder_->result()[i].sentence;
obj["nbest"].append(one);
}
result_ = obj.dump();
return obj.dump();
}

const char* GetResult() { return result_.c_str(); }

void set_nbest(int n) { nbest_ = n; }
void set_enable_timestamp(bool flag) { enable_timestamp_ = flag; }
void AddContext(const char* word) { context_.emplace_back(word); }
Expand All @@ -191,7 +192,6 @@ class Recognizer {
std::shared_ptr<wenet::PostProcessOptions> post_process_opts_ = nullptr;

int nbest_ = 1;
std::string result_;
bool enable_timestamp_ = false;
std::vector<std::string> context_;
float context_score_;
Expand All @@ -213,14 +213,11 @@ void wenet_reset(void* decoder) {
recognizer->Reset();
}

void wenet_decode(void* decoder, const char* data, int len, int last) {
Recognizer* recognizer = reinterpret_cast<Recognizer*>(decoder);
recognizer->Decode(data, len, last);
}

const char* wenet_get_result(void* decoder) {
const char* wenet_decode(void* decoder, const char* data, int len, int last) {
static std::string result;
Recognizer* recognizer = reinterpret_cast<Recognizer*>(decoder);
return recognizer->GetResult();
result = recognizer->Decode(data, len, last);
return result.c_str();
}

void wenet_set_log_level(int level) {
Expand Down
15 changes: 7 additions & 8 deletions runtime/core/api/wenet_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,7 @@ void wenet_free(void* decoder);
*/
void wenet_reset(void* decoder);

/** Decode the input wav data
* @param data: pcm data, encoded as int16_t(16 bits)
* @param len: data length
* @param last: if it is the last package
*/
void wenet_decode(void* decoder, const char* data, int len, int last);

/** Get decode result in json format
/** Decode the input wav data and get decode result in json format
* It returns partial result when last is 0
* It returns final result when last is 1
Expand All @@ -68,7 +61,13 @@ void wenet_decode(void* decoder, const char* data, int len, int last);
"nbest": nbest is enabled when n > 1 in final_result
"sentence": the ASR result
"word_pieces": optional, output timestamp when enabled
* @param data: pcm data, encoded as int16_t(16 bits)
* @param len: data length
* @param last: if it is the last package
*/
const char* wenet_decode(void* decoder, const char* data, int len, int last);

const char* wenet_get_result(void* decoder);

/** Set n-best, range 1~10
Expand Down
6 changes: 3 additions & 3 deletions runtime/core/bin/api_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ int main(int argc, char* argv[]) {

for (int i = 0; i < 10; i++) {
// Return the final result when last is 1
wenet_decode(decoder, reinterpret_cast<const char*>(data.data()),
data.size() * 2, 1);
const char* result = wenet_get_result(decoder);
const char* result =
wenet_decode(decoder, reinterpret_cast<const char*>(data.data()),
data.size() * 2, 1);
LOG(INFO) << i << " " << result;
wenet_reset(decoder);
}
Expand Down

0 comments on commit d30e286

Please sign in to comment.