diff --git a/src/capture/capture.cpp b/src/capture/capture.cpp index 2c2392b4f5..5032e2841e 100644 --- a/src/capture/capture.cpp +++ b/src/capture/capture.cpp @@ -58,7 +58,7 @@ #include "capture/title_extractors/agent_title_extractor.hpp" #include "capture/title_extractors/ocr_title_extractor_builder.hpp" -#include "capture/redis.hpp" +#include "capture/redis_session.hpp" #include "capture/capture_params.hpp" #include "capture/drawable_params.hpp" diff --git a/src/capture/redis.cpp b/src/capture/redis.cpp index 599a3a849a..c9e9574c61 100644 --- a/src/capture/redis.cpp +++ b/src/capture/redis.cpp @@ -20,15 +20,12 @@ Author(s): Proxies Team #include "capture/redis.hpp" #include "utils/sugar/int_to_chars.hpp" -#include "utils/static_string.hpp" -#include "utils/to_timeval.hpp" -#include "utils/netutils.hpp" -#include "utils/select.hpp" #include "cxx/cxx.hpp" #include #include +#include #include @@ -375,168 +372,3 @@ RedisWriter::IOCode RedisWriter::ssl_result_to_io_code(int res, IOCode code) ssl_errnum = ssl_error; return code; } - -RedisSyncSession::IOCode RedisSyncSession::open( - zstring_view address, unsigned int port, - bounded_chars_view<0, 256> password, unsigned int db, - std::chrono::milliseconds timeout, TlsParams tls_params) -{ - using IOCode = IOCode; - - tv = to_timeval(timeout); - - // open socket - close(); - int fd = addr_connect(address, checked_int(port), false).release(); - if (fd == -1) { - return IOCode::ConnectError; - } - - writer.set_fd(fd); - - io_fd_zero(rfds); - io_fd_zero(wfds); - - // enable tls - if (tls_params.enable_tls) { - error_msg = writer.enable_tls(tls_params.ca_cert_file, - tls_params.cert_file, - tls_params.key_file); - if (error_msg) { - return IOCode::CertificateError; - } - - auto code = loop_event( - [&](bytes_view){ return RedisWriter::IOResult{writer.ssl_connect(), 0}; }, - ""_av, IOCode::WantWrite - ); - if (code != IOCode::Ok) { - return code; - } - } - - RedisAuth auth(password, db); - if (auth.count_command() == 2) { - state = State::WaitPassword; - } - else { - assert(auth.count_command() == 1); - state = State::WaitResponse; - } - - return send_impl(auth.packet()); -} - -void RedisSyncSession::close() -{ - error_msg = nullptr; - int fd = writer.get_fd(); - writer.close(); - if (fd != -1) { - ::close(fd); - } -} - -RedisSyncSession::IOCode RedisSyncSession::send(bytes_view buffer) -{ - assert(writer.get_fd() != -1); - - using namespace std::string_view_literals; - constexpr auto expected_resp = "+OK\r\n"sv; - static_assert(resp_buffer_len == expected_resp.size()); - - for (;;) { - auto remaining = make_writable_array_view(resp_buffer).first(expected_resp.size()); - auto code = recv_impl(remaining); - - if (code == IOCode::Ok) { - auto resp = make_writable_array_view(resp_buffer); - if (resp.first(expected_resp.size()).as() != expected_resp) { - error_msg = resp.data(); - resp.back() = '\0'; - return IOCode::UnknownResponse; - } - - if (state == State::WaitResponse) { - return send_impl(buffer); - } - state = State::WaitResponse; - } - else { - return code; - } - } -} - -int RedisSyncSession::get_last_errno() const noexcept -{ - return writer.get_last_errno(); -} - -char const* RedisSyncSession::get_last_error_message() const -{ - return error_msg ? error_msg : writer.get_last_error_message(); -} - -RedisSyncSession::IOCode RedisSyncSession::send_impl(bytes_view buffer) -{ - return loop_event( - [&](bytes_view buffer){ return writer.send(buffer); }, - buffer, IOCode::WantWrite - ); -} - -RedisSyncSession::IOCode RedisSyncSession::recv_impl(writable_bytes_view buffer) -{ - return loop_event( - [&](writable_bytes_view buffer){ return writer.recv(buffer); }, - buffer, IOCode::WantRead - ); -} - -template -RedisSyncSession::IOCode RedisSyncSession::loop_event(F&& f, Buffer buffer, IOCode code_for_waiting) -{ - using IOCode = IOCode; - - fd_set* rfds_ref = nullptr; - fd_set* wfds_ref = nullptr; - - int fd = writer.get_fd(); - - for (;;) { - auto result = f(buffer); - if (result.code == IOCode::Ok) { - if (result.len == buffer.size()) { - return result.code; - } - buffer = buffer.drop_front(result.len); - result.code = code_for_waiting; - } - - if (result.code == IOCode::WantRead) { - rfds_ref = &rfds; - io_fd_set(fd, rfds); - } - else if (result.code == IOCode::WantWrite) { - wfds_ref = &wfds; - io_fd_set(fd, wfds); - } - else { - close(); - return result.code; - } - - auto tv_remaining = tv; - int nfds = select(fd+1, rfds_ref, wfds_ref, nullptr, &tv_remaining); - - if (nfds > 0) { - // ok, continue - } - else if (nfds == 0 || (errno != EINTR && errno != EAGAIN)) { - // possibly EINVAL -> negative timeout - error_msg = strerror(errno); - return (nfds == 0) ? IOCode::Timeout : IOCode::ConnectError; - } - } -} diff --git a/src/capture/redis.hpp b/src/capture/redis.hpp index 5cd051ad94..195101fe11 100644 --- a/src/capture/redis.hpp +++ b/src/capture/redis.hpp @@ -25,13 +25,9 @@ Author(s): Proxies Team #include "utils/sugar/bounded_array_view.hpp" #include "utils/static_string.hpp" -#include #include #include -#include -#include - // redis_command_set(key, value) with accumulator class RedisCmdSet @@ -149,50 +145,3 @@ struct RedisWriter int fd = -1; int ssl_errnum = 0; }; - - -struct RedisSyncSession -{ - using IOCode = RedisIOCode; - - struct TlsParams - { - bool enable_tls; - const char * cert_file; - const char * key_file; - const char * ca_cert_file; - }; - - IOCode open( - zstring_view address, unsigned port, - bounded_chars_view<0, 256> password, unsigned db, - std::chrono::milliseconds timeout, TlsParams tls_params); - - void close(); - - IOCode send(bytes_view buffer); - - int get_last_errno() const noexcept; - char const* get_last_error_message() const; - -private: - template - IOCode loop_event(F&& f, Buffer buffer, IOCode code_for_waiting); - IOCode send_impl(bytes_view buffer); - IOCode recv_impl(writable_bytes_view buffer); - - enum class State : bool - { - WaitResponse, - WaitPassword, - }; - - timeval tv; - State state; - fd_set rfds; - fd_set wfds; - RedisWriter writer; - static constexpr std::size_t resp_buffer_len = 5; - char resp_buffer[resp_buffer_len + 1]; - char const* error_msg = nullptr; -}; diff --git a/src/capture/redis_session.cpp b/src/capture/redis_session.cpp new file mode 100644 index 0000000000..bc65cb73e6 --- /dev/null +++ b/src/capture/redis_session.cpp @@ -0,0 +1,199 @@ +/* +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with this program; if not, write to the Free Software +Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. + +Product name: redemption, a FLOSS RDP proxy +Copyright (C) Wallix 2021 +Author(s): Proxies Team +*/ + +#include "capture/redis_session.hpp" +#include "utils/sugar/int_to_chars.hpp" +#include "utils/static_string.hpp" +#include "utils/to_timeval.hpp" +#include "utils/netutils.hpp" +#include "utils/select.hpp" +#include "cxx/cxx.hpp" + +#include +#include + +#include +#include + + +RedisSyncSession::IOCode RedisSyncSession::open( + zstring_view address, unsigned int port, + bounded_chars_view<0, 256> password, unsigned int db, + std::chrono::milliseconds timeout, TlsParams tls_params) +{ + using IOCode = IOCode; + + tv = to_timeval(timeout); + + // open socket + close(); + int fd = addr_connect(address, checked_int(port), false).release(); + if (fd == -1) { + return IOCode::ConnectError; + } + + writer.set_fd(fd); + + io_fd_zero(rfds); + io_fd_zero(wfds); + + // enable tls + if (tls_params.enable_tls) { + error_msg = writer.enable_tls(tls_params.ca_cert_file, + tls_params.cert_file, + tls_params.key_file); + if (error_msg) { + return IOCode::CertificateError; + } + + auto code = loop_event( + [&](bytes_view){ return RedisWriter::IOResult{writer.ssl_connect(), 0}; }, + ""_av, IOCode::WantWrite + ); + if (code != IOCode::Ok) { + return code; + } + } + + RedisAuth auth(password, db); + if (auth.count_command() == 2) { + state = State::WaitPassword; + } + else { + assert(auth.count_command() == 1); + state = State::WaitResponse; + } + + return send_impl(auth.packet()); +} + +void RedisSyncSession::close() +{ + error_msg = nullptr; + int fd = writer.get_fd(); + writer.close(); + if (fd != -1) { + ::close(fd); + } +} + +RedisSyncSession::IOCode RedisSyncSession::send(bytes_view buffer) +{ + assert(writer.get_fd() != -1); + + using namespace std::string_view_literals; + constexpr auto expected_resp = "+OK\r\n"sv; + static_assert(resp_buffer_len == expected_resp.size()); + + for (;;) { + auto remaining = make_writable_array_view(resp_buffer).first(expected_resp.size()); + auto code = recv_impl(remaining); + + if (code == IOCode::Ok) { + auto resp = make_writable_array_view(resp_buffer); + if (resp.first(expected_resp.size()).as() != expected_resp) { + error_msg = resp.data(); + resp.back() = '\0'; + return IOCode::UnknownResponse; + } + + if (state == State::WaitResponse) { + return send_impl(buffer); + } + state = State::WaitResponse; + } + else { + return code; + } + } +} + +int RedisSyncSession::get_last_errno() const noexcept +{ + return writer.get_last_errno(); +} + +char const* RedisSyncSession::get_last_error_message() const +{ + return error_msg ? error_msg : writer.get_last_error_message(); +} + +RedisSyncSession::IOCode RedisSyncSession::send_impl(bytes_view buffer) +{ + return loop_event( + [&](bytes_view buffer){ return writer.send(buffer); }, + buffer, IOCode::WantWrite + ); +} + +RedisSyncSession::IOCode RedisSyncSession::recv_impl(writable_bytes_view buffer) +{ + return loop_event( + [&](writable_bytes_view buffer){ return writer.recv(buffer); }, + buffer, IOCode::WantRead + ); +} + +template +RedisSyncSession::IOCode RedisSyncSession::loop_event(F&& f, Buffer buffer, IOCode code_for_waiting) +{ + using IOCode = IOCode; + + fd_set* rfds_ref = nullptr; + fd_set* wfds_ref = nullptr; + + int fd = writer.get_fd(); + + for (;;) { + auto result = f(buffer); + if (result.code == IOCode::Ok) { + if (result.len == buffer.size()) { + return result.code; + } + buffer = buffer.drop_front(result.len); + result.code = code_for_waiting; + } + + if (result.code == IOCode::WantRead) { + rfds_ref = &rfds; + io_fd_set(fd, rfds); + } + else if (result.code == IOCode::WantWrite) { + wfds_ref = &wfds; + io_fd_set(fd, wfds); + } + else { + close(); + return result.code; + } + + auto tv_remaining = tv; + int nfds = select(fd+1, rfds_ref, wfds_ref, nullptr, &tv_remaining); + + if (nfds > 0) { + // ok, continue + } + else if (nfds == 0 || (errno != EINTR && errno != EAGAIN)) { + // possibly EINVAL -> negative timeout + error_msg = strerror(errno); + return (nfds == 0) ? IOCode::Timeout : IOCode::ConnectError; + } + } +} diff --git a/src/capture/redis_session.hpp b/src/capture/redis_session.hpp new file mode 100644 index 0000000000..a1d3e0d0a5 --- /dev/null +++ b/src/capture/redis_session.hpp @@ -0,0 +1,73 @@ +/* +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with this program; if not, write to the Free Software +Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. + +Product name: redemption, a FLOSS RDP proxy +Copyright (C) Wallix 2021 +Author(s): Proxies Team +*/ + +#pragma once + +#include "capture/redis.hpp" + +#include + +#include + +struct RedisSyncSession +{ + using IOCode = RedisIOCode; + + struct TlsParams + { + bool enable_tls; + const char * cert_file; + const char * key_file; + const char * ca_cert_file; + }; + + IOCode open( + zstring_view address, unsigned port, + bounded_chars_view<0, 256> password, unsigned db, + std::chrono::milliseconds timeout, TlsParams tls_params); + + void close(); + + IOCode send(bytes_view buffer); + + int get_last_errno() const noexcept; + char const* get_last_error_message() const; + +private: + template + IOCode loop_event(F&& f, Buffer buffer, IOCode code_for_waiting); + IOCode send_impl(bytes_view buffer); + IOCode recv_impl(writable_bytes_view buffer); + + enum class State : bool + { + WaitResponse, + WaitPassword, + }; + + timeval tv; + State state; + fd_set rfds; + fd_set wfds; + RedisWriter writer; + static constexpr std::size_t resp_buffer_len = 5; + char resp_buffer[resp_buffer_len + 1]; + char const* error_msg = nullptr; +}; diff --git a/targets.jam b/targets.jam index 2df12fa43f..ec8dffb54f 100644 --- a/targets.jam +++ b/targets.jam @@ -37,6 +37,7 @@ exe rdpproxy : src/capture/rail_screen_computation.o src/capture/rdp_ppocr/get_ocr_constants_from_locale_id.o src/capture/redis.o + src/capture/redis_session.o src/capture/session_update_buffer.o src/capture/utils/pattern_searcher.o src/capture/utils/pattutils.o @@ -390,7 +391,6 @@ lib libcredis : log.o openssl src/capture/redis.o - src/utils/netutils.o src/utils/uninit_buffer.o $(LIB_DEPENDENCIES) ; @@ -414,6 +414,7 @@ lib libredrec : src/capture/rail_screen_computation.o src/capture/rdp_ppocr/get_ocr_constants_from_locale_id.o src/capture/redis.o + src/capture/redis_session.o src/capture/regions_capture.o src/capture/utils/pattern_searcher.o src/capture/utils/pattutils.o @@ -774,6 +775,7 @@ test-run tests/capture/test_capture : src/capture/rail_screen_computation.o src/capture/rdp_ppocr/get_ocr_constants_from_locale_id.o src/capture/redis.o + src/capture/redis_session.o src/capture/utils/pattern_searcher.o src/capture/utils/pattutils.o src/capture/video_capture.o @@ -900,6 +902,7 @@ test-run tests/capture/test_redis : crypto openssl src/capture/redis.o + src/capture/redis_session.o src/utils/netutils.o -pthread ; @@ -2094,6 +2097,7 @@ test-run tests/front/test_front : src/capture/rail_screen_computation.o src/capture/rdp_ppocr/get_ocr_constants_from_locale_id.o src/capture/redis.o + src/capture/redis_session.o src/capture/session_update_buffer.o src/capture/utils/pattern_searcher.o src/capture/utils/pattutils.o @@ -2375,6 +2379,7 @@ test-run tests/lib/test_do_recorder : src/capture/rail_screen_computation.o src/capture/rdp_ppocr/get_ocr_constants_from_locale_id.o src/capture/redis.o + src/capture/redis_session.o src/capture/regions_capture.o src/capture/utils/pattern_searcher.o src/capture/utils/pattutils.o @@ -3983,6 +3988,7 @@ test-run tests/server/test_mstsc_client : src/capture/rail_screen_computation.o src/capture/rdp_ppocr/get_ocr_constants_from_locale_id.o src/capture/redis.o + src/capture/redis_session.o src/capture/session_update_buffer.o src/capture/utils/pattern_searcher.o src/capture/utils/pattutils.o @@ -4058,6 +4064,7 @@ test-run tests/server/test_mstsc_client_rdp50bulk : src/capture/rail_screen_computation.o src/capture/rdp_ppocr/get_ocr_constants_from_locale_id.o src/capture/redis.o + src/capture/redis_session.o src/capture/session_update_buffer.o src/capture/utils/pattern_searcher.o src/capture/utils/pattutils.o @@ -4133,6 +4140,7 @@ test-run tests/server/test_rdesktop_client : src/capture/rail_screen_computation.o src/capture/rdp_ppocr/get_ocr_constants_from_locale_id.o src/capture/redis.o + src/capture/redis_session.o src/capture/session_update_buffer.o src/capture/utils/pattern_searcher.o src/capture/utils/pattutils.o @@ -4919,6 +4927,7 @@ obj src/capture/params_from_ini.o : $(REDEMPTION_SRC_PATH)/capture/params_from_i obj src/capture/rail_screen_computation.o : $(REDEMPTION_SRC_PATH)/capture/rail_screen_computation.cpp ; obj src/capture/rdp_ppocr/get_ocr_constants_from_locale_id.o : $(REDEMPTION_SRC_PATH)/capture/rdp_ppocr/get_ocr_constants_from_locale_id.cpp ; obj src/capture/redis.o : $(REDEMPTION_SRC_PATH)/capture/redis.cpp ; +obj src/capture/redis_session.o : $(REDEMPTION_SRC_PATH)/capture/redis_session.cpp ; obj src/capture/regions_capture.o : $(REDEMPTION_SRC_PATH)/capture/regions_capture.cpp : ZLIB_CONST ; obj src/capture/session_update_buffer.o : $(REDEMPTION_SRC_PATH)/capture/session_update_buffer.cpp ; obj src/capture/utils/pattern_searcher.o : $(REDEMPTION_SRC_PATH)/capture/utils/pattern_searcher.cpp : $(REDEMPTION_HYPERSCAN_FLAGS) ; @@ -6018,6 +6027,7 @@ explicit src/capture/rail_screen_computation.o src/capture/rdp_ppocr/get_ocr_constants_from_locale_id.o src/capture/redis.o + src/capture/redis_session.o src/capture/regions_capture.o src/capture/session_update_buffer.o src/capture/utils/pattern_searcher.o diff --git a/tests/capture/test_redis.cpp b/tests/capture/test_redis.cpp index 533bcd7208..f8c9d7e234 100644 --- a/tests/capture/test_redis.cpp +++ b/tests/capture/test_redis.cpp @@ -21,6 +21,7 @@ Author(s): Proxies Team #include "test_only/test_framework/redemption_unit_tests.hpp" #include "capture/redis.hpp" +#include "capture/redis_session.hpp" #include "core/listen.hpp" #include