diff --git a/runtime/core/cmake/openfst.cmake b/runtime/core/cmake/openfst.cmake index 1d9965248..ac6f9524c 100644 --- a/runtime/core/cmake/openfst.cmake +++ b/runtime/core/cmake/openfst.cmake @@ -1,47 +1,47 @@ -if(NOT ANDROID) - include(gflags) - # We can't build glog with gflags, unless gflags is pre-installed. - # If build glog with pre-installed gflags, there will be conflict. - set(WITH_GFLAGS OFF CACHE BOOL "whether build glog with gflags" FORCE) - include(glog) - - if(NOT GRAPH_TOOLS) - set(HAVE_BIN OFF CACHE BOOL "Build the fst binaries" FORCE) - set(HAVE_SCRIPT OFF CACHE BOOL "Build the fstscript" FORCE) - endif() - set(HAVE_COMPACT OFF CACHE BOOL "Build compact" FORCE) - set(HAVE_CONST OFF CACHE BOOL "Build const" FORCE) - set(HAVE_GRM OFF CACHE BOOL "Build grm" FORCE) - set(HAVE_FAR OFF CACHE BOOL "Build far" FORCE) - set(HAVE_PDT OFF CACHE BOOL "Build pdt" FORCE) - set(HAVE_MPDT OFF CACHE BOOL "Build mpdt" FORCE) - set(HAVE_LINEAR OFF CACHE BOOL "Build linear" FORCE) - set(HAVE_LOOKAHEAD OFF CACHE BOOL "Build lookahead" FORCE) - set(HAVE_NGRAM OFF CACHE BOOL "Build ngram" FORCE) - set(HAVE_SPECIAL OFF CACHE BOOL "Build special" FORCE) - - if(MSVC) - add_compile_options(/W0 /wd4244 /wd4267) - endif() - - # "OpenFST port for Windows" builds openfst with cmake for multiple platforms. - # Openfst is compiled with glog/gflags to avoid log and flag conflicts with log and flags in wenet/libtorch. - # To build openfst with gflags and glog, we comment out some vars of {flags, log}.h and flags.cc. - set(openfst_SOURCE_DIR ${fc_base}/openfst-src CACHE PATH "OpenFST source directory") - FetchContent_Declare(openfst - URL https://github.com/kkm000/openfst/archive/refs/tags/win/1.7.2.1.tar.gz - URL_HASH SHA256=e04e1dabcecf3a687ace699ccb43a8a27da385777a56e69da6e103344cc66bca - #URL https://github.com/kkm000/openfst/archive/refs/tags/win/1.6.5.1.tar.gz - #URL_HASH SHA256=02c49b559c3976a536876063369efc0e41ab374be1035918036474343877046e - PATCH_COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/patch/openfst ${openfst_SOURCE_DIR} - ) - FetchContent_MakeAvailable(openfst) - add_dependencies(fst gflags glog) - target_link_libraries(fst PUBLIC gflags_nothreads_static glog) - include_directories(${openfst_SOURCE_DIR}/src/include) -else() - set(openfst_BINARY_DIR ${build_DIR}/wenet-openfst-android-1.0.2.aar/jni) - include_directories(${openfst_BINARY_DIR}/include) - link_directories(${openfst_BINARY_DIR}/${ANDROID_ABI}) - link_libraries(log gflags_nothreads glog fst) -endif() +if(NOT ANDROID) + include(gflags) + # We can't build glog with gflags, unless gflags is pre-installed. + # If build glog with pre-installed gflags, there will be conflict. + set(WITH_GFLAGS OFF CACHE BOOL "whether build glog with gflags" FORCE) + include(glog) + + if(NOT GRAPH_TOOLS) + set(HAVE_BIN OFF CACHE BOOL "Build the fst binaries" FORCE) + set(HAVE_SCRIPT OFF CACHE BOOL "Build the fstscript" FORCE) + endif() + set(HAVE_COMPACT OFF CACHE BOOL "Build compact" FORCE) + set(HAVE_CONST OFF CACHE BOOL "Build const" FORCE) + set(HAVE_GRM OFF CACHE BOOL "Build grm" FORCE) + set(HAVE_FAR OFF CACHE BOOL "Build far" FORCE) + set(HAVE_PDT OFF CACHE BOOL "Build pdt" FORCE) + set(HAVE_MPDT OFF CACHE BOOL "Build mpdt" FORCE) + set(HAVE_LINEAR OFF CACHE BOOL "Build linear" FORCE) + set(HAVE_LOOKAHEAD OFF CACHE BOOL "Build lookahead" FORCE) + set(HAVE_NGRAM OFF CACHE BOOL "Build ngram" FORCE) + set(HAVE_SPECIAL OFF CACHE BOOL "Build special" FORCE) + + if(MSVC) + add_compile_options(/W0 /wd4244 /wd4267) + endif() + + # "OpenFST port for Windows" builds openfst with cmake for multiple platforms. + # Openfst is compiled with glog/gflags to avoid log and flag conflicts with log and flags in wenet/libtorch. + # To build openfst with gflags and glog, we comment out some vars of {flags, log}.h and flags.cc. + set(openfst_SOURCE_DIR ${fc_base}/openfst-src CACHE PATH "OpenFST source directory") + FetchContent_Declare(openfst + URL https://github.com/kkm000/openfst/archive/refs/tags/win/1.7.2.1.tar.gz + URL_HASH SHA256=e04e1dabcecf3a687ace699ccb43a8a27da385777a56e69da6e103344cc66bca + #URL https://github.com/kkm000/openfst/archive/refs/tags/win/1.6.5.1.tar.gz + #URL_HASH SHA256=02c49b559c3976a536876063369efc0e41ab374be1035918036474343877046e + PATCH_COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/patch/openfst ${openfst_SOURCE_DIR} + ) + FetchContent_MakeAvailable(openfst) + add_dependencies(fst gflags glog) + target_link_libraries(fst PUBLIC gflags_nothreads_static glog) + include_directories(${openfst_SOURCE_DIR}/src/include) +else() + set(openfst_BINARY_DIR ${build_DIR}/wenet-openfst-android-1.0.2.aar/jni) + include_directories(${openfst_BINARY_DIR}/include) + link_directories(${openfst_BINARY_DIR}/${ANDROID_ABI}) + link_libraries(log gflags_nothreads glog fst) +endif() diff --git a/runtime/core/decoder/asr_decoder.h b/runtime/core/decoder/asr_decoder.h index e40d46c61..6f0380d78 100644 --- a/runtime/core/decoder/asr_decoder.h +++ b/runtime/core/decoder/asr_decoder.h @@ -1,170 +1,170 @@ -// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu) -// 2022 Binbin Zhang (binbzha@qq.com) -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef DECODER_ASR_DECODER_H_ -#define DECODER_ASR_DECODER_H_ - -#include -#include -#include -#include -#include - -#include "fst/fstlib.h" -#include "fst/symbol-table.h" - -#include "decoder/asr_model.h" -#include "decoder/context_graph.h" -#include "decoder/ctc_endpoint.h" -#include "decoder/ctc_prefix_beam_search.h" -#include "decoder/ctc_wfst_beam_search.h" -#include "decoder/search_interface.h" -#include "frontend/feature_pipeline.h" -#include "post_processor/post_processor.h" -#include "utils/utils.h" - -namespace wenet { - -struct DecodeOptions { - // chunk_size is the frame number of one chunk after subsampling. - // e.g. if subsample rate is 4 and chunk_size = 16, the frames in - // one chunk are 64 = 16*4 - int chunk_size = 16; - int num_left_chunks = -1; - - // final_score = rescoring_weight * rescoring_score + ctc_weight * ctc_score; - // rescoring_score = left_to_right_score * (1 - reverse_weight) + - // right_to_left_score * reverse_weight - // Please note the concept of ctc_scores in the following two search - // methods are different. - // For CtcPrefixBeamSearch, it's a sum(prefix) score + context score - // For CtcWfstBeamSearch, it's a max(viterbi) path score + context score - // So we should carefully set ctc_weight according to the search methods. - float ctc_weight = 0.5; - float rescoring_weight = 1.0; - float reverse_weight = 0.0; - CtcEndpointConfig ctc_endpoint_config; - CtcPrefixBeamSearchOptions ctc_prefix_search_opts; - CtcWfstBeamSearchOptions ctc_wfst_search_opts; -}; - -struct WordPiece { - std::string word; - int start = -1; - int end = -1; - - WordPiece(std::string word, int start, int end) - : word(std::move(word)), start(start), end(end) {} -}; - -struct DecodeResult { - float score = -kFloatMax; - std::string sentence; - std::unordered_set contexts; - std::vector word_pieces; - - static bool CompareFunc(const DecodeResult& a, const DecodeResult& b) { - return a.score > b.score; - } -}; - -enum DecodeState { - kEndBatch = 0x00, // End of current decoding batch, normal case - kEndpoint = 0x01, // Endpoint is detected - kEndFeats = 0x02, // All feature is decoded - kWaitFeats = 0x03 // Feat is not enough for one chunk inference, wait -}; - -// DecodeResource is thread safe, which can be shared for multiple -// decoding threads -struct DecodeResource { - std::shared_ptr model = nullptr; - std::shared_ptr symbol_table = nullptr; - // std::shared_ptr> fst = nullptr; - std::shared_ptr> fst = nullptr; - std::shared_ptr unit_table = nullptr; - std::shared_ptr context_graph = nullptr; - std::shared_ptr post_processor = nullptr; -}; - -// Torch ASR decoder -class AsrDecoder { - public: - AsrDecoder(std::shared_ptr feature_pipeline, - std::shared_ptr resource, - const DecodeOptions& opts); - // @param block: if true, block when feature is not enough for one chunk - // inference. Otherwise, return kWaitFeats. - DecodeState Decode(bool block = true); - void Rescoring(); - void Reset(); - void ResetContinuousDecoding(); - bool DecodedSomething() const { - return !result_.empty() && !result_[0].sentence.empty(); - } - - // This method is used for time benchmark - int num_frames_in_current_chunk() const { - return num_frames_in_current_chunk_; - } - int frame_shift_in_ms() const { - return model_->subsampling_rate() * - feature_pipeline_->config().frame_shift * 1000 / - feature_pipeline_->config().sample_rate; - } - int feature_frame_shift_in_ms() const { - return feature_pipeline_->config().frame_shift * 1000 / - feature_pipeline_->config().sample_rate; - } - const std::vector& result() const { return result_; } - - private: - DecodeState AdvanceDecoding(bool block = true); - void AttentionRescoring(); - - void UpdateResult(bool finish = false); - - std::shared_ptr feature_pipeline_; - std::shared_ptr model_; - std::shared_ptr post_processor_; - std::shared_ptr context_graph_; - - // std::shared_ptr> fst_ = nullptr; - std::shared_ptr> fst_ = nullptr; - // output symbol table - std::shared_ptr symbol_table_; - // e2e unit symbol table - std::shared_ptr unit_table_ = nullptr; - const DecodeOptions& opts_; - // cache feature - bool start_ = false; - // For continuous decoding - int num_frames_ = 0; - int global_frame_offset_ = 0; - const int time_stamp_gap_ = 100; // timestamp gap between words in a sentence - - std::unique_ptr searcher_; - std::unique_ptr ctc_endpointer_; - - int num_frames_in_current_chunk_ = 0; - std::vector result_; - - public: - WENET_DISALLOW_COPY_AND_ASSIGN(AsrDecoder); -}; - -} // namespace wenet - -#endif // DECODER_ASR_DECODER_H_ +// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu) +// 2022 Binbin Zhang (binbzha@qq.com) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DECODER_ASR_DECODER_H_ +#define DECODER_ASR_DECODER_H_ + +#include +#include +#include +#include +#include + +#include "fst/fstlib.h" +#include "fst/symbol-table.h" + +#include "decoder/asr_model.h" +#include "decoder/context_graph.h" +#include "decoder/ctc_endpoint.h" +#include "decoder/ctc_prefix_beam_search.h" +#include "decoder/ctc_wfst_beam_search.h" +#include "decoder/search_interface.h" +#include "frontend/feature_pipeline.h" +#include "post_processor/post_processor.h" +#include "utils/utils.h" + +namespace wenet { + +struct DecodeOptions { + // chunk_size is the frame number of one chunk after subsampling. + // e.g. if subsample rate is 4 and chunk_size = 16, the frames in + // one chunk are 64 = 16*4 + int chunk_size = 16; + int num_left_chunks = -1; + + // final_score = rescoring_weight * rescoring_score + ctc_weight * ctc_score; + // rescoring_score = left_to_right_score * (1 - reverse_weight) + + // right_to_left_score * reverse_weight + // Please note the concept of ctc_scores in the following two search + // methods are different. + // For CtcPrefixBeamSearch, it's a sum(prefix) score + context score + // For CtcWfstBeamSearch, it's a max(viterbi) path score + context score + // So we should carefully set ctc_weight according to the search methods. + float ctc_weight = 0.5; + float rescoring_weight = 1.0; + float reverse_weight = 0.0; + CtcEndpointConfig ctc_endpoint_config; + CtcPrefixBeamSearchOptions ctc_prefix_search_opts; + CtcWfstBeamSearchOptions ctc_wfst_search_opts; +}; + +struct WordPiece { + std::string word; + int start = -1; + int end = -1; + + WordPiece(std::string word, int start, int end) + : word(std::move(word)), start(start), end(end) {} +}; + +struct DecodeResult { + float score = -kFloatMax; + std::string sentence; + std::unordered_set contexts; + std::vector word_pieces; + + static bool CompareFunc(const DecodeResult& a, const DecodeResult& b) { + return a.score > b.score; + } +}; + +enum DecodeState { + kEndBatch = 0x00, // End of current decoding batch, normal case + kEndpoint = 0x01, // Endpoint is detected + kEndFeats = 0x02, // All feature is decoded + kWaitFeats = 0x03 // Feat is not enough for one chunk inference, wait +}; + +// DecodeResource is thread safe, which can be shared for multiple +// decoding threads +struct DecodeResource { + std::shared_ptr model = nullptr; + std::shared_ptr symbol_table = nullptr; + // std::shared_ptr> fst = nullptr; + std::shared_ptr> fst = nullptr; + std::shared_ptr unit_table = nullptr; + std::shared_ptr context_graph = nullptr; + std::shared_ptr post_processor = nullptr; +}; + +// Torch ASR decoder +class AsrDecoder { + public: + AsrDecoder(std::shared_ptr feature_pipeline, + std::shared_ptr resource, + const DecodeOptions& opts); + // @param block: if true, block when feature is not enough for one chunk + // inference. Otherwise, return kWaitFeats. + DecodeState Decode(bool block = true); + void Rescoring(); + void Reset(); + void ResetContinuousDecoding(); + bool DecodedSomething() const { + return !result_.empty() && !result_[0].sentence.empty(); + } + + // This method is used for time benchmark + int num_frames_in_current_chunk() const { + return num_frames_in_current_chunk_; + } + int frame_shift_in_ms() const { + return model_->subsampling_rate() * + feature_pipeline_->config().frame_shift * 1000 / + feature_pipeline_->config().sample_rate; + } + int feature_frame_shift_in_ms() const { + return feature_pipeline_->config().frame_shift * 1000 / + feature_pipeline_->config().sample_rate; + } + const std::vector& result() const { return result_; } + + private: + DecodeState AdvanceDecoding(bool block = true); + void AttentionRescoring(); + + void UpdateResult(bool finish = false); + + std::shared_ptr feature_pipeline_; + std::shared_ptr model_; + std::shared_ptr post_processor_; + std::shared_ptr context_graph_; + + // std::shared_ptr> fst_ = nullptr; + std::shared_ptr> fst_ = nullptr; + // output symbol table + std::shared_ptr symbol_table_; + // e2e unit symbol table + std::shared_ptr unit_table_ = nullptr; + const DecodeOptions& opts_; + // cache feature + bool start_ = false; + // For continuous decoding + int num_frames_ = 0; + int global_frame_offset_ = 0; + const int time_stamp_gap_ = 100; // timestamp gap between words in a sentence + + std::unique_ptr searcher_; + std::unique_ptr ctc_endpointer_; + + int num_frames_in_current_chunk_ = 0; + std::vector result_; + + public: + WENET_DISALLOW_COPY_AND_ASSIGN(AsrDecoder); +}; + +} // namespace wenet + +#endif // DECODER_ASR_DECODER_H_ diff --git a/runtime/core/post_processor/post_processor.cc b/runtime/core/post_processor/post_processor.cc index 9af50095c..cd03dc472 100644 --- a/runtime/core/post_processor/post_processor.cc +++ b/runtime/core/post_processor/post_processor.cc @@ -1,96 +1,96 @@ -// Copyright (c) 2021 Xingchen Song sxc19@mails.tsinghua.edu.cn -// 2023 Jing Du (thuduj12@163.com) -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License - -#include "post_processor/post_processor.h" -#include -#include - -#include "utils/string.h" - -namespace wenet { -void PostProcessor::InitITNResource(const std::string& tagger_path, - const std::string& verbalizer_path) { - auto itn_processor = - std::make_shared(tagger_path, verbalizer_path); - itn_resource = itn_processor; -} - -std::string PostProcessor::ProcessSpace(const std::string& str) { - std::string result = str; - // 1. remove ' ' if needed - // only spaces between mandarin words need to be removed, please note that - // if str contains '_', we assume that the decoding type must be - // `CtcPrefixBeamSearch` and this branch will do nothing since str must be - // obtained via "".join() (in function `AsrDecoder::UpdateResult()`) - if (opts_.language_type == kMandarinEnglish && !str.empty()) { - result.clear(); - // split str by ' ' - std::vector words; - std::stringstream ss(str); - std::string tmp; - while (ss >> tmp) { - words.push_back(tmp); - } - // check english word - bool is_englishword_prev = false; - bool is_englishword_now = false; - for (std::string& w : words) { - is_englishword_now = CheckEnglishWord(w); - if (is_englishword_prev && is_englishword_now) { - result += (' ' + w); - } else { - result += (w); - } - is_englishword_prev = is_englishword_now; - } - } - // 2. replace '_' with ' ' - // this should be done for all cases (both kMandarinEnglish and kIndoEuropean) - result = ProcessBlank(result, opts_.lowercase); - return result; -} - -std::string del_substr(const std::string& str, const std::string& sub) { - std::string result = str; - int pos = 0; - while (string::npos != (pos = result.find(sub))) { - result.erase(pos, sub.size()); - } - return result; -} - -std::string PostProcessor::ProcessSymbols(const std::string& str) { - std::string result = str; - result = del_substr(result, ""); - result = del_substr(result, ""); - result = del_substr(result, ""); - return result; -} - -std::string PostProcessor::Process(const std::string& str, bool finish) { - std::string result; - // remove symbols with "<>" first - result = ProcessSymbols(str); - result = ProcessSpace(result); - // TODO(xcsong): do itn/punctuation if finish == true - if (finish == true && opts_.itn) { - if (nullptr != itn_resource) { - result = itn_resource->normalize(result); - } - } - return result; -} - -} // namespace wenet +// Copyright (c) 2021 Xingchen Song sxc19@mails.tsinghua.edu.cn +// 2023 Jing Du (thuduj12@163.com) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License + +#include "post_processor/post_processor.h" +#include +#include + +#include "utils/string.h" + +namespace wenet { +void PostProcessor::InitITNResource(const std::string& tagger_path, + const std::string& verbalizer_path) { + auto itn_processor = + std::make_shared(tagger_path, verbalizer_path); + itn_resource = itn_processor; +} + +std::string PostProcessor::ProcessSpace(const std::string& str) { + std::string result = str; + // 1. remove ' ' if needed + // only spaces between mandarin words need to be removed, please note that + // if str contains '_', we assume that the decoding type must be + // `CtcPrefixBeamSearch` and this branch will do nothing since str must be + // obtained via "".join() (in function `AsrDecoder::UpdateResult()`) + if (opts_.language_type == kMandarinEnglish && !str.empty()) { + result.clear(); + // split str by ' ' + std::vector words; + std::stringstream ss(str); + std::string tmp; + while (ss >> tmp) { + words.push_back(tmp); + } + // check english word + bool is_englishword_prev = false; + bool is_englishword_now = false; + for (std::string& w : words) { + is_englishword_now = CheckEnglishWord(w); + if (is_englishword_prev && is_englishword_now) { + result += (' ' + w); + } else { + result += (w); + } + is_englishword_prev = is_englishword_now; + } + } + // 2. replace '_' with ' ' + // this should be done for all cases (both kMandarinEnglish and kIndoEuropean) + result = ProcessBlank(result, opts_.lowercase); + return result; +} + +std::string del_substr(const std::string& str, const std::string& sub) { + std::string result = str; + int pos = 0; + while (string::npos != (pos = result.find(sub))) { + result.erase(pos, sub.size()); + } + return result; +} + +std::string PostProcessor::ProcessSymbols(const std::string& str) { + std::string result = str; + result = del_substr(result, ""); + result = del_substr(result, ""); + result = del_substr(result, ""); + return result; +} + +std::string PostProcessor::Process(const std::string& str, bool finish) { + std::string result; + // remove symbols with "<>" first + result = ProcessSymbols(str); + result = ProcessSpace(result); + // TODO(xcsong): do itn/punctuation if finish == true + if (finish == true && opts_.itn) { + if (nullptr != itn_resource) { + result = itn_resource->normalize(result); + } + } + return result; +} + +} // namespace wenet diff --git a/runtime/core/post_processor/post_processor.h b/runtime/core/post_processor/post_processor.h index f118b44d3..dee588174 100644 --- a/runtime/core/post_processor/post_processor.h +++ b/runtime/core/post_processor/post_processor.h @@ -1,79 +1,79 @@ -// Copyright (c) 2021 Xingchen Song sxc19@mails.tsinghua.edu.cn -// 2023 Jing Du (thuduj12@163.com) -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License - -#ifndef POST_PROCESSOR_POST_PROCESSOR_H_ -#define POST_PROCESSOR_POST_PROCESSOR_H_ - -#include -#include -#include - -#include "post_processor/processor.h" -#include "utils/utils.h" - -namespace wenet { - -enum LanguageType { - // spaces between **mandarin words** should be removed. - // cases of processing spaces with mandarin-only, english-only - // and mandarin-english code-switch can be found in post_processor_test.cc - kMandarinEnglish = 0x00, - // spaces should be kept for most of the - // Indo-European languages (i.e., deutsch or english-deutsch code-switch). - // cases of those languages can be found in post_processor_test.cc - kIndoEuropean = 0x01 -}; - -struct PostProcessOptions { - // space options - // The decoded result may contain spaces (' ' or '_'), - // we will process those spaces according to language_type. More details can - // be found in - // https://github.com/wenet-e2e/wenet/issues/583#issuecomment-907994058 - LanguageType language_type = kMandarinEnglish; - // whether lowercase letters are required - bool lowercase = true; - bool itn = false; -}; - -// TODO(xcsong): add punctuation related resource - -// Post Processor -class PostProcessor { - public: - explicit PostProcessor(PostProcessOptions&& opts) : opts_(std::move(opts)) {} - explicit PostProcessor(const PostProcessOptions& opts) : opts_(opts) {} - // call other functions to do post processing - std::string Process(const std::string& str, bool finish); - // process spaces according to configurations - std::string ProcessSpace(const std::string& str); - std::string ProcessSymbols(const std::string& str); - // TODO(xcsong): add punctuation - // void Punctuate(const std::string& str); - - void InitITNResource(const std::string& tagger_path, - const std::string& verbalizer_path); - - private: - std::shared_ptr itn_resource = nullptr; - const PostProcessOptions opts_; - - public: - WENET_DISALLOW_COPY_AND_ASSIGN(PostProcessor); -}; - -} // namespace wenet - -#endif // POST_PROCESSOR_POST_PROCESSOR_H_ +// Copyright (c) 2021 Xingchen Song sxc19@mails.tsinghua.edu.cn +// 2023 Jing Du (thuduj12@163.com) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License + +#ifndef POST_PROCESSOR_POST_PROCESSOR_H_ +#define POST_PROCESSOR_POST_PROCESSOR_H_ + +#include +#include +#include + +#include "post_processor/processor.h" +#include "utils/utils.h" + +namespace wenet { + +enum LanguageType { + // spaces between **mandarin words** should be removed. + // cases of processing spaces with mandarin-only, english-only + // and mandarin-english code-switch can be found in post_processor_test.cc + kMandarinEnglish = 0x00, + // spaces should be kept for most of the + // Indo-European languages (i.e., deutsch or english-deutsch code-switch). + // cases of those languages can be found in post_processor_test.cc + kIndoEuropean = 0x01 +}; + +struct PostProcessOptions { + // space options + // The decoded result may contain spaces (' ' or '_'), + // we will process those spaces according to language_type. More details can + // be found in + // https://github.com/wenet-e2e/wenet/issues/583#issuecomment-907994058 + LanguageType language_type = kMandarinEnglish; + // whether lowercase letters are required + bool lowercase = true; + bool itn = false; +}; + +// TODO(xcsong): add punctuation related resource + +// Post Processor +class PostProcessor { + public: + explicit PostProcessor(PostProcessOptions&& opts) : opts_(std::move(opts)) {} + explicit PostProcessor(const PostProcessOptions& opts) : opts_(opts) {} + // call other functions to do post processing + std::string Process(const std::string& str, bool finish); + // process spaces according to configurations + std::string ProcessSpace(const std::string& str); + std::string ProcessSymbols(const std::string& str); + // TODO(xcsong): add punctuation + // void Punctuate(const std::string& str); + + void InitITNResource(const std::string& tagger_path, + const std::string& verbalizer_path); + + private: + std::shared_ptr itn_resource = nullptr; + const PostProcessOptions opts_; + + public: + WENET_DISALLOW_COPY_AND_ASSIGN(PostProcessor); +}; + +} // namespace wenet + +#endif // POST_PROCESSOR_POST_PROCESSOR_H_ diff --git a/runtime/resource b/runtime/resource deleted file mode 120000 index 67a624052..000000000 --- a/runtime/resource +++ /dev/null @@ -1 +0,0 @@ -../../wenet/runtime/resource \ No newline at end of file