diff --git a/.circleci/continue-config.yml b/.circleci/continue-config.yml index c196fa1b..449958b8 100644 --- a/.circleci/continue-config.yml +++ b/.circleci/continue-config.yml @@ -1,11 +1,11 @@ # Copyright 2023 Ant Group Co., Ltd. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -43,7 +43,7 @@ jobs: set +e declare -i test_status bazel test //... -c opt --ui_event_filters=-info,-debug,-warning --test_output=errors --jobs 16 | tee test_result.log; test_status=${PIPESTATUS[0]} - + git clone https://github.com/secretflow/devtools.git sh devtools/rename-junit-xml.sh @@ -75,10 +75,10 @@ jobs: set +e declare -i test_status bazel test //... -c opt --ui_event_filters=-info,-debug,-warning --test_output=errors | tee test_result.log; test_status=${PIPESTATUS[0]} - + git clone https://github.com/secretflow/devtools.git sh devtools/rename-junit-xml.sh - + find bazel-bin/ -perm +111 -type f -name "*_test" -print0 | xargs -0 tar -cvzf test_binary.tar.gz find bazel-testlogs/ -type f -name "test.log" -print0 | xargs -0 tar -cvzf test_logs.tar.gz exit ${test_status} @@ -93,7 +93,7 @@ jobs: # See: https://circleci.com/docs/2.0/configuration-reference/#workflows workflows: unittest: - when: << pipeline.parameters.build-and-run >> + when: << pipeline.parameters.build-and-run >> jobs: - linux_ut: matrix: diff --git a/bazel/msgpack.BUILD b/bazel/msgpack.BUILD index 51f9d98e..7d8ca687 100644 --- a/bazel/msgpack.BUILD +++ b/bazel/msgpack.BUILD @@ -25,10 +25,12 @@ yacl_cmake_external( name = "msgpack", cache_entries = { "MSGPACK_CXX17": "ON", + "MSGPACK_USE_BOOST": "OFF", "MSGPACK_BUILD_EXAMPLES": "OFF", "BUILD_SHARED_LIBS": "OFF", "MSGPACK_BUILD_TESTS": "OFF", }, + defines = ["MSGPACK_NO_BOOST"], lib_source = ":all_srcs", out_headers_only = True, ) diff --git a/bazel/openssl.BUILD b/bazel/openssl.BUILD index a85bb90d..558957ed 100644 --- a/bazel/openssl.BUILD +++ b/bazel/openssl.BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("@rules_foreign_cc//foreign_cc:defs.bzl", "configure_make") +load("@yacl//bazel:yacl.bzl", "yacl_configure_make") package(default_visibility = ["//visibility:public"]) @@ -28,7 +28,7 @@ config_setting( visibility = ["//visibility:private"], ) -configure_make( +yacl_configure_make( name = "openssl", configure_command = select( { @@ -66,10 +66,7 @@ configure_make( "libcrypto.a", ], targets = [ - # NOTE: $(nproc --all) returns host number of cpus in ANT ACI pod - # Hence we choose a fixed number of 4. - # "make -j`nproc --all`", - "-s -j4", + "-s", "-s install_sw", ], ) diff --git a/bazel/patches/brpc-2324.patch b/bazel/patches/brpc-2324.patch deleted file mode 100644 index 20189e48..00000000 --- a/bazel/patches/brpc-2324.patch +++ /dev/null @@ -1,28 +0,0 @@ -diff --git a/src/butil/thread_local.h b/src/butil/thread_local.h -index 74f9533b9e..a3cb1ff087 100644 ---- a/src/butil/thread_local.h -+++ b/src/butil/thread_local.h -@@ -31,7 +31,7 @@ - #endif // _MSC_VER - - #define BAIDU_VOLATILE_THREAD_LOCAL(type, var_name, default_value) \ -- BAIDU_THREAD_LOCAL type var_name = default_value; \ -+ BAIDU_THREAD_LOCAL type var_name = default_value; \ - static __attribute__((noinline, unused)) type get_##var_name(void) { \ - asm volatile(""); \ - return var_name; \ -@@ -46,10 +46,10 @@ - var_name = v; \ - } - --#if defined(__clang__) --// Clang compiler is incorrectly caching the address of thread_local variables --// across a suspend-point. The following macros used to disable the volatile --// thread local access optimization. -+#if (defined (__aarch64__) && defined (__GNUC__)) || defined(__clang__) -+// GNU compiler under aarch and Clang compiler is incorrectly caching the -+// address of thread_local variables across a suspend-point. The following -+// macros used to disable the volatile thread local access optimization. - #define BAIDU_GET_VOLATILE_THREAD_LOCAL(var_name) get_##var_name() - #define BAIDU_GET_PTR_VOLATILE_THREAD_LOCAL(var_name) get_ptr_##var_name() - #define BAIDU_SET_VOLATILE_THREAD_LOCAL(var_name, value) set_##var_name(value) diff --git a/bazel/patches/brpc.patch b/bazel/patches/brpc.patch index d3849937..284e12be 100644 --- a/bazel/patches/brpc.patch +++ b/bazel/patches/brpc.patch @@ -142,15 +142,3 @@ index a2eea9cf..d5c7372f 100644 "//conditions:default": [ "test_file_util_linux.cc", "proc_maps_linux_unittest.cc", -diff --git a/BUILD.bazel b/BUILD.bazel -index 58e8863d..1b35348a 100644 ---- a/BUILD.bazel -+++ b/BUILD.bazel -@@ -256,6 +256,7 @@ objc_library( - "src/butil/memory/scoped_ptr.h", - "src/butil/move.h", - "src/butil/port.h", -+ "src/butil/string_printf.h", - "src/butil/posix/eintr_wrapper.h", - "src/butil/scoped_generic.h", - "src/butil/strings/string16.h", diff --git a/bazel/patches/cpu_features.patch b/bazel/patches/cpu_features.patch deleted file mode 100644 index 625fbe75..00000000 --- a/bazel/patches/cpu_features.patch +++ /dev/null @@ -1,13 +0,0 @@ -diff --git a/src/impl_aarch64_linux_or_android.c b/src/impl_aarch64_linux_or_android.c -index ef923d9..18f90e4 100644 ---- a/src/impl_aarch64_linux_or_android.c -+++ b/src/impl_aarch64_linux_or_android.c -@@ -15,7 +15,7 @@ - #include "cpu_features_macros.h" - - #ifdef CPU_FEATURES_ARCH_AARCH64 --#if defined(CPU_FEATURES_OS_LINUX) || defined(CPU_FEATURES_OS_ANDROID) -+#if defined(CPU_FEATURES_OS_LINUX) || defined(CPU_FEATURES_OS_ANDROID) || defined(CPU_FEATURES_OS_MACOS) - - #include "cpuinfo_aarch64.h" - diff --git a/bazel/patches/ippcp.patch b/bazel/patches/ippcp.patch index af635051..87763608 100644 --- a/bazel/patches/ippcp.patch +++ b/bazel/patches/ippcp.patch @@ -170,3 +170,81 @@ index 0015431..f93411c 100644 # DEBUG flags - optimization level = 0, generation GDB information (-g) set (CMAKE_C_FLAGS_DEBUG " -O0 -g") + +diff --git a/sources/include/dispatcher.h b/sources/include/dispatcher.h +index 8290df6..a2f93d7 100644 +--- a/sources/include/dispatcher.h ++++ b/sources/include/dispatcher.h +@@ -92,9 +92,13 @@ extern "C" { + #define LIB_W7 LIB_S8 + #elif defined( _ARCH_EM64T ) && !defined( OSXEM64T ) && !defined( WIN32E ) /* Linux* OS Intel64 supports N0 */ + enum lib_enum { +- LIB_M7=0, LIB_N8=1, LIB_Y8=2, LIB_E9=3, LIB_L9=4, LIB_N0=5, LIB_K0=6, LIB_K1=7,LIB_NOMORE ++ LIB_E9=0, LIB_L9=1, LIB_K0=2, LIB_K1=3,LIB_NOMORE + }; +- #define LIB_PX LIB_M7 ++ #define LIB_PX LIB_E9 ++ #define LIB_M7 LIB_E9 ++ #define LIB_N8 LIB_E9 ++ #define LIB_Y8 LIB_E9 ++ #define LIB_N0 LIB_L9 + #elif defined( _ARCH_EM64T ) && !defined( OSXEM64T ) /* Windows* OS Intel64 doesn't support N0 */ + enum lib_enum { + LIB_M7=0, LIB_N8=1, LIB_Y8=2, LIB_E9=3, LIB_L9=4, LIB_K0=5, LIB_K1=6, LIB_NOMORE +@@ -103,11 +107,12 @@ extern "C" { + #define LIB_N0 LIB_L9 + #elif defined( OSXEM64T ) + enum lib_enum { +- LIB_Y8=0, LIB_E9=1, LIB_L9=2, LIB_K0=3, LIB_K1=4, LIB_NOMORE ++ LIB_E9=0, LIB_L9=1, LIB_K0=2, LIB_K1=3, LIB_NOMORE + }; +- #define LIB_PX LIB_Y8 +- #define LIB_M7 LIB_Y8 +- #define LIB_N8 LIB_Y8 ++ #define LIB_PX LIB_E9 ++ #define LIB_M7 LIB_E9 ++ #define LIB_N8 LIB_E9 ++ #define LIB_Y8 LIB_E9 + #define LIB_N0 LIB_L9 + #elif defined( _ARCH_LRB2 ) + enum lib_enum { +diff --git a/sources/include/owndefs.h b/sources/include/owndefs.h +index dcc1ede..7c1e93e 100644 +--- a/sources/include/owndefs.h ++++ b/sources/include/owndefs.h +@@ -632,14 +632,14 @@ extern double __intel_castu64_f64(unsigned __int64 val); + + #elif defined(linux) + /* LIN-32, LIN-64 */ +- #if ( defined(_W7) || defined(_M7) ) ++ #if ( defined(_W7) || defined(_E9) ) + #define _IPP_DATA 1 + #endif + + + /* OSX-32, OSX-64 */ + #elif defined(OSX32) || defined(OSXEM64T) +- #if ( defined(_Y8) ) ++ #if ( defined(_E9) ) + #define _IPP_DATA 1 + #endif + #endif +diff --git a/sources/ippcp/CMakeLists.txt b/sources/ippcp/CMakeLists.txt +index 315d1a3..8b11c7a 100644 +--- a/sources/ippcp/CMakeLists.txt ++++ b/sources/ippcp/CMakeLists.txt +@@ -40,12 +40,12 @@ if(WIN32) + endif(WIN32) + if(UNIX) + if(APPLE) +- set(BASE_PLATFORM_LIST ${BASE_PLATFORM_LIST} y8 e9 l9 k0 k1) ++ set(BASE_PLATFORM_LIST ${BASE_PLATFORM_LIST} e9 l9 k0 k1) + else() + if (${ARCH} MATCHES "ia32") + set(BASE_PLATFORM_LIST ${BASE_PLATFORM_LIST} w7 s8 p8 g9 h9) + else() +- set(BASE_PLATFORM_LIST ${BASE_PLATFORM_LIST} m7 n8 y8 e9 l9 n0 k0 k1) ++ set(BASE_PLATFORM_LIST ${BASE_PLATFORM_LIST} e9 l9 k0 k1) + endif(${ARCH} MATCHES "ia32") + endif(APPLE) + endif(UNIX) diff --git a/bazel/repositories.bzl b/bazel/repositories.bzl index 02882f5a..7ca7e61c 100644 --- a/bazel/repositories.bzl +++ b/bazel/repositories.bzl @@ -41,7 +41,6 @@ def yacl_deps(): # crypto related _com_github_openssl_openssl() - _com_github_floodyberry_curve25519_donna() _com_github_blake3team_blake3() _com_github_intel_ipp() _com_github_libsodium() @@ -67,17 +66,16 @@ def _com_github_brpc_brpc(): maybe( http_archive, name = "com_github_brpc_brpc", - sha256 = "d286d520ec4d317180d91ea3970494c1b8319c8867229e5c4784998c4536718f", - strip_prefix = "brpc-1.6.0", + sha256 = "48668cbc943edd1b72551e99c58516249d15767b46ea13a843eb8df1d3d1bc42", + strip_prefix = "brpc-1.7.0", type = "tar.gz", patch_args = ["-p1"], patches = [ "@yacl//bazel:patches/brpc.patch", "@yacl//bazel:patches/brpc_m1.patch", - "@yacl//bazel:patches/brpc-2324.patch", ], urls = [ - "https://github.com/apache/brpc/archive/refs/tags/1.6.0.tar.gz", + "https://github.com/apache/brpc/archive/refs/tags/1.7.0.tar.gz", ], ) @@ -137,11 +135,11 @@ def _com_google_absl(): maybe( http_archive, name = "com_google_absl", - sha256 = "5366d7e7fa7ba0d915014d387b66d0d002c03236448e1ba9ef98122c13b35c36", + sha256 = "987ce98f02eefbaf930d6e38ab16aa05737234d7afbab2d5c4ea7adbe50c28ed", type = "tar.gz", - strip_prefix = "abseil-cpp-20230125.3", + strip_prefix = "abseil-cpp-20230802.1", urls = [ - "https://github.com/abseil/abseil-cpp/archive/refs/tags/20230125.3.tar.gz", + "https://github.com/abseil/abseil-cpp/archive/refs/tags/20230802.1.tar.gz", ], ) @@ -162,11 +160,11 @@ def _com_github_fmtlib_fmt(): maybe( http_archive, name = "com_github_fmtlib_fmt", - strip_prefix = "fmt-10.0.0", - sha256 = "ede1b6b42188163a3f2e0f25ad5c0637eca564bd8df74d02e31a311dd6b37ad8", + strip_prefix = "fmt-10.1.1", + sha256 = "78b8c0a72b1c35e4443a7e308df52498252d1cefc2b08c9a97bc9ee6cfe61f8b", build_file = "@yacl//bazel:fmtlib.BUILD", urls = [ - "https://github.com/fmtlib/fmt/archive/refs/tags/10.0.0.tar.gz", + "https://github.com/fmtlib/fmt/archive/refs/tags/10.1.1.tar.gz", ], ) @@ -219,6 +217,7 @@ def _com_github_blake3team_blake3(): ], ) +# Required by protobuf def _rule_python(): maybe( http_archive, @@ -234,23 +233,10 @@ def _rules_foreign_cc(): maybe( http_archive, name = "rules_foreign_cc", - sha256 = "2a4d07cd64b0719b39a7c12218a3e507672b82a97b98c6a89d38565894cf7c51", - strip_prefix = "rules_foreign_cc-0.9.0", + sha256 = "476303bd0f1b04cc311fc258f1708a5f6ef82d3091e53fd1977fa20383425a6a", + strip_prefix = "rules_foreign_cc-0.10.1", urls = [ - "https://github.com/bazelbuild/rules_foreign_cc/archive/refs/tags/0.9.0.tar.gz", - ], - ) - -def _com_github_floodyberry_curve25519_donna(): - maybe( - http_archive, - name = "com_github_floodyberry_curve25519_donna", - strip_prefix = "curve25519-donna-2fe66b65ea1acb788024f40a3373b8b3e6f4bbb2", - sha256 = "ba57d538c241ad30ff85f49102ab2c8dd996148456ed238a8c319f263b7b149a", - type = "tar.gz", - build_file = "@yacl//bazel:curve25519-donna.BUILD", - urls = [ - "https://github.com/floodyberry/curve25519-donna/archive/2fe66b65ea1acb788024f40a3373b8b3e6f4bbb2.tar.gz", + "https://github.com/bazelbuild/rules_foreign_cc/archive/refs/tags/0.10.1.tar.gz", ], ) @@ -287,14 +273,12 @@ def _com_github_google_cpu_features(): maybe( http_archive, name = "com_github_google_cpu_features", - strip_prefix = "cpu_features-0.8.0", + strip_prefix = "cpu_features-0.9.0", type = "tar.gz", - sha256 = "7021729f2db97aa34f218d12727314f23e8b11eaa2d5a907e8426bcb41d7eaac", build_file = "@yacl//bazel:cpu_features.BUILD", - patch_args = ["-p1"], - patches = ["@yacl//bazel:patches/cpu_features.patch"], + sha256 = "bdb3484de8297c49b59955c3b22dba834401bc2df984ef5cfc17acbe69c5018e", urls = [ - "https://github.com/google/cpu_features/archive/refs/tags/v0.8.0.tar.gz", + "https://github.com/google/cpu_features/archive/refs/tags/v0.9.0.tar.gz", ], ) @@ -333,10 +317,10 @@ def _com_github_msgpack_msgpack(): http_archive, name = "com_github_msgpack_msgpack", type = "tar.gz", - strip_prefix = "msgpack-c-cpp-3.3.0", - sha256 = "754c3ace499a63e45b77ef4bcab4ee602c2c414f58403bce826b76ffc2f77d0b", + strip_prefix = "msgpack-c-cpp-6.1.0", + sha256 = "5e63e4d9b12ab528fccf197f7e6908031039b1fc89cd8da0e97fbcbf5a6c6d3a", urls = [ - "https://github.com/msgpack/msgpack-c/archive/refs/tags/cpp-3.3.0.tar.gz", + "https://github.com/msgpack/msgpack-c/archive/refs/tags/cpp-6.1.0.tar.gz", ], build_file = "@yacl//bazel:msgpack.BUILD", ) diff --git a/bazel/yacl.bzl b/bazel/yacl.bzl index 2ee7388b..a516c2c3 100644 --- a/bazel/yacl.bzl +++ b/bazel/yacl.bzl @@ -74,7 +74,7 @@ def yacl_cmake_external(**attrs): def yacl_configure_make(**attrs): if "args" not in attrs: - attrs["args"] = ["-j 4"] + attrs["args"] = ["-j 8"] return configure_make(**attrs) def yacl_cc_test( diff --git a/yacl/base/buffer.h b/yacl/base/buffer.h index aba17282..28ab5f0b 100644 --- a/yacl/base/buffer.h +++ b/yacl/base/buffer.h @@ -24,7 +24,7 @@ namespace yacl { -// A buffer is a RAII object which represents an in memory buffer. +// A buffer is a RAII object that represents an in memory buffer. class Buffer final { std::byte* ptr_{nullptr}; int64_t size_{0}; @@ -44,6 +44,14 @@ class Buffer final { ptr_ = new std::byte[size]; } + Buffer(int64_t size, int64_t cap) : size_(size), capacity_(cap) { + YACL_ENFORCE(size >= 0 && cap >= size, + "Illegal size & cap, size={}, cap={}", size, cap); + // C++17 ensures alignment of allocated memory is >= + // __STDCPP_DEFAULT_NEW_ALIGNMENT__ Which should be 16 + ptr_ = new std::byte[cap]; + } + template = true> @@ -57,7 +65,7 @@ class Buffer final { } } - // Construct Buffer object from a block of already allocated memory + // Construct a Buffer object from a block of already allocated memory // Buffer will take the ownership of ptr Buffer(void* ptr, size_t size, const std::function& deleter) { YACL_ENFORCE(reinterpret_cast(ptr) % 16 == 0, @@ -124,7 +132,7 @@ class Buffer final { } std::byte* new_ptr = nullptr; - if (new_size != 0) { + if (new_size > 0) { new_ptr = new std::byte[new_size]; if (ptr_ != nullptr) { std::copy(ptr_, ptr_ + std::min(new_size, size_), new_ptr); @@ -139,6 +147,29 @@ class Buffer final { YACL_ENFORCE(size_ == 0 || ptr_ != nullptr, "new size = {}", new_size); } + void reserve(int64_t new_cap) { + YACL_ENFORCE(new_cap >= size_, + "reserve() cannot be used to reduce the size of Buffer,to " + "that end resize() is provided. size()={}, new_cap={}", + size_, new_cap); + + if (new_cap <= capacity_) { + return; + } + + auto* new_ptr = new std::byte[new_cap]; + if (ptr_ != nullptr && size_ > 0) { + std::copy(ptr_, ptr_ + size_, new_ptr); + } + + auto sz = size_; + reset(); + + ptr_ = new_ptr; + size_ = sz; + capacity_ = new_cap; + } + void* release() { void* tmp = ptr_; ptr_ = nullptr; diff --git a/yacl/base/buffer_test.cc b/yacl/base/buffer_test.cc index 5d84b4d4..470ff314 100644 --- a/yacl/base/buffer_test.cc +++ b/yacl/base/buffer_test.cc @@ -22,7 +22,7 @@ namespace yacl::test { -TEST(BufferTest, BasicWorks) { +TEST(BufferTest, ParallelWorks) { std::vector v; v.resize(100000); parallel_for(0, v.size(), 1, [&](int64_t beg, int64_t end) { @@ -32,4 +32,19 @@ TEST(BufferTest, BasicWorks) { }); } +TEST(BufferTest, ReserveWorks) { + Buffer buf("abc\0", 4); + ASSERT_EQ(buf.size(), 4); + + buf.reserve(10); + EXPECT_EQ(buf.size(), 4); + EXPECT_EQ(buf.capacity(), 10); + EXPECT_STREQ(buf.data(), "abc"); + + EXPECT_ANY_THROW(buf.reserve(1)); + EXPECT_NO_THROW(buf.resize(1)); + EXPECT_EQ(buf.size(), 1); + EXPECT_EQ(buf.capacity(), 10); +} + } // namespace yacl::test diff --git a/yacl/base/exception.h b/yacl/base/exception.h index 1f298f02..2932dfc6 100644 --- a/yacl/base/exception.h +++ b/yacl/base/exception.h @@ -135,6 +135,23 @@ class NetworkError : public IoError { using IoError::IoError; }; +class LinkError : public NetworkError { + public: + LinkError() = delete; + explicit LinkError(const std::string& msg, int code, int http_code = 0) + : NetworkError(msg), code_(code), http_code_(http_code) {} + explicit LinkError(const std::string& msg, void** stacks, int dep, int code, + int http_code = 0) + : NetworkError(msg, stacks, dep), code_(code), http_code_(http_code) {} + + int code() const noexcept { return code_; } + int http_code() const noexcept { return http_code_; } + + private: + int code_; + int http_code_; +}; + #define YACL_ERROR_MSG(...) \ fmt::format("[{}:{}] {}", __FILE__, __LINE__, fmt::format(__VA_ARGS__)) @@ -183,6 +200,15 @@ using stacktrace_t = std::array; dep); \ } while (false) +#define YACL_THROW_LINK_ERROR(code, http_code, ...) \ + do { \ + ::yacl::stacktrace_t stacks; \ + int dep = absl::GetStackTrace(stacks.data(), \ + ::yacl::internal::kMaxStackTraceDep, 0); \ + throw ::yacl::LinkError(YACL_ERROR_MSG(__VA_ARGS__), stacks.data(), dep, \ + code, http_code); \ + } while (false) + #define YACL_THROW_INVALID_FORMAT(...) \ do { \ ::yacl::stacktrace_t stacks; \ diff --git a/yacl/crypto/base/ecc/openssl/openssl_factory.cc b/yacl/crypto/base/ecc/openssl/openssl_factory.cc index 1b00b894..e916c0f6 100644 --- a/yacl/crypto/base/ecc/openssl/openssl_factory.cc +++ b/yacl/crypto/base/ecc/openssl/openssl_factory.cc @@ -53,18 +53,22 @@ std::map kName2Nid = { {"secp160r1", NID_secp160r1}, {"secp160r2", NID_secp160r2}, {"secp192k1", NID_secp192k1}, + // SECG secp192r1 is the same as X9.62 prime192v1 and NIST P-192. Openssl + // use NID_X9_62_prime192v1 to indicate them + {"secp192r1", NID_X9_62_prime192v1}, {"secp224k1", NID_secp224k1}, {"secp224r1", NID_secp224r1}, + // SECG secp256r1 is the same as X9.62 prime256v1 and NIST P-256. Openssl + // use NID_X9_62_prime256v1 to indicate them + {"secp256r1", NID_X9_62_prime256v1}, {"secp256k1", NID_secp256k1}, {"secp384r1", NID_secp384r1}, {"secp521r1", NID_secp521r1}, - {"secp192r1", NID_X9_62_prime192v1}, {"prime192v2", NID_X9_62_prime192v2}, {"prime192v3", NID_X9_62_prime192v3}, {"prime239v1", NID_X9_62_prime239v1}, {"prime239v2", NID_X9_62_prime239v2}, {"prime239v3", NID_X9_62_prime239v3}, - {"secp256r1", NID_X9_62_prime256v1}, {"sect113r1", NID_sect113r1}, {"sect113r2", NID_sect113r2}, {"sect131r1", NID_sect131r1}, @@ -82,6 +86,8 @@ std::map kName2Nid = { {"sect409k1", NID_sect409k1}, {"sect409r1", NID_sect409r1}, {"sect571k1", NID_sect571k1}, + // SECG sect571r1 is the same as ANSI X9.63 ansit571r1 and NIST B-571. + // Openssl use NID_sect571r1 to indicate them {"sect571r1", NID_sect571r1}, {"c2pnb163v1", NID_X9_62_c2pnb163v1}, {"c2pnb163v2", NID_X9_62_c2pnb163v2}, @@ -110,6 +116,9 @@ std::map kName2Nid = { {"wap-wsg-idm-ecid-wtls10", NID_wap_wsg_idm_ecid_wtls10}, {"wap-wsg-idm-ecid-wtls11", NID_wap_wsg_idm_ecid_wtls11}, {"wap-wsg-idm-ecid-wtls12", NID_wap_wsg_idm_ecid_wtls12}, + {"Oakley Group 3", NID_ipsec3}, + {"Oakley Group 4", NID_ipsec4}, + {"brainpoolP160r1", NID_brainpoolP160r1}, {"brainpoolP160t1", NID_brainpoolP160t1}, {"brainpoolP192r1", NID_brainpoolP192r1}, {"brainpoolP192t1", NID_brainpoolP192t1}, diff --git a/yacl/crypto/base/ecc/openssl/openssl_group.cc b/yacl/crypto/base/ecc/openssl/openssl_group.cc index 38047dd4..59e38a66 100644 --- a/yacl/crypto/base/ecc/openssl/openssl_group.cc +++ b/yacl/crypto/base/ecc/openssl/openssl_group.cc @@ -307,7 +307,7 @@ EcPoint OpensslGroup::HashToCurve(HashToCurveStrategy strategy, } break; case HashToCurveStrategy::TryAndRehash_SHA3: - YACL_THROW("Openssl lib does not support TryAndRehash_SHA3 strategy now"); + YACL_THROW("Openssl does not support TryAndRehash_SHA3 strategy now"); break; case HashToCurveStrategy::TryAndRehash_SM: hash_algorithm = HashAlgorithm::SM3; @@ -317,9 +317,8 @@ EcPoint OpensslGroup::HashToCurve(HashToCurveStrategy strategy, hash_algorithm = HashAlgorithm::BLAKE3; break; default: - YACL_THROW( - "Openssl lib only supports TryAndRehash strategy now. select={}", - (int)strategy); + YACL_THROW("Openssl only supports TryAndRehash strategy now. select={}", + (int)strategy); } auto point = MakeOpensslPoint(); diff --git a/yacl/crypto/base/test/BUILD.bazel b/yacl/crypto/base/test/BUILD.bazel index d635bee6..b3f25999 100644 --- a/yacl/crypto/base/test/BUILD.bazel +++ b/yacl/crypto/base/test/BUILD.bazel @@ -21,16 +21,11 @@ yacl_cc_binary( srcs = [ "alg_data_test.cc", ], - defines = ["CURVE25519_DONNA"] + select({ - "@bazel_tools//src/conditions:darwin": ["USE_LIBDISPATCH"], - "//conditions:default": [], - }), deps = [ "//yacl/base:int128", "//yacl/crypto/base:symmetric_crypto", "//yacl/crypto/base/hash:ssl_hash", "//yacl/crypto/tools:prg", - "@com_github_floodyberry_curve25519_donna//:curve25519_donna", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest", diff --git a/yacl/crypto/base/test/alg_data_test.cc b/yacl/crypto/base/test/alg_data_test.cc index 18086584..5f85b620 100644 --- a/yacl/crypto/base/test/alg_data_test.cc +++ b/yacl/crypto/base/test/alg_data_test.cc @@ -23,14 +23,8 @@ #include "yacl/crypto/base/symmetric_crypto.h" #include "yacl/crypto/tools/prg.h" -extern "C" { - -#include "curve25519.h" -} - namespace yacl::crypto { -constexpr int kCurve25519ElemSize = 32; constexpr int kDataNum = 10; std::vector aes_data_len = {16, 16, 16, 32, 32, 32, 48, 48, 48, 64}; std::vector sha256_data_len = {3, 6, 9, 12, 15, 16, 16, 16, 32, 32}; @@ -111,61 +105,6 @@ void Sha256TestData(std::string &file_name) { } } -void Curve25519TestData(std::string &file_name) { - std::ofstream out_file; - out_file.open(file_name, std::ios::out /*| std::ios::app*/); - - // prg - Prg prg(0, PRG_MODE::kNistAesCtrDrbg); - - for (size_t i = 0; i < sha256_data_len.size(); ++i) { - std::array private_x; - std::array public_x; - std::array private_y; - std::array public_y; - std::array share_x; - std::array share_y; - - prg.Fill(absl::MakeSpan(private_x.data(), private_x.size())); - prg.Fill(absl::MakeSpan(private_y.data(), private_y.size())); - - curve25519_donna_basepoint(public_x.data(), private_x.data()); - curve25519_donna_basepoint(public_y.data(), private_y.data()); - - curve25519_donna(share_x.data(), private_x.data(), public_y.data()); - curve25519_donna(share_y.data(), private_y.data(), public_x.data()); - - EXPECT_EQ(share_x, share_y); - - out_file << "=====" << i << "=====" << std::endl; - out_file << "private x(" << private_x.size() << "):" << std::endl; - out_file << absl::BytesToHexString(std::string_view( - reinterpret_cast(private_x.data()), - share_x.size())) - << std::endl; - out_file << "public x(" << public_x.size() << "):" << std::endl; - out_file << absl::BytesToHexString(std::string_view( - reinterpret_cast(public_x.data()), - share_x.size())) - << std::endl; - out_file << "private y(" << private_y.size() << "):" << std::endl; - out_file << absl::BytesToHexString(std::string_view( - reinterpret_cast(private_y.data()), - share_x.size())) - << std::endl; - out_file << "public y(" << public_y.size() << "):" << std::endl; - out_file << absl::BytesToHexString(std::string_view( - reinterpret_cast(public_y.data()), - share_x.size())) - << std::endl; - out_file << "share y(" << share_x.size() << "):" << std::endl; - out_file << absl::BytesToHexString(std::string_view( - reinterpret_cast(share_x.data()), - share_x.size())) - << std::endl; - } -} - } // namespace yacl::crypto int main(int /*argc*/, char ** /*argv*/) { @@ -179,7 +118,6 @@ int main(int /*argc*/, char ** /*argv*/) { AesTestData(aes_ctr_file_name, yacl::crypto::SymmetricCrypto::CryptoType::AES128_CTR); yacl::crypto::Sha256TestData(sha256_file_name); - yacl::crypto::Curve25519TestData(curve25519_file_name); return 0; -} \ No newline at end of file +} diff --git a/bazel/curve25519-donna.BUILD b/yacl/io/msgpack/BUILD.bazel similarity index 58% rename from bazel/curve25519-donna.BUILD rename to yacl/io/msgpack/BUILD.bazel index 70182636..0d74d013 100644 --- a/bazel/curve25519-donna.BUILD +++ b/yacl/io/msgpack/BUILD.bazel @@ -12,11 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("@rules_cc//cc:defs.bzl", "cc_library") +load("//bazel:yacl.bzl", "yacl_cc_library", "yacl_cc_test") -cc_library( - name = "curve25519_donna", - srcs = ["curve25519.c"], - hdrs = glob(["*.h"]), - visibility = ["//visibility:public"], +package(default_visibility = ["//visibility:public"]) + +yacl_cc_library( + name = "buffer", + hdrs = ["buffer.h"], + deps = [ + "//yacl/base:buffer", + ], +) + +yacl_cc_library( + name = "spec_traits", + hdrs = ["spec_traits.h"], +) + +yacl_cc_test( + name = "buffer_test", + srcs = ["buffer_test.cc"], + deps = [ + ":buffer", + ], ) diff --git a/yacl/io/msgpack/buffer.h b/yacl/io/msgpack/buffer.h new file mode 100644 index 00000000..46d67728 --- /dev/null +++ b/yacl/io/msgpack/buffer.h @@ -0,0 +1,84 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "yacl/base/buffer.h" + +namespace yacl::io { + +class StreamBuffer { + public: + explicit StreamBuffer(Buffer *buf) : buf_(buf) { + YACL_ENFORCE(buf != nullptr, "buf is nullptr"); + buf_->resize(0); + } + + void write(const char *str, size_t len) { + auto old_sz = buf_->size(); + buf_->resize(buf_->size() + len); + std::memcpy(buf_->data() + old_sz, str, len); + } + + void write(std::string_view str) { write(str.data(), str.length()); } + + // reserve how much *extra* space + void Expand(size_t extra_space) { + buf_->reserve(static_cast(buf_->size() + extra_space)); + } + + char *PosLoc() const { return buf_->data() + buf_->size(); } + void IncPos(int64_t delta) { buf_->resize(buf_->size() + delta); } + + size_t WrittenSize() { return buf_->size(); } + size_t FreeSize() const { return buf_->capacity() - buf_->size(); } + + private: + Buffer *buf_; +}; + +class FixedBuffer { + public: + FixedBuffer(char *buf, size_t buf_len) : buf_(buf), buf_len_(buf_len) {} + + void write(const char *str, size_t len) { + YACL_ENFORCE( + pos_ + len <= buf_len_, + "Dangerous!! buffer overflow, buf_len={}, pos={}, write_size={}", + buf_len_, pos_, len); + + std::copy(str, str + len, buf_ + pos_); + pos_ += len; + } + + char *PosLoc() const { return buf_ + pos_; } + + void IncPos(int64_t delta) { + int64_t new_pos = static_cast(pos_) + delta; + YACL_ENFORCE(new_pos >= 0 && new_pos <= (int64_t)buf_len_, + "cannot update pos, delta={}, cur_pos={}, buf_len={}", delta, + pos_, buf_len_); + pos_ = new_pos; + } + + size_t WrittenSize() const { return pos_; } + size_t FreeSize() const { return buf_len_ - pos_; } + + private: + char *buf_; + size_t buf_len_; + size_t pos_ = 0; +}; + +} // namespace yacl::io diff --git a/yacl/io/msgpack/buffer_test.cc b/yacl/io/msgpack/buffer_test.cc new file mode 100644 index 00000000..d2f86461 --- /dev/null +++ b/yacl/io/msgpack/buffer_test.cc @@ -0,0 +1,34 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "yacl/io/msgpack/buffer.h" + +#include "gtest/gtest.h" + +namespace yacl::io::test { + +TEST(TestStreamBuffer, SimpleWorks) { + Buffer buffer(3); + StreamBuffer sbuf(&buffer); + + EXPECT_EQ(sbuf.WrittenSize(), 0); + + std::string abc = "abc"; + sbuf.write(std::string_view(abc)); + sbuf.write("d", 1); + + EXPECT_EQ(std::string(buffer.data(), buffer.size()), "abcd"); +} + +} // namespace yacl::io::test diff --git a/yacl/io/msgpack/spec_traits.h b/yacl/io/msgpack/spec_traits.h new file mode 100644 index 00000000..44b3ce0f --- /dev/null +++ b/yacl/io/msgpack/spec_traits.h @@ -0,0 +1,83 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace yacl::io::msgpack_traits { + +// The msgpack spec: +// https://github.com/msgpack/msgpack/blob/master/spec.md + +size_t HeadSizeOfStr(size_t str_len) { + // Str format family stores a byte array in 1, 2, 3, or 5 bytes of extra bytes + // in addition to the size of the byte array. + if (str_len < 32) { + // fixstr stores a byte array whose length is upto 31 bytes: + // +--------+========+ + // |101XXXXX| data | + // +--------+========+ + return 1; + + } else if (str_len < 256) { + // str 8 stores a byte array whose length is upto (2^8)-1 bytes: + // +--------+--------+========+ + // | 0xd9 |YYYYYYYY| data | + // +--------+--------+========+ + return 2; + + } else if (str_len < 65536) { + // str 16 stores a byte array whose length is upto (2^16)-1 bytes: + // +--------+--------+--------+========+ + // | 0xda |ZZZZZZZZ|ZZZZZZZZ| data | + // +--------+--------+--------+========+ + return 3; + + } else { + // str 32 stores a byte array whose length is upto (2^32)-1 bytes: + // +--------+--------+--------+--------+--------+========+ + // | 0xdb |AAAAAAAA|AAAAAAAA|AAAAAAAA|AAAAAAAA| data | + // +--------+--------+--------+--------+--------+========+ + return 5; + } +} + +size_t HeadSizeOfArray(size_t array_len) { + // Array format family stores a sequence of elements in 1, 3, or 5 bytes of + // extra bytes in addition to the elements. + if (array_len < 16) { + // fixarray stores an array whose length is upto 15 elements: + // +--------+~~~~~~~~~~~~~~~~~+ + // |1001XXXX| N objects | + // +--------+~~~~~~~~~~~~~~~~~+ + return 1; + + } else if (array_len < 65536) { + // array 16 stores an array whose length is upto (2^16)-1 elements: + // +--------+--------+--------+~~~~~~~~~~~~~~~~~+ + // | 0xdc |YYYYYYYY|YYYYYYYY| N objects | + // +--------+--------+--------+~~~~~~~~~~~~~~~~~+ + return 3; + + } else { + // array 32 stores an array whose length is upto (2^32)-1 elements: + // +--------+--------+--------+--------+--------+~~~~~~~~~~~~~~~~~+ + // | 0xdd |ZZZZZZZZ|ZZZZZZZZ|ZZZZZZZZ|ZZZZZZZZ| N objects | + // +--------+--------+--------+--------+--------+~~~~~~~~~~~~~~~~~+ + return 5; + } +} + +} // namespace yacl::io::msgpack_traits diff --git a/yacl/io/stream/interface.h b/yacl/io/stream/interface.h index 0efd1945..5e91baa8 100644 --- a/yacl/io/stream/interface.h +++ b/yacl/io/stream/interface.h @@ -39,26 +39,26 @@ class InputStream { /** * read line from file. - * raise exception if any error happend except EOF. + * raise exception if any error happened except EOF. */ virtual InputStream& GetLine(std::string* ret, char delim) = 0; virtual InputStream& GetLine(std::string* ret) { return GetLine(ret, '\n'); } /** * Read length bytes from the file. - * raise exception if any error happend except EOF. + * raise exception if any error happened except EOF. */ virtual InputStream& Read(void* buf, size_t length) = 0; /** * Sets input position. - * raise exception if any error happend. + * raise exception if any error happened. */ virtual InputStream& Seekg(size_t pos) = 0; /** * Sets input position. - * raise exception if any error happend. + * raise exception if any error happened. */ virtual size_t Tellg() = 0; @@ -91,7 +91,7 @@ class InputStream { }; /** - * Always trunc file if target file exist. + * Always trunc file if target file exists. */ class OutputStream { public: @@ -99,7 +99,7 @@ class OutputStream { /** * Write/Append length bytes pointed by buf to the file stream - * raise exception if any error happend. + * raise exception if any error happened. */ virtual void Write(const void* buf, size_t length) = 0; virtual void Write(std::string_view buf) = 0; diff --git a/yacl/link/BUILD.bazel b/yacl/link/BUILD.bazel index 4d0179e1..6c9f924a 100644 --- a/yacl/link/BUILD.bazel +++ b/yacl/link/BUILD.bazel @@ -39,6 +39,14 @@ yacl_cc_library( ], ) +yacl_cc_library( + name = "retry_options", + hdrs = ["retry_options.h"], + deps = [ + ":link_cc_proto", + ], +) + yacl_cc_library( name = "ssl_options", hdrs = ["ssl_options.h"], @@ -53,6 +61,7 @@ yacl_cc_library( hdrs = ["context.h"], deps = [ ":link_cc_proto", + ":retry_options", ":ssl_options", ":trace", "//yacl/base:byte_container_view", diff --git a/yacl/link/context.h b/yacl/link/context.h index a4593b42..8f4626dc 100644 --- a/yacl/link/context.h +++ b/yacl/link/context.h @@ -24,6 +24,7 @@ #include #include "yacl/base/byte_container_view.h" +#include "yacl/link/retry_options.h" #include "yacl/link/ssl_options.h" #include "yacl/link/transport/channel.h" #include "yacl/utils/hash.h" @@ -46,8 +47,6 @@ struct ContextDesc { static constexpr uint32_t kDefaultChunkParallelSendSize = 8; static constexpr char kDefaultBrpcChannelProtocol[] = "baidu_std"; static constexpr char kDefaultLinkType[] = "normal"; - static constexpr uint32_t kDefaultBrpcRetryCount = 3; - static constexpr uint32_t kDefaultBrpcRetryInterval = 1000; struct Party { std::string id; @@ -139,13 +138,7 @@ struct ContextDesc { // "blackbox" or "normal", default: "normal" std::string link_type = kDefaultLinkType; - // request retry count - uint32_t brpc_retry_count = kDefaultBrpcRetryCount; - // request retry interval - uint32_t brpc_retry_interval_ms = kDefaultBrpcRetryInterval; // 1 seconds. - // do aggressive retry, this means that retries will be made on additional - // error codes - bool brpc_aggressive_retry = true; + RetryOptions retry_opts; bool operator==(const ContextDesc& other) const { return (id == other.id) && (parties == other.parties); @@ -182,12 +175,7 @@ struct ContextDesc { client_ssl_opts(pb.client_ssl_opts()), server_ssl_opts(pb.server_ssl_opts()), link_type(kDefaultLinkType), - brpc_retry_count(pb.brpc_retry_count() ? pb.brpc_retry_count() - : kDefaultBrpcRetryCount), - brpc_retry_interval_ms(pb.brpc_retry_interval_ms() - ? pb.brpc_retry_interval_ms() - : kDefaultBrpcRetryInterval), - brpc_aggressive_retry(pb.brpc_aggressive_retry()) { + retry_opts(pb.retry_opts()) { for (const auto& party_pb : pb.parties()) { parties.emplace_back(party_pb); } @@ -207,9 +195,7 @@ struct ContextDescHasher { desc.connect_retry_interval_ms, desc.recv_timeout_ms, desc.http_max_payload_size, desc.http_timeout_ms, desc.throttle_window_size, desc.brpc_channel_protocol, - desc.brpc_channel_connection_type, desc.link_type, - desc.brpc_retry_count, desc.brpc_retry_interval_ms, - desc.brpc_aggressive_retry); + desc.brpc_channel_connection_type, desc.link_type); return seed; } diff --git a/yacl/link/factory_brpc.cc b/yacl/link/factory_brpc.cc index 80e3e5d4..b44e08f8 100644 --- a/yacl/link/factory_brpc.cc +++ b/yacl/link/factory_brpc.cc @@ -48,9 +48,7 @@ std::shared_ptr FactoryBrpc::CreateContext(const ContextDesc& desc, auto opts = transport::BrpcLink::GetDefaultOptions(); opts = transport::BrpcLink::MakeOptions( opts, desc.http_timeout_ms, desc.http_max_payload_size, - desc.brpc_channel_protocol, desc.brpc_channel_connection_type, - desc.brpc_retry_count, desc.brpc_retry_interval_ms, - desc.brpc_aggressive_retry); + desc.brpc_channel_protocol, desc.brpc_channel_connection_type); auto msg_loop = std::make_unique(); std::vector> channels(world_size); @@ -65,7 +63,8 @@ std::shared_ptr FactoryBrpc::CreateContext(const ContextDesc& desc, desc.enable_ssl ? &desc.client_ssl_opts : nullptr); auto channel = std::make_shared( - delegate, desc.recv_timeout_ms, desc.exit_if_async_error); + delegate, desc.recv_timeout_ms, desc.exit_if_async_error, + desc.retry_opts); channel->SetThrottleWindowSize(desc.throttle_window_size); msg_loop->AddListener(rank, channel); channels[rank] = std::move(channel); diff --git a/yacl/link/factory_brpc_blackbox.cc b/yacl/link/factory_brpc_blackbox.cc index e9304906..e5b61b9d 100644 --- a/yacl/link/factory_brpc_blackbox.cc +++ b/yacl/link/factory_brpc_blackbox.cc @@ -80,9 +80,7 @@ std::shared_ptr FactoryBrpcBlackBox::CreateContext( auto options = transport::BrpcBlackBoxLink::GetDefaultOptions(); options = transport::BrpcBlackBoxLink::MakeOptions( options, desc.http_timeout_ms, desc.http_max_payload_size, - desc.brpc_channel_protocol, desc.brpc_channel_connection_type, - desc.brpc_retry_count, desc.brpc_retry_interval_ms, - desc.brpc_aggressive_retry); + desc.brpc_channel_protocol, desc.brpc_channel_connection_type); if (options.channel_protocol != "http" && options.channel_protocol != "h2") { YACL_THROW_LOGIC_ERROR( @@ -105,7 +103,7 @@ std::shared_ptr FactoryBrpcBlackBox::CreateContext( desc.enable_ssl ? &desc.client_ssl_opts : nullptr); auto channel = std::make_shared( - delegate, desc.recv_timeout_ms, false); + delegate, desc.recv_timeout_ms, false, desc.retry_opts); channel->SetThrottleWindowSize(desc.throttle_window_size); msg_loop->AddLinkAndChannel(rank, channel, delegate); diff --git a/yacl/link/link.proto b/yacl/link/link.proto index 30da48ae..0430927d 100644 --- a/yacl/link/link.proto +++ b/yacl/link/link.proto @@ -39,6 +39,29 @@ message SSLOptionsProto { string ca_file_path = 4; } +// Retry options. +message RetryOptionsProto { + // max retry count + // default 3 + uint32 max_retry = 1; + // time between retries at first retry + // default 1 second + uint32 retry_interval_ms = 2; + // The amount of time to increase the interval between retries + // default 2s + uint32 retry_interval_incr_ms = 3; + // The maximum interval between retries + // default 10s + uint32 max_retry_interval_ms = 4; + // retry on these brpc error codes, if empty, retry on all codes + repeated uint32 error_codes = 5; + // retry on these http codes, if empty, retry on all http codes + repeated uint32 http_codes = 6; + // do aggressive retry, this means that retries will be made on additional + // error codes + bool aggressive_retry = 7; +} + // Configuration for link config. message ContextDescProto { // the UUID of this communication. @@ -104,14 +127,10 @@ message ContextDescProto { // this config is ignored if enable_ssl == false; SSLOptionsProto server_ssl_opts = 13; - // max retry count - uint32 brpc_retry_count = 14; - // the time between retries - uint32 brpc_retry_interval_ms = 15; - // do aggressive retry, this means that retries will be made on additional - // error codes - bool brpc_aggressive_retry = 16; // chunk parallel send size for channel. if need chunked send when send // message, the max paralleled send size is chunk_parallel_send_size uint32 chunk_parallel_send_size = 17; + + // retry options + RetryOptionsProto retry_opts = 14; } diff --git a/yacl/link/retry_options.h b/yacl/link/retry_options.h new file mode 100644 index 00000000..7d16c39c --- /dev/null +++ b/yacl/link/retry_options.h @@ -0,0 +1,72 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "yacl/link/link.pb.h" + +namespace yacl::link { + +struct RetryOptions { + static constexpr uint32_t kDefaultMaxRetryCount = 3; + static constexpr uint32_t kDefaultRetryInterval = 1000; + static constexpr uint32_t kDefaultRetryIncrement = 2000; + static constexpr uint32_t kDefaultMaxRetryInterval = 10000; + + // max retry count + uint32_t max_retry; + // at first retry interval + uint32_t retry_interval_ms; + // the amount of time to increase between retries + uint32_t retry_interval_incr_ms; + // the max interval between retries + uint32_t max_retry_interval_ms; + // retry on these error codes, if empty, retry on all codes + std::unordered_set error_codes; + // retry on these http codes, if empty, retry on all http codes + std::unordered_set http_codes; + // do aggressive retry + bool aggressive_retry; + + RetryOptions() + : max_retry(kDefaultMaxRetryCount), + retry_interval_ms(kDefaultRetryInterval), + retry_interval_incr_ms(kDefaultRetryIncrement), + max_retry_interval_ms(kDefaultMaxRetryInterval), + aggressive_retry(true) {} + + RetryOptions(const RetryOptions&) = default; + + RetryOptions(const RetryOptionsProto& pb) { + max_retry = pb.max_retry() ? pb.max_retry() : kDefaultMaxRetryCount; + retry_interval_ms = + pb.retry_interval_ms() ? pb.retry_interval_ms() : kDefaultRetryInterval; + retry_interval_incr_ms = pb.retry_interval_incr_ms() + ? pb.retry_interval_ms() + : kDefaultRetryIncrement; + max_retry_interval_ms = pb.max_retry_interval_ms() + ? pb.max_retry_interval_ms() + : kDefaultMaxRetryInterval; + std::for_each(pb.error_codes().begin(), pb.error_codes().end(), + [this](const auto& codes) { error_codes.insert(codes); }); + std::for_each(pb.http_codes().begin(), pb.http_codes().end(), + [this](const auto& codes) { http_codes.insert(codes); }); + aggressive_retry = pb.aggressive_retry(); + } +}; + +} // namespace yacl::link diff --git a/yacl/link/transport/BUILD.bazel b/yacl/link/transport/BUILD.bazel index 1ba60fc4..9f898f83 100644 --- a/yacl/link/transport/BUILD.bazel +++ b/yacl/link/transport/BUILD.bazel @@ -26,12 +26,22 @@ yacl_cc_library( "//yacl/base:buffer", "//yacl/base:byte_container_view", "//yacl/base:exception", + "//yacl/link:retry_options", "//yacl/link:ssl_options", "//yacl/utils:segment_tree", "@com_github_brpc_brpc//:brpc", ], ) +yacl_cc_test( + name = "channel_test", + srcs = ["channel_test.cc"], + deps = [ + ":channel", + ":ic_transport_proto", + ], +) + yacl_cc_library( name = "channel_mem", srcs = ["channel_mem.cc"], @@ -46,15 +56,6 @@ cc_proto_library( deps = ["@org_interconnection//interconnection/link"], ) -yacl_cc_library( - name = "default_brpc_retry_policy", - srcs = ["default_brpc_retry_policy.cc"], - hdrs = ["default_brpc_retry_policy.h"], - deps = [ - "@com_github_brpc_brpc//:brpc", - ], -) - yacl_cc_library( name = "interconnection_link", srcs = ["interconnection_link.cc"], @@ -70,7 +71,6 @@ yacl_cc_library( srcs = ["brpc_link.cc"], hdrs = ["brpc_link.h"], deps = [ - ":default_brpc_retry_policy", ":interconnection_link", ], ) @@ -88,7 +88,6 @@ yacl_cc_library( srcs = ["brpc_blackbox_link.cc"], hdrs = ["brpc_blackbox_link.h"], deps = [ - ":default_brpc_retry_policy", ":interconnection_link", "//yacl/link/transport/blackbox_interconnect:blackbox_service_errorcode", "//yacl/link/transport/blackbox_interconnect:blackbox_service_proto", diff --git a/yacl/link/transport/blackbox_interconnect/blackbox_dummy_service_impl.cc b/yacl/link/transport/blackbox_interconnect/blackbox_dummy_service_impl.cc index bd019b85..19e19f16 100644 --- a/yacl/link/transport/blackbox_interconnect/blackbox_dummy_service_impl.cc +++ b/yacl/link/transport/blackbox_interconnect/blackbox_dummy_service_impl.cc @@ -56,7 +56,7 @@ void DummyBlackBoxServiceImpl::OnInvoke(const std::string* topic, std::string request) { SPDLOG_INFO("invoke: topic {}", *topic); std::lock_guard guard(msg_mtx_); - msg_db_[*topic].push(std::move(request)); + recv_msgs_[*topic].push(std::move(request)); msg_cond_.notify_all(); } @@ -128,10 +128,10 @@ void DummyBlackBoxServiceImpl::OnPop(brpc::Controller* cntl, } std::unique_lock lock(msg_mtx_); msg_cond_.wait_for(lock, std::chrono::seconds(timeout), - [&] { return !msg_db_[*topic].empty(); }); - if (!msg_db_[*topic].empty()) { - response.set_payload(msg_db_[*topic].front()); - msg_db_[*topic].pop(); + [&] { return !recv_msgs_[*topic].empty(); }); + if (!recv_msgs_[*topic].empty()) { + response.set_payload(recv_msgs_[*topic].front()); + recv_msgs_[*topic].pop(); } else { // return OK when timeout, but payload is empty. response.set_code(bb_ic::error_code::Code("OK")); diff --git a/yacl/link/transport/blackbox_interconnect/blackbox_dummy_service_impl.h b/yacl/link/transport/blackbox_interconnect/blackbox_dummy_service_impl.h index 0a44e80d..c5f6be73 100644 --- a/yacl/link/transport/blackbox_interconnect/blackbox_dummy_service_impl.h +++ b/yacl/link/transport/blackbox_interconnect/blackbox_dummy_service_impl.h @@ -58,7 +58,7 @@ class DummyBlackBoxServiceImpl std::string peer_url_; std::mutex msg_mtx_; std::condition_variable msg_cond_; - std::map> msg_db_; + std::map> recv_msgs_; uint32_t invoke_max_retry_cnt_ = 10; uint32_t invoke_retry_interval_ms_ = 1000; }; diff --git a/yacl/link/transport/brpc_blackbox_link.cc b/yacl/link/transport/brpc_blackbox_link.cc index 6d5f7a05..478022f6 100644 --- a/yacl/link/transport/brpc_blackbox_link.cc +++ b/yacl/link/transport/brpc_blackbox_link.cc @@ -28,7 +28,6 @@ #include "yacl/base/exception.h" #include "yacl/link/transport/blackbox_interconnect/blackbox_service_errorcode.h" #include "yacl/link/transport/channel.h" -#include "yacl/link/transport/default_brpc_retry_policy.h" #include "yacl/link/transport/blackbox_interconnect/blackbox_service.pb.h" namespace yacl::link::transport { @@ -60,35 +59,6 @@ const auto* const kHttpHeadHost = "host"; } // namespace -class BlackboxRetryPolicy : public DefaultBrpcRetryPolicy { - public: - BlackboxRetryPolicy(uint32_t retry_interval_ms, bool aggressive_retry, - uint32_t push_wait_s) - : DefaultBrpcRetryPolicy(retry_interval_ms, aggressive_retry), - push_wait_s_(push_wait_s) {} - - bool OnRpcSuccess(const brpc::Controller* cntl) const override { - bb_ic::TransportOutbound response; - if (!response.ParseFromString(cntl->response_attachment().to_string())) { - SPDLOG_ERROR("Parse message failed."); - return false; - } - - if (response.code() == bb_ic::error_code::Code("QueueFull")) { - SPDLOG_WARN( - "{} push error due to transport service queue is full, try again...", - cntl->retried_count()); - bthread_usleep(push_wait_s_ * 1e6); - return true; - } - - return false; - } - - private: - uint32_t push_wait_s_; -}; - ReceiverLoopBlackBox::~ReceiverLoopBlackBox() { Stop(); } void ReceiverLoopBlackBox::Stop() { @@ -181,8 +151,6 @@ std::optional BrpcBlackBoxLink::TryReceive() { brpc::ChannelOptions BrpcBlackBoxLink::GetChannelOption( const SSLOptions* ssl_opts) { brpc::ChannelOptions options; - auto retry_policy = std::make_unique( - options_.retry_interval_ms, options_.aggressive_retry, push_wait_ms_); { if (options_.channel_protocol != "http" && options_.channel_protocol != "h2") { @@ -194,8 +162,7 @@ brpc::ChannelOptions BrpcBlackBoxLink::GetChannelOption( options.connection_type = options_.channel_connection_type; options.connect_timeout_ms = 20000; options.timeout_ms = options_.http_timeout_ms; - options.max_retry = options_.max_retry; - options.retry_policy = retry_policy.get(); + options.max_retry = 0; if (ssl_opts != nullptr) { options.mutable_ssl_options()->client_cert.certificate = ssl_opts->cert.certificate_path; @@ -207,7 +174,6 @@ brpc::ChannelOptions BrpcBlackBoxLink::GetChannelOption( ssl_opts->verify.ca_file_path; } } - retry_policy_ = std::move(retry_policy); return options; } @@ -271,7 +237,7 @@ void BrpcBlackBoxLink::SetPeerHost(const std::string& self_id, } void BrpcBlackBoxLink::SetHttpHeader(brpc::Controller* controller, - const std::string& topic) { + const std::string& topic) const { for (auto& [k, v] : http_headers_) { controller->http_request().SetHeader(k, v); } @@ -279,8 +245,8 @@ void BrpcBlackBoxLink::SetHttpHeader(brpc::Controller* controller, controller->http_request().set_method(brpc::HTTP_METHOD_POST); } -void BrpcBlackBoxLink::SendRequest(const ::google::protobuf::Message& request, - uint32_t timeout_ms) { +void BrpcBlackBoxLink::SendRequest(const Request& request, + uint32_t timeout_ms) const { bb_ic::TransportOutbound response; auto request_str = request.SerializeAsString(); int pushs = 0; @@ -298,8 +264,7 @@ void BrpcBlackBoxLink::SendRequest(const ::google::protobuf::Message& request, channel_->CallMethod(nullptr, &cntl, nullptr, nullptr, nullptr); if (cntl.Failed()) { - YACL_THROW_NETWORK_ERROR("send, rpc failed={}, message={}", - cntl.ErrorCode(), cntl.ErrorText()); + ThrowLinkErrorByBrpcCntl(cntl); } YACL_ENFORCE( @@ -311,9 +276,7 @@ void BrpcBlackBoxLink::SendRequest(const ::google::protobuf::Message& request, } if (response.code() != bb_ic::error_code::Code("QueueFull")) { - YACL_THROW_NETWORK_ERROR( - "response with error, code={}({}) message={}", response.code(), - bb_ic::error_code::Desc(response.code()), response.message()); + ThrowLinkErrorByBrpcCntl(cntl); } else { SPDLOG_WARN( "{} push error due to transport service queue is full, try " diff --git a/yacl/link/transport/brpc_blackbox_link.h b/yacl/link/transport/brpc_blackbox_link.h index a9b4a999..e480354c 100644 --- a/yacl/link/transport/brpc_blackbox_link.h +++ b/yacl/link/transport/brpc_blackbox_link.h @@ -26,7 +26,6 @@ #include "yacl/link/ssl_options.h" #include "yacl/link/transport/channel.h" -#include "yacl/link/transport/default_brpc_retry_policy.h" #include "yacl/link/transport/interconnection_link.h" #include "interconnection/link/transport.pb.h" @@ -79,8 +78,7 @@ class ReceiverLoopBlackBox final : public IReceiverLoop { class BrpcBlackBoxLink final : public InterconnectionLink { public: static InterconnectionLink::Options GetDefaultOptions() { - return InterconnectionLink::Options{10 * 1000, 512 * 1024, "http", - "", 3, 1000}; + return InterconnectionLink::Options{10 * 1000, 512 * 1024, "http", ""}; } using InterconnectionLink::InterconnectionLink; @@ -92,7 +90,7 @@ class BrpcBlackBoxLink final : public InterconnectionLink { } void SendRequest(const ::google::protobuf::Message& request, - uint32_t timeout_ms) override; + uint32_t timeout_ms) const override; void SetPeerHost(const std::string& self_id, const std::string& self_node_id, const std::string& peer_id, const std::string& peer_node_id, @@ -103,7 +101,8 @@ class BrpcBlackBoxLink final : public InterconnectionLink { brpc::ChannelOptions GetChannelOption(const SSLOptions* ssl_opts); uint32_t GetQueueFullWaitTime() const { return push_wait_ms_; } - void SetHttpHeader(brpc::Controller* controller, const std::string& topic); + void SetHttpHeader(brpc::Controller* controller, + const std::string& topic) const; void OnPopResponse(blackbox_interconnect::TransportOutbound* response); // receive related @@ -119,7 +118,6 @@ class BrpcBlackBoxLink final : public InterconnectionLink { protected: // brpc channel related. - std::unique_ptr retry_policy_; std::shared_ptr channel_; std::string send_topic_; std::string recv_topic_; diff --git a/yacl/link/transport/brpc_blackbox_link_test.cc b/yacl/link/transport/brpc_blackbox_link_test.cc index af3ec0f4..13226601 100644 --- a/yacl/link/transport/brpc_blackbox_link_test.cc +++ b/yacl/link/transport/brpc_blackbox_link_test.cc @@ -114,8 +114,9 @@ class BlackBoxLinkTest : public ::testing::Test { std::make_shared(send_rank, recv_rank, options); auto receiver_delegate = std::make_shared(recv_rank, send_rank, options); - sender_ = std::make_shared(sender_delegate, false); - receiver_ = std::make_shared(receiver_delegate, false); + sender_ = std::make_shared(sender_delegate, false, RetryOptions()); + receiver_ = + std::make_shared(receiver_delegate, false, RetryOptions()); std::string server_addr = "127.0.0.1:0"; diff --git a/yacl/link/transport/brpc_link.cc b/yacl/link/transport/brpc_link.cc index 91ba83f0..2697fa98 100644 --- a/yacl/link/transport/brpc_link.cc +++ b/yacl/link/transport/brpc_link.cc @@ -22,7 +22,6 @@ #include "yacl/base/exception.h" #include "yacl/link/transport/channel.h" -#include "yacl/link/transport/default_brpc_retry_policy.h" #include "interconnection/link/transport.pb.h" @@ -30,35 +29,6 @@ namespace yacl::link::transport { namespace ic = org::interconnection; namespace ic_pb = org::interconnection::link; - -class BrpcRetryPolicy : public DefaultBrpcRetryPolicy { - public: - using DefaultBrpcRetryPolicy::DefaultBrpcRetryPolicy; - bool OnRpcSuccess(const brpc::Controller* cntl) const override { - if (cntl->response() == nullptr) { - SPDLOG_ERROR("response is null"); - return false; - } - - if (cntl->response()->GetDescriptor() != - ic_pb::PushResponse::descriptor()) { - SPDLOG_ERROR("unexpected response typename: {}", - cntl->response()->GetTypeName()); - return false; - } - - auto* response = static_cast(cntl->response()); - auto error_code = response->header().error_code(); - if (error_code == ic::ErrorCode::NETWORK_ERROR) { - SPDLOG_INFO("NETWORK_ERROR, sleep={}us and retry", retry_interval_us_); - bthread_usleep(retry_interval_us_); - return true; - } else { - return false; - } - } -}; - namespace internal { class ReceiverServiceImpl : public ic_pb::ReceiverService { @@ -137,17 +107,13 @@ void BrpcLink::SetPeerHost(const std::string& peer_host, const SSLOptions* ssl_opts) { auto brpc_channel = std::make_unique(); const auto load_balancer = ""; - auto retry_policy = std::make_unique( - options_.retry_interval_ms, options_.aggressive_retry); brpc::ChannelOptions options; { options.protocol = options_.channel_protocol; options.connection_type = options_.channel_connection_type; options.connect_timeout_ms = 20000; options.timeout_ms = options_.http_timeout_ms; - - options.max_retry = options_.max_retry; - options.retry_policy = retry_policy.get(); + options.max_retry = 0; if (ssl_opts != nullptr) { options.mutable_ssl_options()->client_cert.certificate = @@ -167,12 +133,10 @@ void BrpcLink::SetPeerHost(const std::string& peer_host, } delegate_channel_ = std::move(brpc_channel); - retry_policy_ = std::move(retry_policy); peer_host_ = peer_host; } -void BrpcLink::SendRequest(const ::google::protobuf::Message& request, - uint32_t timeout) { +void BrpcLink::SendRequest(const Request& request, uint32_t timeout) const { ic_pb::PushResponse response; brpc::Controller cntl; cntl.ignore_eovercrowded(); @@ -184,9 +148,9 @@ void BrpcLink::SendRequest(const ::google::protobuf::Message& request, nullptr); // handle failures. if (cntl.Failed()) { - YACL_THROW_NETWORK_ERROR("send, rpc failed={}, message={}", - cntl.ErrorCode(), cntl.ErrorText()); + ThrowLinkErrorByBrpcCntl(cntl); } + if (response.header().error_code() != ic::ErrorCode::OK) { YACL_THROW("send, peer failed message={}", response.header().error_msg()); } diff --git a/yacl/link/transport/brpc_link.h b/yacl/link/transport/brpc_link.h index fd3b9439..f27d0917 100644 --- a/yacl/link/transport/brpc_link.h +++ b/yacl/link/transport/brpc_link.h @@ -27,7 +27,6 @@ #include "yacl/base/exception.h" #include "yacl/link/ssl_options.h" #include "yacl/link/transport/channel.h" -#include "yacl/link/transport/default_brpc_retry_policy.h" #include "yacl/link/transport/interconnection_link.h" namespace yacl::link::transport { @@ -65,11 +64,10 @@ class BrpcLink final : public InterconnectionLink { static InterconnectionLink::Options GetDefaultOptions() { return InterconnectionLink::Options{10 * 1000, 512 * 1024, "baidu_std", - "single", 3, 1000}; + "single"}; } - void SendRequest(const ::google::protobuf::Message& request, - uint32_t timeout_override_ms) override; + uint32_t timeout_override_ms) const override; void SetPeerHost(const std::string& peer_host, const SSLOptions* ssl_opts = nullptr); @@ -77,7 +75,6 @@ class BrpcLink final : public InterconnectionLink { protected: // brpc channel related. std::string peer_host_; - std::unique_ptr retry_policy_; std::shared_ptr delegate_channel_; }; diff --git a/yacl/link/transport/brpc_link_test.cc b/yacl/link/transport/brpc_link_test.cc index b5f5482a..7eae90f4 100644 --- a/yacl/link/transport/brpc_link_test.cc +++ b/yacl/link/transport/brpc_link_test.cc @@ -65,8 +65,9 @@ class BrpcLinkTest : public ::testing::Test { auto receive_delegate = std::make_shared(recv_rank, send_rank, options); - sender_ = std::make_shared(sender_delegate, false); - receiver_ = std::make_shared(receive_delegate, false); + sender_ = std::make_shared(sender_delegate, false, RetryOptions()); + receiver_ = + std::make_shared(receive_delegate, false, RetryOptions()); // let sender rank as 0, receiver rank as 1. // receiver_ listen messages from sender(rank 0). @@ -129,11 +130,15 @@ TEST_F(BrpcLinkTest, Normal_Len100) { class BrpcLinkWithLimitTest : public BrpcLinkTest, - public ::testing::WithParamInterface> {}; + public ::testing::WithParamInterface> { +}; TEST_P(BrpcLinkWithLimitTest, SendAsync) { const size_t size_limit_per_call = std::get<0>(GetParam()); const size_t size_to_send = std::get<1>(GetParam()); + const size_t parallel_size = std::get<2>(GetParam()); + sender_->SetChunkParallelSendSize(parallel_size); + receiver_->SetChunkParallelSendSize(parallel_size); sender_->GetLink()->SetMaxBytesPerChunk(size_limit_per_call); @@ -148,6 +153,9 @@ TEST_P(BrpcLinkWithLimitTest, SendAsync) { TEST_P(BrpcLinkWithLimitTest, Unread) { const size_t size_limit_per_call = std::get<0>(GetParam()); const size_t size_to_send = std::get<1>(GetParam()); + const size_t parallel_size = std::get<2>(GetParam()); + sender_->SetChunkParallelSendSize(parallel_size); + receiver_->SetChunkParallelSendSize(parallel_size); sender_->GetLink()->SetMaxBytesPerChunk(size_limit_per_call); @@ -165,6 +173,10 @@ TEST_P(BrpcLinkWithLimitTest, Unread) { TEST_P(BrpcLinkWithLimitTest, Async) { const size_t size_limit_per_call = std::get<0>(GetParam()); const size_t size_to_send = std::get<1>(GetParam()); + const size_t parallel_size = std::get<2>(GetParam()); + sender_->SetChunkParallelSendSize(parallel_size); + receiver_->SetChunkParallelSendSize(parallel_size); + sender_->SetThrottleWindowSize(size_to_send); sender_->GetLink()->SetMaxBytesPerChunk(size_limit_per_call); const size_t test_size = 128 + (std::rand() % 128); @@ -202,6 +214,9 @@ TEST_P(BrpcLinkWithLimitTest, Async) { TEST_P(BrpcLinkWithLimitTest, AsyncWithThrottleLimit) { const size_t size_limit_per_call = std::get<0>(GetParam()); const size_t size_to_send = std::get<1>(GetParam()); + const size_t parallel_size = std::get<2>(GetParam()); + sender_->SetChunkParallelSendSize(parallel_size); + receiver_->SetChunkParallelSendSize(parallel_size); sender_->SetThrottleWindowSize(size_to_send); sender_->GetLink()->SetMaxBytesPerChunk(size_limit_per_call); const size_t test_size = 128 + (std::rand() % 128); @@ -242,6 +257,10 @@ TEST_P(BrpcLinkWithLimitTest, AsyncWithThrottleLimit) { TEST_P(BrpcLinkWithLimitTest, ThrottleWindowUnread) { const size_t size_limit_per_call = std::get<0>(GetParam()); const size_t size_to_send = std::get<1>(GetParam()); + const size_t parallel_size = std::get<2>(GetParam()); + sender_->SetChunkParallelSendSize(parallel_size); + receiver_->SetChunkParallelSendSize(parallel_size); + sender_->SetThrottleWindowSize(size_to_send); sender_->GetLink()->SetMaxBytesPerChunk(size_limit_per_call); const size_t test_size = 128 + (std::rand() % 128); @@ -285,6 +304,9 @@ TEST_P(BrpcLinkWithLimitTest, ThrottleWindowUnread) { TEST_P(BrpcLinkWithLimitTest, Send) { const size_t size_limit_per_call = std::get<0>(GetParam()); const size_t size_to_send = std::get<1>(GetParam()); + const size_t parallel_size = std::get<2>(GetParam()); + sender_->SetChunkParallelSendSize(parallel_size); + receiver_->SetChunkParallelSendSize(parallel_size); sender_->GetLink()->SetMaxBytesPerChunk(size_limit_per_call); @@ -299,10 +321,12 @@ TEST_P(BrpcLinkWithLimitTest, Send) { INSTANTIATE_TEST_SUITE_P( Normal_Instances, BrpcLinkWithLimitTest, testing::Combine(testing::Values(9, 17), - testing::Values(1, 2, 9, 10, 11, 20, 19, 21, 1001)), + testing::Values(1, 2, 9, 10, 11, 20, 19, 21, 1001), + testing::Values(1, 8)), [](const testing::TestParamInfo& info) { - std::string name = fmt::format("Limit_{}_Len_{}", std::get<0>(info.param), - std::get<1>(info.param)); + std::string name = + fmt::format("Limit_{}_Len_{}_parallel_{}", std::get<0>(info.param), + std::get<1>(info.param), std::get<2>(info.param)); return name; }); @@ -377,8 +401,10 @@ TEST_F(BrpcLinkSSLTest, OneWaySSL) { std::make_shared(send_rank, recv_rank, channel_options_); auto receiver_delegate = std::make_shared(recv_rank, send_rank, channel_options_); - auto sender = std::make_shared(sender_delegate, false); - auto receiver = std::make_shared(receiver_delegate, false); + auto sender = + std::make_shared(sender_delegate, false, RetryOptions()); + auto receiver = + std::make_shared(receiver_delegate, false, RetryOptions()); // let sender rank as 0, receiver rank as 1. // receiver listen messages from sender(rank 0). @@ -423,8 +449,10 @@ TEST_F(BrpcLinkSSLTest, TwoWaySSL) { std::make_shared(send_rank, recv_rank, channel_options_); auto receiver_delegate = std::make_shared(recv_rank, send_rank, channel_options_); - auto sender = std::make_shared(sender_delegate, false); - auto receiver = std::make_shared(receiver_delegate, false); + auto sender = + std::make_shared(sender_delegate, false, RetryOptions()); + auto receiver = + std::make_shared(receiver_delegate, false, RetryOptions()); // let sender rank as 0, receiver rank as 1. // receiver listen messages from sender(rank 0). diff --git a/yacl/link/transport/channel.cc b/yacl/link/transport/channel.cc index 4b847f4a..708cfee2 100644 --- a/yacl/link/transport/channel.cc +++ b/yacl/link/transport/channel.cc @@ -94,7 +94,7 @@ class SendTask { std::unique_ptr task(static_cast(args)); try { task->channel_->SendImpl(task->msg_.msg_key_, task->msg_.value_); - } catch (const yacl::Exception& e) { + } catch (const std::exception& e) { SPDLOG_ERROR("SendImpl error {}", e.what()); if (task->exit_if_async_error_) { exit(-1); @@ -148,7 +148,7 @@ std::optional Channel::MessageQueue::Pop(bool block) { void Channel::SendThread() { while (!send_thread_stopped_.load()) { - auto msg = msg_queue_.Pop(true); + auto msg = send_msgs_.Pop(true); if (!msg.has_value()) { continue; } @@ -159,7 +159,7 @@ void Channel::SendThread() { // link is closing, send all pending msgs while (true) { - auto msg = msg_queue_.Pop(false); + auto msg = send_msgs_.Pop(false); if (!msg.has_value()) { break; } @@ -174,11 +174,11 @@ class SendChunkedWindow YACL_ENFORCE(parallel_limit_ > 0); } - void OnPushDone(std::optional e) noexcept { + void OnPushDone(std::unique_ptr e) noexcept { std::unique_lock lock(mutex_); running_push_--; - if (e.has_value()) { + if (e) { async_exception_ = std::move(e); } cond_.notify_all(); @@ -188,8 +188,8 @@ class SendChunkedWindow std::unique_lock lock(mutex_); while (running_push_ != 0) { cond_.wait(lock); - if (async_exception_.has_value()) { - throw async_exception_.value(); + if (async_exception_) { + throw yacl::Exception(async_exception_->what()); } } } @@ -199,28 +199,27 @@ class SendChunkedWindow explicit Token(std::shared_ptr window) : window_(std::move(window)) {} - void SetException(std::optional& e) { - if (!e.has_value()) { - return; + void SetException(std::unique_ptr e) { + if (e) { + exception_ = std::move(e); } - exception_ = std::move(e); } - ~Token() { window_->OnPushDone(exception_); } + ~Token() { window_->OnPushDone(std::move(exception_)); } private: std::shared_ptr window_; - std::optional exception_; + std::unique_ptr exception_; }; std::unique_ptr GetToken() { std::unique_lock lock(mutex_); running_push_++; - while (running_push_ >= parallel_limit_) { + while (running_push_ > parallel_limit_) { cond_.wait(lock); - if (async_exception_.has_value()) { - throw async_exception_.value(); + if (async_exception_) { + throw(*async_exception_); } } @@ -232,42 +231,42 @@ class SendChunkedWindow int64_t running_push_ = 0; bthread::Mutex mutex_; bthread::ConditionVariable cond_; - std::optional async_exception_; + std::unique_ptr async_exception_; }; class SendChunkedTask { public: - SendChunkedTask(std::shared_ptr delegate, + SendChunkedTask(std::shared_ptr channel, std::unique_ptr token, std::unique_ptr<::google::protobuf::Message> request) - : link_(std::move(delegate)), + : channel_(std::move(channel)), token_(std::move(token)), request_(std::move(request)) { YACL_ENFORCE(request_, "request is null"); YACL_ENFORCE(token_, "token is null"); - YACL_ENFORCE(link_, "channel is null"); + YACL_ENFORCE(channel_, "channel is null"); } static void* Proc(void* param) { std::unique_ptr task(static_cast(param)); - std::optional except; + std::unique_ptr except; try { - task->link_->SendRequest(*(task->request_), 0); - } catch (Exception& e) { - except = e; - task->token_->SetException(except); + task->channel_->SendRequestWithRetry(*(task->request_), 0); + } catch (const Exception& e) { + except = std::make_unique(e); + task->token_->SetException(std::move(except)); } - return nullptr; - }; + } private: - std::shared_ptr link_; + std::shared_ptr channel_; std::unique_ptr token_; std::unique_ptr<::google::protobuf::Message> request_; }; -void Channel::SendChunked(const std::string& key, ByteContainerView value) { +void Channel::SendChunked(const std::string& key, + ByteContainerView value) const { const size_t bytes_per_chunk = link_->GetMaxBytesPerChunk(); const size_t num_bytes = value.size(); const size_t num_chunks = (num_bytes + bytes_per_chunk - 1) / bytes_per_chunk; @@ -285,8 +284,8 @@ void Channel::SendChunked(const std::string& key, ByteContainerView value) { std::min(bytes_per_chunk, value.size() - chunk_offset)), chunk_offset, num_bytes); - auto task = std::make_unique(link_, window->GetToken(), - std::move(request)); + auto task = std::make_unique( + this->shared_from_this(), window->GetToken(), std::move(request)); bthread_t tid; if (bthread_start_background(&tid, nullptr, SendChunkedTask::Proc, task.get()) == 0) { @@ -299,6 +298,61 @@ void Channel::SendChunked(const std::string& key, ByteContainerView value) { window->Finished(); } +void Channel::SendMono(const std::string& key, ByteContainerView value, + uint32_t timeout_override_ms, + spdlog::level::level_enum log_level) const { + auto request = link_->PackMonoRequest(key, value); + SendRequestWithRetry(*request, timeout_override_ms, log_level); +} + +void Channel::SendRequestWithRetry(const ::google::protobuf::Message& request, + uint32_t timeout_override_ms, + spdlog::level::level_enum log_level) const { + uint32_t retry_count = 0; + while (true) { + try { + link_->SendRequest(request, timeout_override_ms); + break; + } catch (const yacl::LinkError& e) { + auto should_retry = [&](const RetryOptions& retry_options) -> bool { + if (retry_options.aggressive_retry) { + return true; + } + if (retry_options.error_codes.empty() || + retry_options.error_codes.count(e.code())) { + return true; + } + + if (e.http_code() && (retry_options.http_codes.empty() || + retry_options.http_codes.count(e.http_code()))) { + return true; + } + return false; + }; + + if (!should_retry(retry_options_)) { + SPDLOG_WARN("send request failed and no retry, message={}", e.what()); + throw e; + } + + if (retry_count >= retry_options_.max_retry) { + throw e; + } + uint32_t interval_ms = + std::min(retry_options_.retry_interval_ms + + retry_count * retry_options_.retry_interval_incr_ms, + retry_options_.max_retry_interval_ms); + retry_count++; + SPDLOG_LOGGER_CALL( + spdlog::default_logger_raw(), log_level, + "send request failed and retry, retry_count={}, max_retry={}, " + "interval_ms={}, message={}", + retry_count, retry_options_.max_retry, interval_ms, e.what()); + std::this_thread::sleep_for(std::chrono::milliseconds(interval_ms)); + } + } +} + void Channel::SendTaskSynchronizer::SendTaskStartNotify() { std::unique_lock lock(mutex_); running_tasks_++; @@ -335,12 +389,12 @@ Buffer Channel::Recv(const std::string& msg_key) { { std::unique_lock lock(msg_mutex_); auto stop_waiting = [&] { - auto itr = this->msg_db_.find(msg_key); - if (itr == this->msg_db_.end()) { + auto itr = this->recv_msgs_.find(msg_key); + if (itr == this->recv_msgs_.end()) { return false; } else { std::tie(value, seq_id) = std::move(itr->second); - this->msg_db_.erase(itr); + this->recv_msgs_.erase(itr); return true; } }; @@ -379,7 +433,7 @@ void Channel::OnNormalMessage(const std::string& key, T&& v) { if (!waiting_finish_.load()) { auto pair = - msg_db_.emplace(msg_key, std::make_pair(std::forward(v), seq_id)); + recv_msgs_.emplace(msg_key, std::make_pair(std::forward(v), seq_id)); if (seq_id > 0 && !pair.second) { YACL_THROW( "For developer: BUG! PLS do not use same key for multiple msg, " @@ -504,7 +558,7 @@ void Channel::SendAsync(const std::string& msg_key, Buffer&& value) { NormalMessageKeyEnforce(msg_key); size_t seq_id = msg_seq_id_.fetch_add(1) + 1; auto key = BuildChannelKey(msg_key, seq_id); - msg_queue_.Push(Message(seq_id, std::move(key), std::move(value))); + send_msgs_.Push(Message(seq_id, std::move(key), std::move(value))); } void Channel::Send(const std::string& msg_key, ByteContainerView value) { @@ -513,7 +567,7 @@ void Channel::Send(const std::string& msg_key, ByteContainerView value) { NormalMessageKeyEnforce(msg_key); size_t seq_id = msg_seq_id_.fetch_add(1) + 1; auto key = BuildChannelKey(msg_key, seq_id); - msg_queue_.Push(Message(seq_id, std::move(key), value)); + send_msgs_.Push(Message(seq_id, std::move(key), value)); send_sync_.WaitSeqIdSendFinished(seq_id); } @@ -528,7 +582,7 @@ void Channel::SendAsyncThrottled(const std::string& msg_key, Buffer&& value) { NormalMessageKeyEnforce(msg_key); size_t seq_id = msg_seq_id_.fetch_add(1) + 1; auto key = BuildChannelKey(msg_key, seq_id); - msg_queue_.Push(Message(seq_id, std::move(key), std::move(value))); + send_msgs_.Push(Message(seq_id, std::move(key), std::move(value))); ThrottleWindowWait(seq_id); } @@ -537,7 +591,7 @@ void Channel::TestSend(uint32_t timeout) { "TestSend is not allowed when channel is closing"); const auto msg_key = fmt::format("connect_{}", link_->LocalRank()); const auto key = BuildChannelKey(msg_key, 0); - SendImpl(key, "", timeout); + SendImpl(key, "", timeout, spdlog::level::debug); } void Channel::TestRecv() { @@ -563,7 +617,7 @@ void Channel::ThrottleWindowWait(size_t wait_count) { void Channel::WaitAsyncSendToFinish() { send_thread_stopped_.store(true); - msg_queue_.EmptyNotify(); + send_msgs_.EmptyNotify(); send_thread_.join(); send_sync_.WaitAllSendFinished(); } @@ -601,13 +655,13 @@ void Channel::WaitForFinAndFlyingMsg() { void Channel::StopReceivingAndAckUnreadMsgs() { std::unique_lock lock(msg_mutex_); waiting_finish_.store(true); - for (auto& msg : msg_db_) { + for (auto& msg : recv_msgs_) { auto seq_id = msg.second.second; SPDLOG_WARN("Asymmetric logic exist, clear unread key {}, seq_id {}", msg.first, seq_id); SendAck(seq_id); } - msg_db_.clear(); + recv_msgs_.clear(); } void Channel::WaitForFlyingAck() { @@ -628,7 +682,7 @@ void Channel::WaitForFlyingAck() { void Channel::WaitLinkTaskFinish() { // 4 steps to total stop link. - // send ack for msg exist in msg_db_ that unread by up layer logic. + // send ack for msg exist in recv_msgs_ that unread by up layer logic. // stop OnMessage & auto ack all normal msg from now on. StopReceivingAndAckUnreadMsgs(); // wait for last fin msg contain peer's send msg count. diff --git a/yacl/link/transport/channel.h b/yacl/link/transport/channel.h index 1179ad57..0df5e0a7 100644 --- a/yacl/link/transport/channel.h +++ b/yacl/link/transport/channel.h @@ -24,6 +24,7 @@ #include #include +#include "brpc/controller.h" #include "bthread/bthread.h" #include "bthread/condition_variable.h" #include "google/protobuf/message.h" @@ -32,6 +33,7 @@ #include "yacl/base/buffer.h" #include "yacl/base/byte_container_view.h" #include "yacl/base/exception.h" +#include "yacl/link/retry_options.h" #include "yacl/utils/segment_tree.h" namespace yacl::link::transport { @@ -93,36 +95,36 @@ class IChannel { }; class TransportLink { + public: using Request = ::google::protobuf::Message; using Response = ::google::protobuf::Message; - public: TransportLink(size_t self_rank, size_t peer_rank) : self_rank_(self_rank), peer_rank_(peer_rank) {} virtual ~TransportLink() = default; - virtual size_t GetMaxBytesPerChunk() = 0; + virtual size_t GetMaxBytesPerChunk() const = 0; virtual void SetMaxBytesPerChunk(size_t bytes) = 0; - virtual std::unique_ptr PackMonoRequest(const std::string& key, - ByteContainerView value) = 0; - virtual std::unique_ptr PackChunkedRequest(const std::string& key, - ByteContainerView value, - size_t offset, - size_t total_length) = 0; + virtual std::unique_ptr PackMonoRequest( + const std::string& key, ByteContainerView value) const = 0; + virtual std::unique_ptr PackChunkedRequest( + const std::string& key, ByteContainerView value, size_t offset, + size_t total_length) const = 0; virtual void UnpackMonoRequest(const Request& request, std::string* key, - ByteContainerView* value) = 0; + ByteContainerView* value) const = 0; virtual void UnpackChunckRequest(const Request& request, std::string* key, ByteContainerView* value, size_t* offset, - size_t* total_length) = 0; - virtual void FillResponseOk(const Request& request, Response* response) = 0; + size_t* total_length) const = 0; + virtual void FillResponseOk(const Request& request, + Response* response) const = 0; virtual void FillResponseError(const Request& request, - Response* response) = 0; - virtual bool IsChunkedRequest(const Request& request) = 0; - virtual bool IsMonoRequest(const Request& request) = 0; + Response* response) const = 0; + virtual bool IsChunkedRequest(const Request& request) const = 0; + virtual bool IsMonoRequest(const Request& request) const = 0; virtual void SendRequest(const Request& request, - uint32_t timeout_override_ms) = 0; + uint32_t timeout_override_ms) const = 0; size_t LocalRank() const { return self_rank_; } size_t RemoteRank() const { return peer_rank_; } @@ -137,15 +139,19 @@ class ChunkedMessage; class Channel : public IChannel, public std::enable_shared_from_this { public: - Channel(std::shared_ptr delegate, bool exit_if_async_error) - : exit_if_async_error_(exit_if_async_error), link_(std::move(delegate)) { + Channel(std::shared_ptr delegate, bool exit_if_async_error, + const RetryOptions& retry_options) + : exit_if_async_error_(exit_if_async_error), + link_(std::move(delegate)), + retry_options_(retry_options) { StartSendThread(); } Channel(std::shared_ptr delegate, size_t recv_timeout_ms, - bool exit_if_async_error) + bool exit_if_async_error, const RetryOptions& retry_options) : recv_timeout_ms_(recv_timeout_ms), exit_if_async_error_(exit_if_async_error), - link_(std::move(delegate)) { + link_(std::move(delegate)), + retry_options_(retry_options) { StartSendThread(); } @@ -211,24 +217,31 @@ class Channel : public IChannel, public std::enable_shared_from_this { chunk_parallel_send_size_ = size; } + void SendRequestWithRetry( + const ::google::protobuf::Message& request, uint32_t timeout_override_ms, + spdlog::level::level_enum log_level = spdlog::level::info) const; + protected: - void SendChunked(const std::string& key, ByteContainerView value); + void SendChunked(const std::string& key, ByteContainerView value) const; - void SendImpl(const std::string& key, ByteContainerView value) { - SendImpl(key, value, 0); + void SendMono(const std::string& key, ByteContainerView value, + uint32_t timeout_override_ms, + spdlog::level::level_enum log_level) const; + + void SendImpl(const std::string& key, ByteContainerView value) const { + SendImpl(key, value, 0, spdlog::level::info); } void SendImpl(const std::string& key, ByteContainerView value, - uint32_t timeout_override_ms) { + uint32_t timeout_override_ms, + spdlog::level::level_enum log_level) const { YACL_ENFORCE(link_ != nullptr, "delegate has not been setted."); SPDLOG_DEBUG("{} send {}", link_->LocalRank(), key); if (value.size() > link_->GetMaxBytesPerChunk()) { SendChunked(key, value); - return; + } else { + SendMono(key, value, timeout_override_ms, log_level); } - - auto request = link_->PackMonoRequest(key, value); - link_->SendRequest(*request, timeout_override_ms); } private: @@ -303,7 +316,7 @@ class Channel : public IChannel, public std::enable_shared_from_this { protected: uint64_t recv_timeout_ms_ = 3UL * 60 * 1000; // 3 minites - MessageQueue msg_queue_; + MessageQueue send_msgs_; std::thread send_thread_; std::atomic send_thread_stopped_ = false; SendTaskSynchronizer send_sync_; @@ -317,7 +330,7 @@ class Channel : public IChannel, public std::enable_shared_from_this { bthread::Mutex msg_mutex_; bthread::ConditionVariable msg_db_cond_; // msg_key -> - std::map> msg_db_; + std::map> recv_msgs_; // for Throttle Window std::atomic throttle_window_size_ = 0; @@ -341,6 +354,8 @@ class Channel : public IChannel, public std::enable_shared_from_this { const bool exit_if_async_error_; std::shared_ptr link_; + + RetryOptions retry_options_; }; // A receiver loop is a thread loop which receives messages from the world. diff --git a/yacl/link/transport/channel_mem.cc b/yacl/link/transport/channel_mem.cc index ba813826..630538df 100644 --- a/yacl/link/transport/channel_mem.cc +++ b/yacl/link/transport/channel_mem.cc @@ -39,12 +39,12 @@ Buffer ChannelMem::Recv(const std::string& key) { { std::unique_lock lock(msg_mutex_); auto stop_waiting = [&] { - auto itr = this->msg_db_.find(key); - if (itr == this->msg_db_.end()) { + auto itr = this->recv_msgs_.find(key); + if (itr == this->recv_msgs_.end()) { return false; } else { value = std::move(itr->second); - this->msg_db_.erase(itr); + this->recv_msgs_.erase(itr); return true; } }; @@ -60,7 +60,7 @@ void ChannelMem::OnMessage(const std::string& msg_key, ByteContainerView value) { { std::unique_lock lock(msg_mutex_); - msg_db_.emplace(msg_key, value); + recv_msgs_.emplace(msg_key, value); } msg_db_cond_.notify_all(); } diff --git a/yacl/link/transport/channel_mem.h b/yacl/link/transport/channel_mem.h index 62feff33..23716243 100644 --- a/yacl/link/transport/channel_mem.h +++ b/yacl/link/transport/channel_mem.h @@ -86,7 +86,7 @@ class ChannelMem final : public IChannel { // message database related. std::mutex msg_mutex_; std::condition_variable msg_db_cond_; - std::unordered_map msg_db_; + std::unordered_map recv_msgs_; std::chrono::milliseconds recv_timeout_ms_ = 3UL * 60 * std::chrono::milliseconds(1000); diff --git a/yacl/link/transport/channel_test.cc b/yacl/link/transport/channel_test.cc new file mode 100644 index 00000000..0d7b9450 --- /dev/null +++ b/yacl/link/transport/channel_test.cc @@ -0,0 +1,219 @@ +// Copyright 2019 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "yacl/link/transport/channel.h" + +#include "fmt/format.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#include "brpc/errno.pb.h" +#include "interconnection/link/transport.pb.h" + +namespace yacl::link::transport::test { + +namespace ic_pb = org::interconnection::link; + +static std::string RandStr(size_t length) { + auto randchar = []() -> char { + const char charset[] = + "0123456789" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz"; + const size_t max_index = (sizeof(charset) - 1); + return charset[rand() % max_index]; + }; + std::string str(length, 0); + std::generate_n(str.begin(), length, randchar); + return str; +} + +class MockTransportLink : public TransportLink { + public: + using TransportLink::TransportLink; + MOCK_METHOD(void, SetMaxBytesPerChunk, (size_t), (override)); + MOCK_METHOD(std::unique_ptr, PackMonoRequest, + (const std::string&, ByteContainerView), (const, override)); + MOCK_METHOD(void, UnpackMonoRequest, + (const Request&, std::string*, ByteContainerView*), + (const, override)); + MOCK_METHOD(void, UnpackChunckRequest, + (const Request& request, std::string*, ByteContainerView*, + size_t*, size_t*), + (const, override)); + MOCK_METHOD(void, FillResponseOk, (const Request&, Response*), + (const, override)); + MOCK_METHOD(void, FillResponseError, (const Request&, Response*), + (const, override)); + MOCK_METHOD(bool, IsChunkedRequest, (const Request&), (const, override)); + MOCK_METHOD(bool, IsMonoRequest, (const Request&), (const, override)); + MOCK_METHOD(void, SendRequest, (const Request&, uint32_t), (const, override)); + + // provided an implementation in order to test SendChunked + size_t GetMaxBytesPerChunk() const override { + static size_t max_bytes_per_chunk = 2; + return max_bytes_per_chunk; + } + std::unique_ptr<::google::protobuf::Message> PackChunkedRequest( + const std::string& key, ByteContainerView value, size_t offset, + size_t total_length) const override { + auto request = std::make_unique(); + { + request->set_sender_rank(self_rank_); + request->set_key(key); + request->set_value(value.data(), value.size()); + request->set_trans_type(ic_pb::TransType::CHUNKED); + request->mutable_chunk_info()->set_chunk_offset(offset); + request->mutable_chunk_info()->set_message_length(total_length); + } + return request; + } +}; + +class ChannelSendRetryTest : public testing::Test { + protected: + void SetUp() override { + const size_t send_rank = 0; + const size_t recv_rank = 1; + sender_delegate_ = + std::make_shared(send_rank, recv_rank); + RetryOptions retry_options; + retry_options.aggressive_retry = false; + // reset time intervals to accelerate + retry_options.retry_interval_ms = 10; + retry_options.retry_interval_incr_ms = 20; + retry_options.max_retry_interval_ms = 100; + retry_options.error_codes = {brpc::ENOSERVICE, brpc::ENOMETHOD}; + retry_options.http_codes = {brpc::HTTP_STATUS_BAD_GATEWAY, + brpc::HTTP_STATUS_MULTIPLE_CHOICES}; + + sender_ = std::make_shared(sender_delegate_, false, retry_options); + } + + void TearDown() override { + sender_delegate_.reset(); + sender_.reset(); + } + + std::shared_ptr sender_delegate_; + std::shared_ptr sender_; +}; + +TEST_F(ChannelSendRetryTest, NoRetrySuccess) { + EXPECT_CALL(*sender_delegate_, SendRequest(::testing::_, ::testing::_)) + .Times(1) + .WillRepeatedly([]() {}); + const std::string key = "key"; + sender_->Send(key, {}); +} + +TEST_F(ChannelSendRetryTest, NoRetryFail) { + EXPECT_CALL(*sender_delegate_, SendRequest(::testing::_, ::testing::_)) + .Times(1) + .WillRepeatedly([]() { + throw yacl::LinkError("not valid error code.", brpc::EUNUSED); + }); + const std::string key = "key"; + sender_->Send(key, {}); +} + +TEST_F(ChannelSendRetryTest, RetrySuccess) { + EXPECT_CALL(*sender_delegate_, SendRequest(::testing::_, ::testing::_)) + .Times(2) + .WillOnce([]() { + throw yacl::LinkError("valid error code, will retry once and success.", + brpc::ENOSERVICE); + }) + .WillRepeatedly([]() {}); + const std::string key = "key"; + sender_->Send(key, {}); +} + +TEST_F(ChannelSendRetryTest, RetryFail) { + EXPECT_CALL(*sender_delegate_, SendRequest(::testing::_, ::testing::_)) + .Times(4) + .WillRepeatedly([]() { + throw yacl::LinkError( + "valid error code, will retry max count and fail.", + brpc::ENOSERVICE); + }); + const std::string key = "key"; + sender_->Send(key, {}); +} + +TEST_F(ChannelSendRetryTest, HttpNoRetryFail) { + EXPECT_CALL(*sender_delegate_, SendRequest(::testing::_, ::testing::_)) + .Times(1) + .WillRepeatedly([]() { + throw yacl::LinkError("not valid http code and no retry.", brpc::EHTTP, + brpc::HTTP_STATUS_GATEWAY_TIMEOUT); + }); + const std::string key = "key"; + sender_->Send(key, {}); +} + +TEST_F(ChannelSendRetryTest, HttpRetrySuccess) { + EXPECT_CALL(*sender_delegate_, SendRequest(::testing::_, ::testing::_)) + .Times(2) + .WillOnce([]() { + throw yacl::LinkError("valid http code, will retry once and success.", + brpc::EHTTP, brpc::HTTP_STATUS_BAD_GATEWAY); + }) + .WillRepeatedly([]() {}); + const std::string key = "key"; + sender_->Send(key, {}); +} + +TEST_F(ChannelSendRetryTest, HttpRetryFail) { + EXPECT_CALL(*sender_delegate_, SendRequest(::testing::_, ::testing::_)) + .Times(4) + .WillRepeatedly([]() { + throw yacl::LinkError("valid http code, will retry max count and fail.", + brpc::EHTTP, brpc::HTTP_STATUS_BAD_GATEWAY); + }); + const std::string key = "key"; + sender_->Send(key, {}); +} + +TEST_F(ChannelSendRetryTest, ChunkedNoRetrySuccess) { + EXPECT_CALL(*sender_delegate_, SendRequest(::testing::_, ::testing::_)) + .Times(2) + .WillRepeatedly([]() {}); + const std::string key = "key"; + sender_->Send(key, RandStr(3)); +} + +TEST_F(ChannelSendRetryTest, ChunkedNoRetryFail) { + EXPECT_CALL(*sender_delegate_, SendRequest(::testing::_, ::testing::_)) + .Times(2) // chunk1 fail | chunk2 fail + .WillRepeatedly([]() { + throw yacl::LinkError("not valid error code.", brpc::EUNUSED); + }); + const std::string key = "key"; + sender_->Send(key, RandStr(3)); +} + +TEST_F(ChannelSendRetryTest, ChunkedRetrySuccess) { + EXPECT_CALL(*sender_delegate_, SendRequest(::testing::_, ::testing::_)) + .Times(3) // chunk1 fail | chunk1 retry | chunk2 success + .WillOnce([]() { + throw yacl::LinkError("valid error code, will retry once and success.", + brpc::ENOSERVICE); + }) + .WillRepeatedly([]() {}); + const std::string key = "key"; + sender_->Send(key, RandStr(3)); +} + +} // namespace yacl::link::transport::test diff --git a/yacl/link/transport/default_brpc_retry_policy.cc b/yacl/link/transport/default_brpc_retry_policy.cc deleted file mode 100644 index 68489a1e..00000000 --- a/yacl/link/transport/default_brpc_retry_policy.cc +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright 2023 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "yacl/link/transport/default_brpc_retry_policy.h" - -#include "bthread/bthread.h" -#include "spdlog/spdlog.h" - -namespace yacl::link::transport { - -void LogHttpDetail(const brpc::Controller* cntl) { - const auto& response_header = cntl->http_response(); - std::string str_header; - for (auto it = response_header.HeaderBegin(); - it != response_header.HeaderEnd(); ++it) { - str_header += fmt::format("[{}]:[{}];", it->first, it->second); - } - SPDLOG_INFO( - "cntl ErrorCode '{}', http status code '{}', response header " - "'{}', error msg '{}'", - cntl->ErrorCode(), cntl->http_response().status_code(), str_header, - cntl->ErrorText()); -} - -bool DefaultBrpcRetryPolicy::OnRpcSuccess( - const brpc::Controller* /* cntl */) const { - SPDLOG_DEBUG("rpc success, no retry."); - return false; -} - -// From brpc::RetryPolicy -bool DefaultBrpcRetryPolicy::DoRetry(const brpc::Controller* cntl) const { - if (cntl->ErrorCode() == 0) { - // successful RPC, check response header - return OnRpcSuccess(cntl); - } else if (cntl->ErrorCode() == ECONNREFUSED || - cntl->ErrorCode() == ECONNRESET) { - // if a lot of parallel job and each of which may send a lot of async - // request, the remote (gateway) may reset the connection, this branch - // tries to handle this by sleep a little bit. - SPDLOG_INFO("socket error, sleep={}us and retry", retry_interval_us_); - bthread_usleep(retry_interval_us_); - return true; - } else if (cntl->ErrorCode() == brpc::EHTTP && - cntl->http_response().status_code() == - brpc::HTTP_STATUS_BAD_GATEWAY) { - // rejected by peer gateway, do retry. - LogHttpDetail(cntl); - SPDLOG_INFO("rejected by remote gateway, sleep={}us and retry", - retry_interval_us_); - bthread_usleep(retry_interval_us_); - return true; - } else if (cntl->ErrorCode() == brpc::ERPCTIMEDOUT || - cntl->ErrorCode() == ECANCELED) { - // do not retry on ERPCTIMEDOUT / ECANCELED, otherwise brpc will trigger - // assert. - // https://github.com/apache/brpc/blob/1.0.0-rc02/src/brpc/controller.cpp#L625 - SPDLOG_INFO( - "not retry for reached rcp timeout, ErrorCode '{}', error msg '{}'", - cntl->ErrorCode(), cntl->ErrorText()); - return false; - } else if (aggressive_retry_) { - LogHttpDetail(cntl); - SPDLOG_INFO("aggressive retry, sleep={}us and retry", retry_interval_us_); - bthread_usleep(retry_interval_us_); - - return true; - } - - // leave others to brpc::DefaultRetryPolicy() - return brpc::DefaultRetryPolicy()->DoRetry(cntl); -} - -} // namespace yacl::link::transport diff --git a/yacl/link/transport/default_brpc_retry_policy.h b/yacl/link/transport/default_brpc_retry_policy.h deleted file mode 100644 index 8356f076..00000000 --- a/yacl/link/transport/default_brpc_retry_policy.h +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright 2023 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include - -#include "brpc/retry_policy.h" - -namespace yacl::link::transport { - -class DefaultBrpcRetryPolicy : public brpc::RetryPolicy { - public: - explicit DefaultBrpcRetryPolicy(uint32_t retry_interval_ms, - bool aggressive_retry) - : retry_interval_us_(retry_interval_ms * 1000), - aggressive_retry_(aggressive_retry) {} - - virtual bool OnRpcSuccess(const brpc::Controller* cntl) const; - - // From brpc::RetryPolicy - bool DoRetry(const brpc::Controller* cntl) const final; - - protected: - // for bthread_usleep - const uint32_t retry_interval_us_; - bool aggressive_retry_ = true; -}; - -} // namespace yacl::link::transport diff --git a/yacl/link/transport/interconnection_link.cc b/yacl/link/transport/interconnection_link.cc index 097d3355..20aca7db 100644 --- a/yacl/link/transport/interconnection_link.cc +++ b/yacl/link/transport/interconnection_link.cc @@ -34,12 +34,38 @@ DECLARE_int32(h2_client_stream_window_size); namespace yacl::link::transport { +void ThrowLinkErrorByBrpcCntl(const brpc::Controller& cntl) { + int code = cntl.ErrorCode(); + int http_code = 0; + if (code == brpc::EHTTP) { + http_code = cntl.http_response().status_code(); + } + const auto& response_header = cntl.http_response(); + std::string http_headers; + constexpr size_t kMaxResponsePrefix = 20; + std::string http_response; + if (cntl.has_http_request()) { + for (auto it = response_header.HeaderBegin(); + it != response_header.HeaderEnd(); ++it) { + http_headers += fmt::format("[{}]:[{}];", it->first, it->second); + } + http_response = + cntl.response_attachment().to_string().substr(0, kMaxResponsePrefix); + } + + YACL_THROW_LINK_ERROR(code, http_code, + "cntl ErrorCode '{}', http status code '{}', response " + "header '{}', response body '{}', error msg '{}'", + code, http_code, http_headers, http_response, + cntl.ErrorText()); +} + namespace ic = org::interconnection; namespace ic_pb = org::interconnection::link; std::unique_ptr<::google::protobuf::Message> InterconnectionLink::PackMonoRequest(const std::string& key, - ByteContainerView value) { + ByteContainerView value) const { auto request = std::make_unique(); { request->set_sender_rank(self_rank_); @@ -53,7 +79,7 @@ InterconnectionLink::PackMonoRequest(const std::string& key, std::unique_ptr<::google::protobuf::Message> InterconnectionLink::PackChunkedRequest(const std::string& key, ByteContainerView value, size_t offset, - size_t total_length) { + size_t total_length) const { auto request = std::make_unique(); { request->set_sender_rank(self_rank_); @@ -68,7 +94,7 @@ InterconnectionLink::PackChunkedRequest(const std::string& key, void InterconnectionLink::UnpackMonoRequest( const ::google::protobuf::Message& request, std::string* key, - ByteContainerView* value) { + ByteContainerView* value) const { YACL_ENFORCE(key != nullptr, "key should not be null"); YACL_ENFORCE(value != nullptr, "value should not be null"); auto real_request = static_cast(&request); @@ -78,7 +104,7 @@ void InterconnectionLink::UnpackMonoRequest( void InterconnectionLink::UnpackChunckRequest( const ::google::protobuf::Message& request, std::string* key, - ByteContainerView* value, size_t* offset, size_t* total_length) { + ByteContainerView* value, size_t* offset, size_t* total_length) const { YACL_ENFORCE(key != nullptr, "key should not be null"); YACL_ENFORCE(value != nullptr, "value should not be null"); YACL_ENFORCE(offset != nullptr, "offset should not be null"); @@ -93,7 +119,7 @@ void InterconnectionLink::UnpackChunckRequest( void InterconnectionLink::FillResponseOk( const ::google::protobuf::Message& /*request*/, - ::google::protobuf::Message* response) { + ::google::protobuf::Message* response) const { YACL_ENFORCE(response != nullptr, "response should not be null"); auto real_response = static_cast(response); @@ -103,7 +129,7 @@ void InterconnectionLink::FillResponseOk( void InterconnectionLink::FillResponseError( const ::google::protobuf::Message& request, - ::google::protobuf::Message* response) { + ::google::protobuf::Message* response) const { YACL_ENFORCE(response != nullptr, "response should not be null"); auto real_response = static_cast(response); @@ -117,13 +143,13 @@ void InterconnectionLink::FillResponseError( } bool InterconnectionLink::IsChunkedRequest( - const ::google::protobuf::Message& request) { + const ::google::protobuf::Message& request) const { return static_cast(&request)->trans_type() == ic_pb::TransType::CHUNKED; } bool InterconnectionLink::IsMonoRequest( - const ::google::protobuf::Message& request) { + const ::google::protobuf::Message& request) const { return static_cast(&request)->trans_type() == ic_pb::TransType::MONO; } @@ -131,8 +157,7 @@ bool InterconnectionLink::IsMonoRequest( auto InterconnectionLink::MakeOptions( Options& default_opt, uint32_t http_timeout_ms, uint32_t http_max_payload_bytes, const std::string& brpc_channel_protocol, - const std::string& brpc_channel_connection_type, uint32_t retry_count, - uint32_t retry_interval, bool aggressive_retry) -> Options { + const std::string& brpc_channel_connection_type) -> Options { auto opts = default_opt; if (http_timeout_ms != 0) { opts.http_timeout_ms = http_timeout_ms; @@ -143,18 +168,6 @@ auto InterconnectionLink::MakeOptions( if (!brpc_channel_protocol.empty()) { opts.channel_protocol = brpc_channel_protocol; } - if (retry_count != 0) { - opts.max_retry = retry_count; - } - if (retry_interval != 0) { - opts.retry_interval_ms = retry_interval; - } - opts.aggressive_retry = aggressive_retry; - - YACL_ENFORCE( - opts.http_timeout_ms >= (opts.max_retry * opts.retry_interval_ms), - "retry_count() * retry_interval() should less than timeout({})", - opts.max_retry, opts.retry_interval_ms, opts.http_timeout_ms); if (absl::StartsWith(opts.channel_protocol, "h2")) { YACL_ENFORCE(opts.http_max_payload_bytes > 4096, diff --git a/yacl/link/transport/interconnection_link.h b/yacl/link/transport/interconnection_link.h index 39be9b29..2188c53e 100644 --- a/yacl/link/transport/interconnection_link.h +++ b/yacl/link/transport/interconnection_link.h @@ -34,6 +34,8 @@ class PushResponse; namespace yacl::link::transport { +void ThrowLinkErrorByBrpcCntl(const brpc::Controller& cntl); + class InterconnectionLink : public TransportLink { public: struct Options { @@ -41,41 +43,43 @@ class InterconnectionLink : public TransportLink { uint32_t http_max_payload_bytes = 512 * 1024; // 512k bytes std::string channel_protocol; std::string channel_connection_type; - uint32_t max_retry = 3; - uint32_t retry_interval_ms = 1000; - bool aggressive_retry = true; }; + static Options MakeOptions(Options& default_opt, uint32_t http_timeout_ms, uint32_t http_max_payload_size, const std::string& brpc_channel_protocol, - const std::string& brpc_channel_connection_type, - uint32_t retry_count, uint32_t retry_interval, - bool aggressive_retry); + const std::string& brpc_channel_connection_type); InterconnectionLink(size_t self_rank, size_t peer_rank, Options options) : TransportLink(self_rank, peer_rank), options_(std::move(options)) {} void SetMaxBytesPerChunk(size_t bytes) override { options_.http_max_payload_bytes = bytes; } - size_t GetMaxBytesPerChunk() override { + size_t GetMaxBytesPerChunk() const override { return options_.http_max_payload_bytes; } + std::unique_ptr<::google::protobuf::Message> PackMonoRequest( - const std::string& key, ByteContainerView value) override; + const std::string& key, ByteContainerView value) const override; std::unique_ptr<::google::protobuf::Message> PackChunkedRequest( const std::string& key, ByteContainerView value, size_t offset, - size_t total_length) override; + size_t total_length) const override; void UnpackMonoRequest(const ::google::protobuf::Message& request, - std::string* key, ByteContainerView* value) override; + std::string* key, + ByteContainerView* value) const override; void UnpackChunckRequest(const ::google::protobuf::Message& request, std::string* key, ByteContainerView* value, - size_t* offset, size_t* total_length) override; + size_t* offset, size_t* total_length) const override; void FillResponseOk(const ::google::protobuf::Message& request, - ::google::protobuf::Message* response) override; + ::google::protobuf::Message* response) const override; void FillResponseError(const ::google::protobuf::Message& request, - ::google::protobuf::Message* response) override; - bool IsChunkedRequest(const ::google::protobuf::Message& request) override; - bool IsMonoRequest(const ::google::protobuf::Message& request) override; + ::google::protobuf::Message* response) const override; + bool IsChunkedRequest( + const ::google::protobuf::Message& request) const override; + bool IsMonoRequest(const ::google::protobuf::Message& request) const override; + + // void SendRequest(const Request& request, + // uint32_t timeout_override_ms) const override; protected: Options options_; diff --git a/yacl/math/mpint/mp_int.cc b/yacl/math/mpint/mp_int.cc index f0b09048..bc7f893d 100644 --- a/yacl/math/mpint/mp_int.cc +++ b/yacl/math/mpint/mp_int.cc @@ -31,6 +31,7 @@ yacl::math::MPInt operator""_mp(unsigned long long num) { namespace yacl::math { +const MPInt MPInt::_0_(0); const MPInt MPInt::_1_(1); const MPInt MPInt::_2_(2); @@ -276,8 +277,6 @@ void MPInt::Set(const std::string &num, int radix) { } } -bool MPInt::IsZero() const { return mp_iszero(&n_); } - void MPInt::SetZero() { mp_zero(&n_); } uint8_t MPInt::operator[](int idx) const { return GetBit(idx); } @@ -472,6 +471,10 @@ MPInt MPInt::SubMod(const MPInt &b, const MPInt &mod) const { return res; } +void MPInt::SubMod(const MPInt &a, const MPInt &b, const MPInt &mod, MPInt *d) { + MPINT_ENFORCE_OK(mp_submod(&a.n_, &b.n_, &mod.n_, &d->n_)); +} + void MPInt::Mul(const MPInt &a, const MPInt &b, MPInt *c) { MPINT_ENFORCE_OK(mp_mul(&a.n_, &b.n_, &c->n_)); } @@ -626,6 +629,14 @@ yacl::Buffer MPInt::Serialize() const { return buffer; } +size_t MPInt::Serialize(uint8_t *buf, size_t buf_len) const { + if (buf == nullptr) { + return mp_ext_serialize_size(n_); + } + + return mp_ext_serialize(n_, buf, buf_len); +} + void MPInt::Deserialize(yacl::ByteContainerView buffer) { mp_ext_deserialize(&n_, buffer.data(), buffer.size()); } diff --git a/yacl/math/mpint/mp_int.h b/yacl/math/mpint/mp_int.h index 1738907e..860c6cd3 100644 --- a/yacl/math/mpint/mp_int.h +++ b/yacl/math/mpint/mp_int.h @@ -43,6 +43,7 @@ class MPInt { // Pre-defined variables ... // WARNING: Do not use these variables in the global environment. Your global // variables may be initialized earlier than these. + static const MPInt _0_; static const MPInt _1_; static const MPInt _2_; @@ -65,8 +66,8 @@ class MPInt { } // if radix == 0, the real radix will be auto-detected: - // - if num is start with 0x, the radix is 16 - // - if num is start with 0 , the radix is 8 + // - if num is started with 0x, the radix is 16 + // - if num is started with 0, the radix is 8 // - otherwise radix is 10 // Plus and minus signs are supported, e.g. MPInt("-0xFF") is -255 explicit MPInt(const std::string &num, size_t radix = 0); @@ -97,16 +98,11 @@ class MPInt { void Set(const std::string &num, int radix = 0); void SetZero(); - [[nodiscard]] bool IsZero() const; - - // Get the i'th bit. Always return 0 or 1 - uint8_t operator[](int idx) const; - uint8_t GetBit(int idx) const; - void SetBit(int idx, uint8_t bit); - - [[nodiscard]] size_t BitCount() const; - + [[nodiscard]] bool IsZero() const { return mp_iszero(&n_); } + [[nodiscard]] bool IsOne() const { return mp_isone(&n_); } [[nodiscard]] bool IsNegative() const { return mp_isneg(&n_); } + // is natural number (>= 0) + [[nodiscard]] bool IsNatural() const { return !mp_isneg(&n_); } [[nodiscard]] bool IsPositive() const { return !mp_iszero(&n_) && !mp_isneg(&n_); } @@ -114,6 +110,13 @@ class MPInt { [[nodiscard]] bool IsOdd() const { return mp_isodd(&n_); } [[nodiscard]] bool IsEven() const { return mp_iseven(&n_); } + // Get the i'th bit. Always return 0 or 1 + uint8_t operator[](int idx) const; + uint8_t GetBit(int idx) const; + void SetBit(int idx, uint8_t bit); + + [[nodiscard]] size_t BitCount() const; + size_t SizeAllocated() { return n_.alloc * sizeof(mp_digit); } size_t SizeUsed() { return n_.used * sizeof(mp_digit); } @@ -191,7 +194,10 @@ class MPInt { // (*c) = a - b static void Sub(const MPInt &a, const MPInt &b, MPInt *c); + MPInt SubMod(const MPInt &b, const MPInt &mod) const; + static void SubMod(const MPInt &a, const MPInt &b, const MPInt &mod, + MPInt *d); // (*c) = a * b static void Mul(const MPInt &a, const MPInt &b, MPInt *c); @@ -203,13 +209,14 @@ class MPInt { MPInt *d); // *d = (a**b) mod c - static void PowMod(const MPInt &a, const MPInt &b, const MPInt &mod, - MPInt *d); static void Pow(const MPInt &a, uint32_t b, MPInt *c); - MPInt PowMod(const MPInt &b, const MPInt &mod) const; MPInt Pow(uint32_t b) const; void PowInplace(uint32_t b); + static void PowMod(const MPInt &a, const MPInt &b, const MPInt &mod, + MPInt *d); + MPInt PowMod(const MPInt &b, const MPInt &mod) const; + /* a/b => cb + d == a */ static void Div(const MPInt &a, const MPInt &b, MPInt *c, MPInt *d); @@ -298,6 +305,12 @@ class MPInt { friend std::ostream &operator<<(std::ostream &os, const MPInt &an_int); [[nodiscard]] yacl::Buffer Serialize() const; + // serialize mpint to already allocated buffer. + // if buf is nullptr, then calc serialize size only + // @return: the actual size of serialized buffer + // @throw: if buf_len is too small, an exception will be thrown + size_t Serialize(uint8_t *buf, size_t buf_len) const; + void Deserialize(yacl::ByteContainerView buffer); [[nodiscard]] std::string ToString() const; [[nodiscard]] std::string ToHexString() const; diff --git a/yacl/math/mpint/tommath_ext_features.cc b/yacl/math/mpint/tommath_ext_features.cc index 4c136d8f..a078c69b 100644 --- a/yacl/math/mpint/tommath_ext_features.cc +++ b/yacl/math/mpint/tommath_ext_features.cc @@ -430,29 +430,44 @@ void mp_ext_from_mag_bytes(mp_int *num, const uint8_t *buf, size_t buf_len, mp_clamp(num); } +// The mpint serialize protocol: +// The value is stored in little-endian format, and the MSB of the last byte is +// the sign bit. +// +// +--------+--------+--------+--------+ +// |DDDDDDDD|DDDDDDDD|DDDDDDDD|SDDDDDDD| +// +--------+--------+--------+--------+ +// ▲ +// │ +// sign bit +// D = data/payload; S = sign bit size_t mp_ext_serialize_size(const mp_int &num) { - return mp_ext_mag_bytes_size(num) + 1; // we add an extra meta byte + return mp_ext_count_bits_fast(num) / CHAR_BIT + 1; } -void mp_ext_serialize(const mp_int &num, uint8_t *buf, size_t buf_len) { - YACL_ENFORCE(buf_len > 0, "buf_len is zero"); - - // buf[0] is meta byte - if (mp_isneg(&num)) { - buf[0] = 1; - } else { - buf[0] = 0; - } +size_t mp_ext_serialize(const mp_int &num, uint8_t *buf, size_t buf_len) { + auto total_buf = mp_ext_serialize_size(num); + YACL_ENFORCE(buf_len >= total_buf, + "buf is too small, min required={}, actual={}", total_buf, + buf_len); // store num in Little-Endian - mp_ext_to_mag_bytes(num, buf + 1, buf_len - 1, Endian::little); + buf[total_buf - 1] = 0; + auto value_buf = mp_ext_to_mag_bytes(num, buf, buf_len, Endian::little); + YACL_ENFORCE(total_buf == value_buf || total_buf == value_buf + 1, + "bug: buf len mismatch, {} vs {}", total_buf, value_buf); + // write sign bit + uint8_t sign = (mp_isneg(&num) ? 1 : 0) << 7; + buf[total_buf - 1] |= sign; + return total_buf; } void mp_ext_deserialize(mp_int *num, const uint8_t *buf, size_t buf_len) { YACL_ENFORCE(buf_len > 0, "mp_int deserialize: empty buffer"); - - mp_ext_from_mag_bytes(num, buf + 1, buf_len - 1, Endian::little); - num->sign = buf[0] == 0 ? MP_ZPOS : MP_NEG; + // since buf is const, we cannot clear the sign bit + mp_ext_from_mag_bytes(num, buf, buf_len, Endian::little); + num->sign = ((buf[buf_len - 1] >> 7) == 1 ? MP_NEG : MP_ZPOS); + mp_ext_set_bit(num, buf_len * CHAR_BIT - 1, 0); // clear sign bit } uint8_t mp_ext_get_bit(const mp_int &a, int index) { diff --git a/yacl/math/mpint/tommath_ext_features.h b/yacl/math/mpint/tommath_ext_features.h index 1ce282ec..cdba5963 100644 --- a/yacl/math/mpint/tommath_ext_features.h +++ b/yacl/math/mpint/tommath_ext_features.h @@ -41,7 +41,7 @@ void mp_ext_from_mag_bytes(mp_int *num, const uint8_t *buf, size_t buf_len, int mp_ext_count_bits_fast(const mp_int &a); size_t mp_ext_serialize_size(const mp_int &num); -void mp_ext_serialize(const mp_int &num, uint8_t *buf, size_t buf_len); +size_t mp_ext_serialize(const mp_int &num, uint8_t *buf, size_t buf_len); void mp_ext_deserialize(mp_int *num, const uint8_t *buf, size_t buf_len); // return 0 or 1 diff --git a/yacl/math/mpint/tommath_ext_test.cc b/yacl/math/mpint/tommath_ext_test.cc index bd0864f1..05112041 100644 --- a/yacl/math/mpint/tommath_ext_test.cc +++ b/yacl/math/mpint/tommath_ext_test.cc @@ -32,7 +32,7 @@ std::string Info(const mp_int& n) { output.resize(size); MP_ASSERT_OK(mp_to_radix(&n, output.data(), size, nullptr, 16)); output.pop_back(); // remove tailing '\0' - return fmt::format("[used={}, bits={}, dp={}]", n.used, mp_count_bits(&n), + return fmt::format("[used={}, bits={}, value={}]", n.used, mp_count_bits(&n), output); } diff --git a/yacl/utils/spi/BUILD.bazel b/yacl/utils/spi/BUILD.bazel index 244d406a..8aeaab56 100644 --- a/yacl/utils/spi/BUILD.bazel +++ b/yacl/utils/spi/BUILD.bazel @@ -27,9 +27,12 @@ yacl_cc_library( yacl_cc_library( name = "item", + srcs = ["item.cc"], hdrs = ["item.h"], deps = [ "//yacl/base:exception", + "//yacl/base:int128", + "//yacl/math/mpint", "@com_google_absl//absl/types:span", ], ) diff --git a/yacl/utils/spi/argument/arg_k.h b/yacl/utils/spi/argument/arg_k.h index 97822c9a..f8cfeb81 100644 --- a/yacl/utils/spi/argument/arg_k.h +++ b/yacl/utils/spi/argument/arg_k.h @@ -24,6 +24,8 @@ namespace yacl { template class SpiArgKey { public: + using ValueType = T; + explicit SpiArgKey(const std::string &key) : key_(absl::AsciiStrToLower(key)) {} diff --git a/yacl/utils/spi/argument/arg_set.h b/yacl/utils/spi/argument/arg_set.h index 99f225eb..8113c833 100644 --- a/yacl/utils/spi/argument/arg_set.h +++ b/yacl/utils/spi/argument/arg_set.h @@ -41,7 +41,8 @@ class SpiArgs : public std::map { // If the user sets this parameter, but the type is not T, then an exception // is thrown template - T GetRequired(const SpiArgKey &key) const { + auto GetRequired(const SpiArgKey &key) const -> + typename SpiArgKey::ValueType { auto it = find((key.Key())); YACL_ENFORCE(it != end(), "Missing required argument {}", key.Key()); return it->second.template Value(); diff --git a/yacl/utils/spi/item.cc b/yacl/utils/spi/item.cc new file mode 100644 index 00000000..3518e092 --- /dev/null +++ b/yacl/utils/spi/item.cc @@ -0,0 +1,67 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "yacl/utils/spi/item.h" + +#include "yacl/base/int128.h" +#include "yacl/math/mpint/mp_int.h" + +namespace yacl { + +namespace { + +std::string TryRead(const std::any &v) { +#define TRY_TYPE(type) \ + if (t == typeid(type)) { \ + return fmt::to_string(std::any_cast(v)); \ + } + + const auto &t = v.type(); + TRY_TYPE(bool); + TRY_TYPE(int8_t); + TRY_TYPE(int16_t); + TRY_TYPE(int32_t); + TRY_TYPE(int64_t); + TRY_TYPE(int128_t); + TRY_TYPE(uint8_t); + TRY_TYPE(uint16_t); + TRY_TYPE(uint32_t); + TRY_TYPE(uint64_t); + TRY_TYPE(uint128_t); + TRY_TYPE(double); + TRY_TYPE(float); + TRY_TYPE(char); + TRY_TYPE(unsigned char); + TRY_TYPE(math::MPInt); // MPInt is a first-class citizen in SPI + return ""; +} + +} // namespace + +std::string Item::ToString() const { + if (IsArray()) { + return fmt::format("{} Item, element_type={}, RO={}", + IsView() ? "Span" : "Vector", v_.type().name(), + IsReadOnly()); + } else { + return fmt::format("Scalar item, type={}, RO={}, Content={}", + v_.type().name(), IsReadOnly(), TryRead(v_)); + } +} + +std::ostream &operator<<(std::ostream &os, const Item &a) { + return os << a.ToString(); +} + +} // namespace yacl diff --git a/yacl/utils/spi/item.h b/yacl/utils/spi/item.h index 5a2e5f15..055180ae 100644 --- a/yacl/utils/spi/item.h +++ b/yacl/utils/spi/item.h @@ -48,10 +48,17 @@ constexpr bool is_const_t = template using remove_cvref_t = std::remove_cv_t>; +enum class OperandType : int { + Scalar2Scalar = 0b00, + Scalar2Vector = 0b01, + Vector2Scalar = 0b10, + Vector2Vector = 0b11, +}; + class Item { public: template >> - Item(T&& value) : v_(std::forward(value)) { + /* implicit */ Item(T&& value) : v_(std::forward(value)) { Setup(false, false, false); } @@ -82,13 +89,30 @@ class Item { return item; } + template + explicit operator T() const& { + YACL_ENFORCE(v_.type() == typeid(T), "Type mismatch: convert {} to {} fail", + v_.type().name(), typeid(T).name()); + return As(); + } + + template + explicit operator T() && { + YACL_ENFORCE(v_.type() == typeid(T), + "Type mismatch: convert rvalue {} to {} fail", + v_.type().name(), typeid(T).name()); + return As(); + } + template std::enable_if_t, T>& As() & { try { return std::any_cast(v_); } catch (const std::bad_any_cast& e) { - YACL_THROW("Item as lvalue: cannot cast from {} to {}", v_.type().name(), - typeid(T).name()); + YACL_THROW( + "{}, Item as lvalue: cannot cast from {} to {}, please use 'c++filt " + "-t ' tool to see a human-readable name.", + ToString(), v_.type().name(), typeid(T&).name()); } } @@ -97,8 +121,10 @@ class Item { try { return std::any_cast(std::move(v_)); } catch (const std::bad_any_cast& e) { - YACL_THROW("Item as rvalue: cannot cast from {} to {}", v_.type().name(), - typeid(T).name()); + YACL_THROW( + "{}, Item as rvalue: cannot cast from {} to {}, please use 'c++filt " + "-t ' tool to see a human-readable name", + ToString(), v_.type().name(), typeid(T&&).name()); } } @@ -107,8 +133,10 @@ class Item { try { return std::any_cast(v_); } catch (const std::bad_any_cast& e) { - YACL_THROW("Item as const ref: cannot cast from {} to {}", - v_.type().name(), typeid(T).name()); + YACL_THROW( + "{}, Item as a const ref: cannot cast from {} to {}, please use " + "'c++filt -t ' tool to see a human-readable name. ", + ToString(), v_.type().name(), typeid(const T&).name()); } } @@ -117,14 +145,26 @@ class Item { try { return std::any_cast>(&v_); } catch (const std::bad_any_cast& e) { - YACL_THROW("Item as pointer: cannot cast from {} to {}", v_.type().name(), - typeid(T).name()); + YACL_THROW( + "{}, Item as a pointer: cannot cast from {} to {}, please use " + "'c++filt -t ' tool to see a human-readable name", + ToString(), v_.type().name(), + typeid(std::remove_pointer_t).name()); } } template absl::Span AsSpan() { - YACL_ENFORCE(IsArray(), "Item is not an array"); + if (!IsArray()) { + // single element as span + YACL_ENFORCE(!IsReadOnly(), "Single const T is not supported now"); + + return absl::MakeSpan(&As(), 1); + } + + using RawT = std::remove_cv_t; + + // const array case if (IsReadOnly()) { YACL_ENFORCE( is_const_t, @@ -135,22 +175,27 @@ class Item { return As>(); } else { // vector - return absl::MakeSpan(As>>()); + return absl::MakeSpan(As>()); } } - // non-const value -> (const) T + // non-const array case if (IsView()) { - return As>>(); + // span + return As>(); } else { // vector - return absl::MakeSpan(As>>()); + return absl::MakeSpan(As>()); } } template absl::Span AsSpan() const { - YACL_ENFORCE(IsArray(), "Item is not an array"); + if (!IsArray()) { + // single element as span + return absl::MakeSpan(&As(), 1); + } + using RawT = std::remove_cv_t; if (IsView()) { @@ -191,6 +236,10 @@ class Item { return !operator==(other); } + OperandType operator,(const Item& other) const { + return static_cast(((meta_ & 1) << 1) | (other.meta_ & 1)); + }; + // operations only for array template Item SubSpan(size_t pos, size_t len = absl::Span::npos) { @@ -210,6 +259,8 @@ class Item { return SubConstSpanImpl>(pos, len); } + std::string ToString() const; + private: constexpr Item() {} @@ -219,8 +270,9 @@ class Item { meta_ |= static_cast(is_readonly) << 2; } + // For developer: This is not a const func. DO NOT add const qualifier template - Item SubSpanImpl(size_t pos, size_t len) const { + Item SubSpanImpl(size_t pos, size_t len) { YACL_ENFORCE(!IsReadOnly(), "Cannot make a read-write subspan of a const span"); @@ -265,4 +317,6 @@ class Item { std::any v_; }; +std::ostream& operator<<(std::ostream& os, const Item& a); + } // namespace yacl