diff --git a/runtime/core/bin/websocket_client_main.cc b/runtime/core/bin/websocket_client_main.cc index 3eaa96069d..6a96f37574 100644 --- a/runtime/core/bin/websocket_client_main.cc +++ b/runtime/core/bin/websocket_client_main.cc @@ -22,18 +22,19 @@ DEFINE_int32(port, 10086, "port of websocket server"); DEFINE_int32(nbest, 1, "n-best of decode result"); DEFINE_string(wav_path, "", "test wav file path"); DEFINE_bool(continuous_decoding, false, "continuous decoding mode"); +DEFINE_int32(sr, 16000, "audio sample rate"); int main(int argc, char* argv[]) { gflags::ParseCommandLineFlags(&argc, &argv, false); google::InitGoogleLogging(argv[0]); wenet::WebSocketClient client(FLAGS_hostname, FLAGS_port); client.set_nbest(FLAGS_nbest); + client.set_sr(FLAGS_sr); client.set_continuous_decoding(FLAGS_continuous_decoding); client.SendStartSignal(); wenet::WavReader wav_reader(FLAGS_wav_path); - const int sample_rate = 16000; - // Only support 16K + int sample_rate = client.sample_rate_; CHECK_EQ(wav_reader.sample_rate(), sample_rate); const int num_samples = wav_reader.num_samples(); // Send data every 0.5 second diff --git a/runtime/core/websocket/websocket_client.h b/runtime/core/websocket/websocket_client.h index 76ec3aa451..6e1b42e595 100644 --- a/runtime/core/websocket/websocket_client.h +++ b/runtime/core/websocket/websocket_client.h @@ -38,7 +38,7 @@ using tcp = boost::asio::ip::tcp; // from class WebSocketClient { public: WebSocketClient(const std::string& host, int port); - + int sample_rate_; void SendTextData(const std::string& data); void SendBinaryData(const void* data, size_t size); void ReadLoopFunc(); @@ -47,6 +47,7 @@ class WebSocketClient { void SendStartSignal(); void SendEndSignal(); void set_nbest(int nbest) { nbest_ = nbest; } + void set_sr(int sr) { sample_rate_ = sr; } void set_continuous_decoding(bool continuous_decoding) { continuous_decoding_ = continuous_decoding; }