From 72bcc962fe1e1923a9a484a73ccaa596fd8a92f2 Mon Sep 17 00:00:00 2001 From: "shanzhu.cjm" Date: Sat, 20 Jan 2024 17:58:03 +0800 Subject: [PATCH] repo-sync-2024-01-20T17:57:27+0800 --- .bazelversion | 2 +- CHANGELOG.md | 3 + bazel/openssl.BUILD | 3 +- bazel/repositories.bzl | 6 +- yacl/base/dynamic_bitset.h | 20 + yacl/base/dynamic_bitset_test.cc | 33 +- yacl/crypto/base/BUILD.bazel | 1 - yacl/crypto/base/key_utils.cc | 5 +- yacl/crypto/base/openssl_wrappers.h | 12 + yacl/crypto/ossl-provider/BUILD.bazel | 2 + yacl/crypto/ossl-provider/provider_test.cc | 136 +++--- yacl/crypto/primitives/dpf/BUILD.bazel | 35 +- .../{vole/f2k/sparse_vole.cc => dpf/mpfss.cc} | 405 ++++++++++-------- yacl/crypto/primitives/dpf/mpfss.h | 214 +++++++++ yacl/crypto/primitives/dpf/mpfss_test.cc | 229 ++++++++++ yacl/crypto/primitives/ot/sgrr_ote.cc | 199 +++++---- yacl/crypto/primitives/ot/sgrr_ote.h | 43 +- yacl/crypto/primitives/ot/sgrr_ote_test.cc | 90 +++- yacl/crypto/primitives/ot/softspoken_ote.cc | 25 +- .../primitives/vole/{f2k => }/BUILD.bazel | 23 +- .../primitives/vole/{f2k => }/base_vole.h | 37 +- .../vole/{f2k => }/base_vole_test.cc | 76 ++-- yacl/crypto/primitives/vole/benchmark.cc | 268 ++++++++++++ yacl/crypto/primitives/vole/f2k/benchmark.cc | 264 ------------ .../primitives/vole/f2k/silent_vole_test.cc | 159 ------- yacl/crypto/primitives/vole/f2k/sparse_vole.h | 163 ------- .../primitives/vole/f2k/sparse_vole_test.cc | 269 ------------ yacl/crypto/primitives/vole/mp_vole.h | 215 ++++++++++ yacl/crypto/primitives/vole/mp_vole_test.cc | 116 +++++ .../primitives/vole/{f2k => }/silent_vole.cc | 148 +++---- .../primitives/vole/{f2k => }/silent_vole.h | 58 ++- .../primitives/vole/silent_vole_test.cc | 117 +++++ yacl/crypto/tools/BUILD.bazel | 9 + yacl/crypto/tools/common.h | 34 ++ yacl/crypto/utils/BUILD.bazel | 1 + yacl/crypto/utils/secparam.h | 29 ++ yacl/math/BUILD.bazel | 4 +- yacl/math/f2k/f2k.h | 13 +- yacl/math/gadget.h | 84 ++++ 39 files changed, 2135 insertions(+), 1415 deletions(-) rename yacl/crypto/primitives/{vole/f2k/sparse_vole.cc => dpf/mpfss.cc} (59%) create mode 100644 yacl/crypto/primitives/dpf/mpfss.h create mode 100644 yacl/crypto/primitives/dpf/mpfss_test.cc rename yacl/crypto/primitives/vole/{f2k => }/BUILD.bazel (87%) rename yacl/crypto/primitives/vole/{f2k => }/base_vole.h (86%) rename yacl/crypto/primitives/vole/{f2k => }/base_vole_test.cc (75%) create mode 100644 yacl/crypto/primitives/vole/benchmark.cc delete mode 100644 yacl/crypto/primitives/vole/f2k/benchmark.cc delete mode 100644 yacl/crypto/primitives/vole/f2k/silent_vole_test.cc delete mode 100644 yacl/crypto/primitives/vole/f2k/sparse_vole.h delete mode 100644 yacl/crypto/primitives/vole/f2k/sparse_vole_test.cc create mode 100644 yacl/crypto/primitives/vole/mp_vole.h create mode 100644 yacl/crypto/primitives/vole/mp_vole_test.cc rename yacl/crypto/primitives/vole/{f2k => }/silent_vole.cc (70%) rename yacl/crypto/primitives/vole/{f2k => }/silent_vole.h (89%) create mode 100644 yacl/crypto/primitives/vole/silent_vole_test.cc create mode 100644 yacl/crypto/tools/common.h diff --git a/.bazelversion b/.bazelversion index 024b066c..19b860c1 100644 --- a/.bazelversion +++ b/.bazelversion @@ -1 +1 @@ -6.2.1 +6.4.0 diff --git a/CHANGELOG.md b/CHANGELOG.md index a2c85fc6..1463e379 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,9 @@ All notable changes to this project will be documented in this file. > - Add `[Bugfix]` prefix for bug fixes > - Add `[API]` prefix for API changes +## Staging +- [Feature] Add Silent Vole (malicious version) + ## 2024-01-09 - [YACL] v0.4.2 - [Dependency] Bump: Openssl 3.0.12 (experimental) diff --git a/bazel/openssl.BUILD b/bazel/openssl.BUILD index d69a8605..37442510 100644 --- a/bazel/openssl.BUILD +++ b/bazel/openssl.BUILD @@ -33,8 +33,8 @@ CONFIGURE_OPTIONS = [ "--libdir=lib", "no-legacy", "no-weak-ssl-ciphers", - "no-shared", "no-tests", + "no-shared", "no-ui-console", ] @@ -59,7 +59,6 @@ yacl_configure_make( }), lib_name = "openssl", lib_source = ":all_srcs", - out_binaries = ["openssl"], # Note that for Linux builds, libssl must come before libcrypto on the linker command-line. # As such, libssl must be listed before libcrypto out_static_libs = [ diff --git a/bazel/repositories.bzl b/bazel/repositories.bzl index de077b24..7158d870 100644 --- a/bazel/repositories.bzl +++ b/bazel/repositories.bzl @@ -263,10 +263,10 @@ def _rules_foreign_cc(): maybe( http_archive, name = "rules_foreign_cc", - sha256 = "476303bd0f1b04cc311fc258f1708a5f6ef82d3091e53fd1977fa20383425a6a", - strip_prefix = "rules_foreign_cc-0.10.1", + sha256 = "2463288e7b2256a1dc61d62c0f970dcbe5dfc22e90c58e60d3119ce2e47209af", + strip_prefix = "rules_foreign_cc-c2e097455d2bbf92b2ae71611d1261ba79eb8aa8", urls = [ - "https://github.com/bazelbuild/rules_foreign_cc/archive/refs/tags/0.10.1.tar.gz", + "https://github.com/bazelbuild/rules_foreign_cc/archive/c2e097455d2bbf92b2ae71611d1261ba79eb8aa8.tar.gz", ], ) diff --git a/yacl/base/dynamic_bitset.h b/yacl/base/dynamic_bitset.h index 17664152..e76ccb0b 100644 --- a/yacl/base/dynamic_bitset.h +++ b/yacl/base/dynamic_bitset.h @@ -789,6 +789,8 @@ class dynamic_bitset { template constexpr void append(BlockInputIterator first, BlockInputIterator last); + constexpr void append(const dynamic_bitset& other); + /** * @brief Sets the bits to the result of binary AND on corresponding * pairs of bits of *this and @p rhs. @@ -2481,6 +2483,24 @@ constexpr void dynamic_bitset::append( assert(check_consistency()); } + +template +constexpr void dynamic_bitset::append( + const dynamic_bitset& other) { + const auto final_size = size() + other.size(); + const auto block_num = other.num_blocks(); + if (&other != this) { + auto other_data = other.data(); + append(other_data, other_data + block_num); + } else { + // Append a bitset to itself might cause an automatic reallocation + for (size_t i = 0; i < block_num; ++i) { + append(other.data()[i]); + } + } + resize(final_size); +} + template constexpr dynamic_bitset& dynamic_bitset::operator&=( diff --git a/yacl/base/dynamic_bitset_test.cc b/yacl/base/dynamic_bitset_test.cc index 88c88d3a..352207a0 100644 --- a/yacl/base/dynamic_bitset_test.cc +++ b/yacl/base/dynamic_bitset_test.cc @@ -123,7 +123,7 @@ TYPED_TEST(DynamicBitsetTest, PushPopTest) { EXPECT_EQ(bitset, check2); } -TYPED_TEST(DynamicBitsetTest, AppendTest) { +TYPED_TEST(DynamicBitsetTest, AppendBlockTest) { // GIVEN auto bitset = dynamic_bitset("0100101"); auto block = static_cast(crypto::FastRandU128()); @@ -137,6 +137,37 @@ TYPED_TEST(DynamicBitsetTest, AppendTest) { EXPECT_EQ(block, check); } +TYPED_TEST(DynamicBitsetTest, AppendBitSetTest) { + // GIVEN + auto bitset0 = dynamic_bitset("010010101010101"); + auto block = static_cast(*bitset0.data()); + + auto size = bitset0.size(); + // WHEN + bitset0.append(bitset0); + + // THEN + bitset0 >>= size; // right shift to remove the original bits + auto check = static_cast(*bitset0.data()); + EXPECT_EQ(block, check); +} + +TYPED_TEST(DynamicBitsetTest, AppendBitSetTest2) { + // GIVEN + auto bitset0 = dynamic_bitset("010010"); + auto bitset1 = dynamic_bitset("010010101010101"); + auto block = static_cast(*bitset1.data()); + + auto size = bitset0.size(); + // WHEN + bitset0.append(bitset1); + + // THEN + bitset0 >>= size; // right shift to remove the original bits + auto check = static_cast(*bitset1.data()); + EXPECT_EQ(block, check); +} + TYPED_TEST(DynamicBitsetTest, XorTest) { auto r1 = crypto::RandVec(kBlockNum); auto r2 = crypto::RandVec(kBlockNum); diff --git a/yacl/crypto/base/BUILD.bazel b/yacl/crypto/base/BUILD.bazel index f15ea987..80e9a13d 100644 --- a/yacl/crypto/base/BUILD.bazel +++ b/yacl/crypto/base/BUILD.bazel @@ -28,7 +28,6 @@ yacl_cc_library( "//yacl/utils:scope_guard", "@com_github_openssl_openssl//:openssl", ], - alwayslink = True, ) yacl_cc_library( diff --git a/yacl/crypto/base/key_utils.cc b/yacl/crypto/base/key_utils.cc index 7f5f110e..9ebd3835 100644 --- a/yacl/crypto/base/key_utils.cc +++ b/yacl/crypto/base/key_utils.cc @@ -16,6 +16,7 @@ #include +#include "yacl/crypto/base/openssl_wrappers.h" #include "yacl/io/stream/file_io.h" namespace yacl::crypto { @@ -351,8 +352,8 @@ Buffer ExportX509CertToBuf(const openssl::UniqueX509& x509) { openssl::UniqueBio bio(BIO_new(BIO_s_mem())); // create an empty bio // export certificate to bio - OSSL_RET_1(PEM_write_bio_X509(bio.get(), x509.get()), - "Failed PEM_export_bio_X509."); + OSSL_RET_1(PEM_write_bio_X509(bio.get(), x509.get())); + return BioToBuf(bio); } diff --git a/yacl/crypto/base/openssl_wrappers.h b/yacl/crypto/base/openssl_wrappers.h index bd8e225f..b5c47d4c 100644 --- a/yacl/crypto/base/openssl_wrappers.h +++ b/yacl/crypto/base/openssl_wrappers.h @@ -27,6 +27,7 @@ #include "openssl/decoder.h" #include "openssl/ec.h" #include "openssl/encoder.h" +#include "openssl/err.h" #include "openssl/evp.h" #include "openssl/pem.h" #include "openssl/provider.h" @@ -111,6 +112,17 @@ inline UniqueMac FetchEvpHmac() { return UniqueMac(EVP_MAC_fetch(nullptr, OSSL_MAC_NAME_HMAC, nullptr)); } +// see: https://en.wikibooks.org/wiki/OpenSSL/Error_handling +inline std::string GetOSSLErr() { + BIO* bio = BIO_new(BIO_s_mem()); + ERR_print_errors(bio); + char* buf; + size_t len = BIO_get_mem_data(bio, &buf); + std::string ret(buf, len); + BIO_free(bio); + return ret; +} + // --------------------------------- // Helpers for OpenSSL return values // --------------------------------- diff --git a/yacl/crypto/ossl-provider/BUILD.bazel b/yacl/crypto/ossl-provider/BUILD.bazel index 4e9c1fdf..7fe4d2fb 100644 --- a/yacl/crypto/ossl-provider/BUILD.bazel +++ b/yacl/crypto/ossl-provider/BUILD.bazel @@ -57,6 +57,8 @@ yacl_cc_library( name = "provider", srcs = [ "provider.cc", + ], + hdrs = [ "rand_impl.h", "version.h", ], diff --git a/yacl/crypto/ossl-provider/provider_test.cc b/yacl/crypto/ossl-provider/provider_test.cc index 9921d48e..db96e5ea 100644 --- a/yacl/crypto/ossl-provider/provider_test.cc +++ b/yacl/crypto/ossl-provider/provider_test.cc @@ -35,12 +35,13 @@ TEST(OpensslTest, ShouldWork) { // initialize a provider that was previously added with auto prov = openssl::UniqueProv( OSSL_PROVIDER_load(libctx.get(), GetProviderPath().c_str())); - YACL_ENFORCE(prov != nullptr); + YACL_ENFORCE(prov != nullptr, ERR_error_string(ERR_get_error(), nullptr)); // get provider's entropy source EVP_RAND* rand; - auto yes = EVP_RAND_fetch(libctx.get(), "Yes", - nullptr); /* yes = yacl entropy source */ - YACL_ENFORCE(yes != nullptr); + auto* yes = EVP_RAND_fetch(libctx.get(), "Yes", + nullptr); /* yes = yacl entropy source */ + + YACL_ENFORCE(yes != nullptr, ERR_error_string(ERR_get_error(), nullptr)); auto* yes_ctx = EVP_RAND_CTX_new(yes, nullptr); YACL_ENFORCE(yes_ctx != nullptr); EVP_RAND_instantiate(yes_ctx, 128, 0, nullptr, 0, nullptr); @@ -73,69 +74,68 @@ TEST(OpensslTest, ShouldWork) { EVP_RAND_CTX_free(rctx); } -// // https://www.openssl.org/docs/man3.0/man7/EVP_RAND-SEED-SRC.html -// TEST(OpensslTest, Example1) { -// EVP_RAND* rand; -// EVP_RAND_CTX* seed; -// EVP_RAND_CTX* rctx; -// unsigned char bytes[100]; -// OSSL_PARAM params[2]; -// OSSL_PARAM* p = params; -// unsigned int strength = 128; - -// /* Create a seed source */ -// rand = EVP_RAND_fetch(nullptr, "SEED-SRC", nullptr); -// seed = EVP_RAND_CTX_new(rand, nullptr); -// EVP_RAND_instantiate(seed, 128, 0, nullptr, 0, nullptr); - -// /* Feed this into a DRBG */ -// auto* tmp = EVP_RAND_fetch(nullptr, "CTR-DRBG", nullptr); -// // EVP_RAND_CTX_new() creates a new context for the RAND implementation -// rand. -// // If not NULL, parent specifies the seed source for this implementation. -// rctx = EVP_RAND_CTX_new(tmp, seed); -// YACL_ENFORCE(rctx != nullptr); - -// /* Configure the DRBG */ -// *p++ = OSSL_PARAM_construct_utf8_string(OSSL_DRBG_PARAM_CIPHER, -// (char*)"AES-256-CTR", 0); -// *p = OSSL_PARAM_construct_end(); -// EVP_RAND_instantiate(rctx, strength, 0, nullptr, 0, params); - -// int ret = -// EVP_RAND_generate(rctx, bytes, sizeof(bytes), strength, 0, nullptr, 0); -// EXPECT_EQ(ret, 1); - -// EVP_RAND_free(rand); -// EVP_RAND_free(tmp); -// EVP_RAND_CTX_free(rctx); -// EVP_RAND_CTX_free(seed); -// } - -// // https://www.openssl.org/docs/man3.0/man7/EVP_RAND-CTR-DRBG.html -// TEST(OpensslTest, Example2) { -// EVP_RAND* rand; -// EVP_RAND_CTX* rctx; -// unsigned char bytes[100]; -// OSSL_PARAM params[2]; -// OSSL_PARAM* p = params; -// unsigned int strength = 128; - -// rand = EVP_RAND_fetch(nullptr, "CTR-DRBG", nullptr); -// rctx = EVP_RAND_CTX_new(rand, nullptr); - -// *p++ = OSSL_PARAM_construct_utf8_string(OSSL_DRBG_PARAM_CIPHER, -// (char*)"AES-256-CTR", 0); -// *p = OSSL_PARAM_construct_end(); -// int ret0 = EVP_RAND_instantiate(rctx, strength, 0, nullptr, 0, params); -// EXPECT_EQ(ret0, 1); - -// int ret1 = -// EVP_RAND_generate(rctx, bytes, sizeof(bytes), strength, 0, nullptr, 0); -// EXPECT_EQ(ret1, 1); - -// EVP_RAND_free(rand); -// EVP_RAND_CTX_free(rctx); -// } +// https://www.openssl.org/docs/man3.0/man7/EVP_RAND-SEED-SRC.html +TEST(OpensslTest, Example1) { + EVP_RAND* rand; + EVP_RAND_CTX* seed; + EVP_RAND_CTX* rctx; + unsigned char bytes[100]; + OSSL_PARAM params[2]; + OSSL_PARAM* p = params; + unsigned int strength = 128; + + /* Create a seed source */ + rand = EVP_RAND_fetch(nullptr, "SEED-SRC", nullptr); + seed = EVP_RAND_CTX_new(rand, nullptr); + EVP_RAND_instantiate(seed, 128, 0, nullptr, 0, nullptr); + + /* Feed this into a DRBG */ + auto* tmp = EVP_RAND_fetch(nullptr, "CTR-DRBG", nullptr); + // EVP_RAND_CTX_new() creates a new context for the RAND implementation rand. + // If not NULL, parent specifies the seed source for this implementation. + rctx = EVP_RAND_CTX_new(tmp, seed); + YACL_ENFORCE(rctx != nullptr); + + /* Configure the DRBG */ + *p++ = OSSL_PARAM_construct_utf8_string(OSSL_DRBG_PARAM_CIPHER, + (char*)"AES-256-CTR", 0); + *p = OSSL_PARAM_construct_end(); + EVP_RAND_instantiate(rctx, strength, 0, nullptr, 0, params); + + int ret = + EVP_RAND_generate(rctx, bytes, sizeof(bytes), strength, 0, nullptr, 0); + EXPECT_EQ(ret, 1); + + EVP_RAND_free(rand); + EVP_RAND_free(tmp); + EVP_RAND_CTX_free(rctx); + EVP_RAND_CTX_free(seed); +} + +// https://www.openssl.org/docs/man3.0/man7/EVP_RAND-CTR-DRBG.html +TEST(OpensslTest, Example2) { + EVP_RAND* rand; + EVP_RAND_CTX* rctx; + unsigned char bytes[100]; + OSSL_PARAM params[2]; + OSSL_PARAM* p = params; + unsigned int strength = 128; + + rand = EVP_RAND_fetch(nullptr, "CTR-DRBG", nullptr); + rctx = EVP_RAND_CTX_new(rand, nullptr); + + *p++ = OSSL_PARAM_construct_utf8_string(OSSL_DRBG_PARAM_CIPHER, + (char*)"AES-256-CTR", 0); + *p = OSSL_PARAM_construct_end(); + int ret0 = EVP_RAND_instantiate(rctx, strength, 0, nullptr, 0, params); + EXPECT_EQ(ret0, 1); + + int ret1 = + EVP_RAND_generate(rctx, bytes, sizeof(bytes), strength, 0, nullptr, 0); + EXPECT_EQ(ret1, 1); + + EVP_RAND_free(rand); + EVP_RAND_CTX_free(rctx); +} } // namespace yacl::crypto diff --git a/yacl/crypto/primitives/dpf/BUILD.bazel b/yacl/crypto/primitives/dpf/BUILD.bazel index b0b54d4e..a01e2739 100644 --- a/yacl/crypto/primitives/dpf/BUILD.bazel +++ b/yacl/crypto/primitives/dpf/BUILD.bazel @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("//bazel:yacl.bzl", "yacl_cc_library", "yacl_cc_test") +load("//bazel:yacl.bzl", "AES_COPT_FLAGS", "yacl_cc_library", "yacl_cc_test") load("@rules_proto//proto:defs.bzl", "proto_library") load("@rules_cc//cc:defs.bzl", "cc_proto_library") @@ -52,3 +52,36 @@ cc_proto_library( name = "serializable_cc_proto", deps = [":serializable_proto"], ) + +yacl_cc_library( + name = "mpfss", + srcs = ["mpfss.cc"], + hdrs = ["mpfss.h"], + copts = AES_COPT_FLAGS, + deps = [ + "//yacl/base:aligned_vector", + "//yacl/base:dynamic_bitset", + "//yacl/base:int128", + "//yacl/crypto/primitives/ot:gywz_ote", + "//yacl/crypto/primitives/ot:ot_store", + "//yacl/crypto/primitives/ot:sgrr_ote", + "//yacl/crypto/tools:crhash", + "//yacl/crypto/utils:rand", + "//yacl/crypto/utils:secparam", + "//yacl/math:gadget", + "//yacl/math/f2k", + ], +) + +yacl_cc_test( + name = "mpfss_test", + srcs = ["mpfss_test.cc"], + copts = AES_COPT_FLAGS, + deps = [ + ":mpfss", + "//yacl/crypto/utils:rand", + "//yacl/link:test_util", + "//yacl/math:gadget", + "//yacl/math/f2k", + ], +) diff --git a/yacl/crypto/primitives/vole/f2k/sparse_vole.cc b/yacl/crypto/primitives/dpf/mpfss.cc similarity index 59% rename from yacl/crypto/primitives/vole/f2k/sparse_vole.cc rename to yacl/crypto/primitives/dpf/mpfss.cc index d9acf831..5dffc90d 100644 --- a/yacl/crypto/primitives/vole/f2k/sparse_vole.cc +++ b/yacl/crypto/primitives/dpf/mpfss.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "yacl/crypto/primitives/vole/f2k/sparse_vole.h" +#include "mpfss.h" #include #include @@ -20,9 +20,11 @@ #include "yacl/base/aligned_vector.h" #include "yacl/base/byte_container_view.h" #include "yacl/base/int128.h" +#include "yacl/crypto/primitives/ot/gywz_ote.h" +#include "yacl/crypto/primitives/ot/sgrr_ote.h" +#include "yacl/crypto/tools/crhash.h" #include "yacl/math/f2k/f2k.h" #include "yacl/math/gadget.h" -#include "yacl/utils/serialize.h" namespace yacl::crypto { @@ -30,114 +32,10 @@ namespace { constexpr uint32_t kSuperBatch = 16; } -// void SpVoleSend(const std::shared_ptr& ctx, -// const OtSendStore& /*rot*/ send_ot, uint32_t n, uint128_t w, -// absl::Span output) { -// SgrrOtExtSend(ctx, send_ot, n, output); -// uint128_t send_msg = w; -// for (uint32_t i = 0; i < n; ++i) { -// send_msg ^= output[i]; -// } -// ctx->SendAsync(ctx->NextRank(), yacl::SerializeUint128(send_msg), -// "SpVole_msg"); -// } - -// void SpVoleRecv(const std::shared_ptr& ctx, -// const OtRecvStore& /*rot*/ recv_ot, uint32_t n, uint32_t -// index, uint128_t v, absl::Span output) { -// SgrrOtExtRecv(ctx, recv_ot, n, index, output); -// auto recv_buff = ctx->Recv(ctx->NextRank(), "SpVole_msg"); -// auto recv_msg = DeserializeUint128(ByteContainerView(recv_buff)); -// for (uint32_t i = 0; i < n; ++i) { -// recv_msg ^= output[i]; -// } -// output[index] = recv_msg ^ v; -// } - -void SpVoleSend(const std::shared_ptr& ctx, - const OtSendStore& /*cot*/ send_ot, uint32_t n, uint128_t w, - absl::Span output) { - GywzOtExtSend(ctx, send_ot, n, output); - ParaCrHashInplace_128(output.subspan(0, n)); - uint128_t send_msg = w; - send_msg = std::reduce(output.begin(), output.begin() + n, send_msg, - std::bit_xor()); - ctx->SendAsync(ctx->NextRank(), SerializeUint128(send_msg), "SpVole_msg"); -} - -void SpVoleRecv(const std::shared_ptr& ctx, - const OtRecvStore& /*cot*/ recv_ot, uint32_t n, uint32_t index, - uint128_t v, absl::Span output) { - GywzOtExtRecv(ctx, recv_ot, n, index, output); - ParaCrHashInplace_128(output.subspan(0, n)); - output[index] = 0; - auto recv_buff = ctx->Recv(ctx->NextRank(), "SpVole_msg"); - auto recv_msg = DeserializeUint128(ByteContainerView(recv_buff)); - recv_msg = std::reduce(output.begin(), output.begin() + n, recv_msg, - std::bit_xor()); - output[index] = recv_msg ^ v; -} - -// void MpVoleSend(const std::shared_ptr& ctx, -// const OtSendStore& /*rot*/ send_ot, const MpVoleParam& param, -// absl::Span w, absl::Span output) { -// YACL_ENFORCE(param.assumption_ == LpnNoiseAsm::RegularNoise); -// YACL_ENFORCE(output.size() >= param.mp_vole_size_); -// YACL_ENFORCE(w.size() >= param.noise_num_); -// YACL_ENFORCE(send_ot.Size() >= param.require_ot_num_); - -// const auto& batch_num = param.noise_num_; -// const auto& batch_size = param.sp_vole_size_; -// const auto& last_batch_size = param.last_sp_vole_size_; - -// for (uint32_t i = 0; i < batch_num; ++i) { -// auto this_size = (i == batch_num - 1) ? last_batch_size : batch_size; - -// // TODO: @wenfan -// // "Slice" would force to slice original OtStore from "begin" to "end", -// // which might cause unexpected error. -// // It would be better to use "NextSlice" here, but it's not a const -// // function. -// auto ot_slice = send_ot.Slice( -// i * math::Log2Ceil(batch_size), -// i * math::Log2Ceil(batch_size) + math::Log2Ceil(this_size)); -// SpVoleSend(ctx, ot_slice, this_size, w[i], -// output.subspan(i * batch_size, this_size)); -// } -// } - -// void MpVoleRecv(const std::shared_ptr& ctx, -// const OtRecvStore& /*rot*/ recv_ot, const MpVoleParam& param, -// absl::Span v, absl::Span output) { -// YACL_ENFORCE(param.assumption_ == LpnNoiseAsm::RegularNoise); -// YACL_ENFORCE(output.size() >= param.mp_vole_size_); -// YACL_ENFORCE(v.size() >= param.noise_num_); -// YACL_ENFORCE(recv_ot.Size() >= param.require_ot_num_); - -// const auto& batch_num = param.noise_num_; -// const auto& batch_size = param.sp_vole_size_; -// const auto& last_batch_size = param.last_sp_vole_size_; -// const auto& indexes = param.indexes_; - -// for (uint32_t i = 0; i < batch_num; ++i) { -// auto this_size = (i == batch_num - 1) ? last_batch_size : batch_size; - -// // TODO: @wenfan -// // "Slice" would force to slice original OtStore from "begin" to "end", -// // which might cause unexpected error. -// // It would be better to use "NextSlice" here, but it's not a const -// // function. -// auto ot_slice = recv_ot.Slice( -// i * math::Log2Ceil(batch_size), -// i * math::Log2Ceil(batch_size) + math::Log2Ceil(this_size)); -// SpVoleRecv(ctx, ot_slice, this_size, indexes[i], v[i], -// output.subspan(i * batch_size, this_size)); -// } -// } - -void MpVoleSend(const std::shared_ptr& ctx, - const OtSendStore& /*cot*/ send_ot, const MpVoleParam& param, - absl::Span w, absl::Span output) { +void MpfssSend(const std::shared_ptr& ctx, + const OtSendStore& /*cot*/ send_ot, const MpFssParam& param, + absl::Span w, absl::Span output, + const MpfssOp& op) { YACL_ENFORCE(param.assumption_ == LpnNoiseAsm::RegularNoise); YACL_ENFORCE(output.size() >= param.mp_vole_size_); YACL_ENFORCE(w.size() >= param.noise_num_); @@ -147,7 +45,9 @@ void MpVoleSend(const std::shared_ptr& ctx, const auto& batch_size = param.sp_vole_size_; const auto& last_batch_size = param.last_sp_vole_size_; - auto send_msg = AlignedVector(w.data(), w.data() + batch_num); + AlignedVector send_msgs(batch_num, 0); + std::transform(send_msgs.cbegin(), send_msgs.cend(), w.cbegin(), + send_msgs.begin(), op.sub); for (uint32_t i = 0; i < batch_num; ++i) { auto this_size = (i == batch_num - 1) ? last_batch_size : batch_size; @@ -163,28 +63,22 @@ void MpVoleSend(const std::shared_ptr& ctx, i * math::Log2Ceil(batch_size) + math::Log2Ceil(this_size)); GywzOtExtSend(ctx, ot_slice, this_size, this_span); + // Break the correlation + ParaCrHashInplace_128(this_span); + send_msgs[i] = + std::reduce(this_span.begin(), this_span.end(), send_msgs[i], op.add); } - // Break the correlation - ParaCrHashInplace_128(output.subspan(0, param.mp_vole_size_)); - for (uint32_t i = 0; i < batch_num; ++i) { - auto this_size = (i == batch_num - 1) ? last_batch_size : batch_size; - auto this_span = output.subspan(i * batch_size, this_size); - send_msg[i] = std::reduce(this_span.begin(), this_span.end(), send_msg[i], - std::bit_xor()); - } - ctx->SendAsync( ctx->NextRank(), - ByteContainerView(send_msg.data(), send_msg.size() * sizeof(uint128_t)), + ByteContainerView(send_msgs.data(), send_msgs.size() * sizeof(uint128_t)), "MpVole_msg"); } -void MpVoleRecv(const std::shared_ptr& ctx, - const OtRecvStore& /*cot*/ recv_ot, const MpVoleParam& param, - absl::Span v, absl::Span output) { +void MpfssRecv(const std::shared_ptr& ctx, + const OtRecvStore& /*cot*/ recv_ot, const MpFssParam& param, + absl::Span output, const MpfssOp& op) { YACL_ENFORCE(param.assumption_ == LpnNoiseAsm::RegularNoise); YACL_ENFORCE(output.size() >= param.mp_vole_size_); - YACL_ENFORCE(v.size() >= param.noise_num_); YACL_ENFORCE(recv_ot.Size() >= param.require_ot_num_); const auto& batch_num = param.noise_num_; @@ -192,6 +86,8 @@ void MpVoleRecv(const std::shared_ptr& ctx, const auto& last_batch_size = param.last_sp_vole_size_; const auto& indexes = param.indexes_; + AlignedVector dpf_sum(batch_num, 0); + for (uint32_t i = 0; i < batch_num; ++i) { auto this_size = (i == batch_num - 1) ? last_batch_size : batch_size; auto this_span = output.subspan(i * batch_size, this_size); @@ -205,33 +101,138 @@ void MpVoleRecv(const std::shared_ptr& ctx, i * math::Log2Ceil(batch_size), i * math::Log2Ceil(batch_size) + math::Log2Ceil(this_size)); GywzOtExtRecv(ctx, ot_slice, this_size, indexes[i], this_span); + ParaCrHashInplace_128(this_span); + dpf_sum[i] = + std::reduce(this_span.begin(), this_span.end(), dpf_sum[i], op.add); } - ParaCrHashInplace_128(output.subspan(0, param.mp_vole_size_)); - auto recv_buff = ctx->Recv(ctx->NextRank(), "MpVole_msg"); - YACL_ENFORCE(static_cast(recv_buff.size()) >= batch_num * sizeof(uint128_t)); - auto recv_msg = + auto recv_msgs = absl::MakeSpan(reinterpret_cast(recv_buff.data()), batch_num); + for (uint32_t i = 0; i < batch_num; ++i) { + auto tmp = op.sub(recv_msgs[i], dpf_sum[i]); + output[i * batch_size + indexes[i]] = + op.add(output[i * batch_size + indexes[i]], tmp); + } +} + +void MpfssSend(const std::shared_ptr& ctx, + const OtSendStore& /*cot*/ send_ot, const MpFssParam& param, + absl::Span w, absl::Span output, + const MpfssOp& op) { + YACL_ENFORCE(param.assumption_ == LpnNoiseAsm::RegularNoise); + YACL_ENFORCE(output.size() >= param.mp_vole_size_); + YACL_ENFORCE(w.size() >= param.noise_num_); + YACL_ENFORCE(send_ot.Size() >= param.require_ot_num_); + + const auto& batch_num = param.noise_num_; + const auto& batch_size = param.sp_vole_size_; + const auto& last_batch_size = param.last_sp_vole_size_; + + AlignedVector send_msgs(batch_num); + std::transform(send_msgs.cbegin(), send_msgs.cend(), w.cbegin(), + send_msgs.begin(), op.sub); + + auto dpf_buff = + Buffer(std::max(batch_size, last_batch_size) * sizeof(uint128_t)); + auto dpf_span = absl::MakeSpan(dpf_buff.data(), + dpf_buff.size() / sizeof(uint128_t)); + // AlignedVector dpf_buff(std::max(batch_size, last_batch_size)); for (uint32_t i = 0; i < batch_num; ++i) { auto this_size = (i == batch_num - 1) ? last_batch_size : batch_size; - auto this_span = output.subspan(i * batch_size, this_size); - this_span[indexes[i]] = 0; // set punctured value as zero - recv_msg[i] = std::reduce(this_span.begin(), this_span.end(), recv_msg[i], - std::bit_xor()); - this_span[indexes[i]] = recv_msg[i] ^ v[i]; + auto this_span = dpf_span.subspan(0, this_size); + + // TODO: @wenfan + // "Slice" would force to slice original OtStore from "begin" to "end", + // which might cause unexpected error. + // It would be better to use "NextSlice" here, but it's not a const + // function. + auto ot_slice = send_ot.Slice( + i * math::Log2Ceil(batch_size), + i * math::Log2Ceil(batch_size) + math::Log2Ceil(this_size)); + + GywzOtExtSend(ctx, ot_slice, this_size, this_span); + ParaCrHashInplace_128(this_span); + + // Break the correlation + std::transform( + this_span.begin(), this_span.end(), output.data() + i * batch_size, + [](const uint128_t& val) { return static_cast(val); }); + + send_msgs[i] = std::reduce(output.data() + i * batch_size, + output.data() + i * batch_size + this_size, + send_msgs[i], op.add); + } + ctx->SendAsync( + ctx->NextRank(), + ByteContainerView(send_msgs.data(), send_msgs.size() * sizeof(uint128_t)), + "MpVole_msg"); +} + +void MpfssRecv(const std::shared_ptr& ctx, + const OtRecvStore& /*cot*/ recv_ot, const MpFssParam& param, + absl::Span output, const MpfssOp& op) { + YACL_ENFORCE(param.assumption_ == LpnNoiseAsm::RegularNoise); + YACL_ENFORCE(output.size() >= param.mp_vole_size_); + YACL_ENFORCE(recv_ot.Size() >= param.require_ot_num_); + + const auto& batch_num = param.noise_num_; + const auto& batch_size = param.sp_vole_size_; + const auto& last_batch_size = param.last_sp_vole_size_; + const auto& indexes = param.indexes_; + + auto dpf_buf = + Buffer(std::max(batch_size, last_batch_size) * sizeof(uint128_t)); + auto dpf_span = absl::MakeSpan(dpf_buf.data(), + dpf_buf.size() / sizeof(uint128_t)); + // AlignedVector dpf_buff(std::max(batch_size, last_batch_size)); + AlignedVector dpf_sum(batch_num, 0); + + for (uint32_t i = 0; i < batch_num; ++i) { + auto this_size = (i == batch_num - 1) ? last_batch_size : batch_size; + auto this_span = dpf_span.subspan(0, this_size); + + // TODO: @wenfan + // "Slice" would force to slice original OtStore from "begin" to "end", + // which might cause unexpected error. + // It would be better to use "NextSlice" here, but it's not a const + // function. + auto ot_slice = recv_ot.Slice( + i * math::Log2Ceil(batch_size), + i * math::Log2Ceil(batch_size) + math::Log2Ceil(this_size)); + GywzOtExtRecv(ctx, ot_slice, this_size, indexes[i], this_span); + ParaCrHashInplace_128(this_span); + + std::transform( + this_span.begin(), this_span.end(), output.data() + i * batch_size, + [](const uint128_t& val) { return static_cast(val); }); + dpf_sum[i] = std::reduce(output.data() + i * batch_size, + output.data() + i * batch_size + this_size, + dpf_sum[i], op.add); + } + + auto recv_buff = ctx->Recv(ctx->NextRank(), "MpVole_msg"); + YACL_ENFORCE(static_cast(recv_buff.size()) >= + batch_num * sizeof(uint64_t)); + + auto recv_msgs = + absl::MakeSpan(reinterpret_cast(recv_buff.data()), batch_num); + for (uint32_t i = 0; i < batch_num; ++i) { + auto tmp = op.sub(recv_msgs[i], dpf_sum[i]); + output[i * batch_size + indexes[i]] = + op.add(output[i * batch_size + indexes[i]], tmp); } } -void MpVoleSend_fixed_index(const std::shared_ptr& ctx, - const OtSendStore& /*cot*/ send_ot, - const MpVoleParam& param, - absl::Span w, - absl::Span output) { +void MpfssSend_fixed_index(const std::shared_ptr& ctx, + const OtSendStore& /*cot*/ send_ot, + MpFssParam& param, absl::Span w, + absl::Span output, + const MpfssOp& op) { YACL_ENFORCE(param.assumption_ == LpnNoiseAsm::RegularNoise); YACL_ENFORCE(output.size() >= param.mp_vole_size_); YACL_ENFORCE(w.size() >= param.noise_num_); @@ -244,7 +245,9 @@ void MpVoleSend_fixed_index(const std::shared_ptr& ctx, const auto last_batch_length = math::Log2Ceil(last_batch_size); // Copy vector w - auto spvole_sum = AlignedVector(w.data(), w.data() + batch_num); + AlignedVector dpf_sum(batch_num, 0); + std::transform(dpf_sum.cbegin(), dpf_sum.cend(), w.cbegin(), dpf_sum.begin(), + op.sub); // send message buff for GYWZ OTe auto gywz_send_msgs = AlignedVector( batch_length * (kSuperBatch - 1) + last_batch_length); @@ -277,9 +280,8 @@ void MpVoleSend_fixed_index(const std::shared_ptr& ctx, // Use CrHash to break the correlation ParaCrHashInplace_128(this_span); // this_span xor - spvole_sum[batch_idx] = - std::reduce(this_span.begin(), this_span.end(), spvole_sum[batch_idx], - std::bit_xor()); + dpf_sum[batch_idx] = std::reduce(this_span.begin(), this_span.end(), + dpf_sum[batch_idx], op.add); } auto msg_length = kSuperBatch * batch_length; @@ -292,21 +294,19 @@ void MpVoleSend_fixed_index(const std::shared_ptr& ctx, "GYWZ_OTE: messages"); } - auto& send_msgs = spvole_sum; + auto& send_msgs = dpf_sum; ctx->SendAsync( ctx->NextRank(), ByteContainerView(send_msgs.data(), send_msgs.size() * sizeof(uint128_t)), "MPVOLE:messages"); } -void MpVoleRecv_fixed_index(const std::shared_ptr& ctx, - const OtRecvStore& /*cot*/ recv_ot, - const MpVoleParam& param, - absl::Span v, - absl::Span output) { +void MpfssRecv_fixed_index(const std::shared_ptr& ctx, + const OtRecvStore& /*cot*/ recv_ot, + MpFssParam& param, absl::Span output, + const MpfssOp& op) { YACL_ENFORCE(param.assumption_ == LpnNoiseAsm::RegularNoise); YACL_ENFORCE(output.size() >= param.mp_vole_size_); - YACL_ENFORCE(v.size() >= param.noise_num_); YACL_ENFORCE(recv_ot.Size() >= param.require_ot_num_); const auto& batch_num = param.noise_num_; @@ -315,12 +315,12 @@ void MpVoleRecv_fixed_index(const std::shared_ptr& ctx, const auto batch_length = math::Log2Ceil(batch_size); const auto last_batch_length = math::Log2Ceil(last_batch_size); - const auto& indexes = param.indexes_; + auto& indexes = param.indexes_; const auto super_batch_num = math::DivCeil(batch_num, kSuperBatch); // Copy vector v - auto spvole_sum = AlignedVector(v.begin(), v.begin() + batch_num); + auto dpf_sum = AlignedVector(batch_num, 0); for (uint32_t s = 0; s < super_batch_num; ++s) { const uint32_t bound = @@ -353,14 +353,24 @@ void MpVoleRecv_fixed_index(const std::shared_ptr& ctx, batch_idx * batch_length + this_length); auto recv_span = absl::MakeSpan(gywz_recv_msgs.data() + i * batch_length, this_length); + + uint32_t real_index = 0; + for (size_t i = 0; i < this_length; ++i) { + real_index |= ot_slice.GetChoice(i) << i; + } + if (indexes[batch_idx] != real_index) { + SPDLOG_DEBUG( + "batch_idx {} , param.index_ ({}) and ot.choices mismatch ({}) !!!", + batch_idx, indexes[batch_idx], real_index); + indexes[batch_idx] = real_index; + } // GywzOtExt is single-point COT GywzOtExtRecv_fixed_index(ot_slice, this_size, this_span, recv_span); // Use CrHash to break the correlation ParaCrHashInplace_128(this_span); // this_span xor - spvole_sum[batch_idx] = - std::reduce(this_span.begin(), this_span.end(), spvole_sum[batch_idx], - std::bit_xor()); + dpf_sum[batch_idx] = std::reduce(this_span.begin(), this_span.end(), + dpf_sum[batch_idx], op.add); } } @@ -373,15 +383,17 @@ void MpVoleRecv_fixed_index(const std::shared_ptr& ctx, absl::MakeSpan(reinterpret_cast(recv_buff.data()), batch_num); for (uint32_t i = 0; i < batch_num; ++i) { - output[i * batch_size + indexes[i]] ^= recv_msgs[i] ^ spvole_sum[i]; + auto tmp = op.sub(recv_msgs[i], dpf_sum[i]); + output[i * batch_size + indexes[i]] = + op.add(output[i * batch_size + indexes[i]], tmp); } } -void MpVoleSend_fixed_index(const std::shared_ptr& ctx, - const OtSendStore& /*cot*/ send_ot, - const MpVoleParam& param, - absl::Span w, - absl::Span output) { +void MpfssSend_fixed_index(const std::shared_ptr& ctx, + const OtSendStore& /*cot*/ send_ot, + MpFssParam& param, absl::Span w, + absl::Span output, + const MpfssOp& op) { YACL_ENFORCE(param.assumption_ == LpnNoiseAsm::RegularNoise); YACL_ENFORCE(output.size() >= param.mp_vole_size_); YACL_ENFORCE(w.size() >= param.noise_num_); @@ -394,11 +406,17 @@ void MpVoleSend_fixed_index(const std::shared_ptr& ctx, const auto last_batch_length = math::Log2Ceil(last_batch_size); // copy w - auto spvole_sum = AlignedVector(w.begin(), w.begin() + batch_num); + AlignedVector dpf_sum(batch_num, 0); + std::transform(dpf_sum.cbegin(), dpf_sum.cend(), w.cbegin(), dpf_sum.begin(), + op.sub); // GywzOtExt need uint128_t buffer - auto spvole_buff = - AlignedVector(1 << std::max(batch_length, last_batch_length)); - auto spvole_span = absl::MakeSpan(spvole_buff); + auto dpf_buf = Buffer((1 << std::max(batch_length, last_batch_length)) * + sizeof(uint128_t)); + // auto dpf_buf = + // AlignedVector(1 << std::max(batch_length, + // last_batch_length)); + auto dpf_span = absl::MakeSpan(dpf_buf.data(), + dpf_buf.size() / sizeof(uint128_t)); // send message buffer for GYWZ OTe auto gywz_send_msgs = AlignedVector( batch_length * (kSuperBatch - 1) + last_batch_length); @@ -419,7 +437,7 @@ void MpVoleSend_fixed_index(const std::shared_ptr& ctx, // full_size = 1 << this_length, would avoid copying in GywzOtExt auto full_size = 1 << this_length; auto batch_idx = s * kSuperBatch + i; - auto this_span = spvole_span.subspan(0, full_size); + auto this_span = dpf_span.subspan(0, full_size); // TODO: @wenfan // "Slice" would force to slice original OtStore from "begin" to "end", @@ -438,10 +456,10 @@ void MpVoleSend_fixed_index(const std::shared_ptr& ctx, output.data() + batch_idx * batch_size, [](uint128_t t) -> uint64_t { return t; }); // this_span xor - spvole_sum[batch_idx] = + dpf_sum[batch_idx] = std::reduce(output.data() + batch_idx * batch_size, output.data() + batch_idx * batch_size + this_size, - spvole_sum[batch_idx], std::bit_xor()); + dpf_sum[batch_idx], op.add); } auto msg_length = kSuperBatch * batch_length; @@ -454,21 +472,19 @@ void MpVoleSend_fixed_index(const std::shared_ptr& ctx, "GYWZ_OTE: messages"); } - auto& send_msgs = spvole_sum; + auto& send_msgs = dpf_sum; ctx->SendAsync( ctx->NextRank(), ByteContainerView(send_msgs.data(), send_msgs.size() * sizeof(uint64_t)), "MPVOLE:messages"); } -void MpVoleRecv_fixed_index(const std::shared_ptr& ctx, - const OtRecvStore& /*cot*/ recv_ot, - const MpVoleParam& param, - absl::Span v, - absl::Span output) { +void MpfssRecv_fixed_index(const std::shared_ptr& ctx, + const OtRecvStore& /*cot*/ recv_ot, + MpFssParam& param, absl::Span output, + const MpfssOp& op) { YACL_ENFORCE(param.assumption_ == LpnNoiseAsm::RegularNoise); YACL_ENFORCE(output.size() >= param.mp_vole_size_); - YACL_ENFORCE(v.size() >= param.noise_num_); YACL_ENFORCE(recv_ot.Size() >= param.require_ot_num_); const auto& batch_num = param.noise_num_; @@ -477,16 +493,19 @@ void MpVoleRecv_fixed_index(const std::shared_ptr& ctx, const auto batch_length = math::Log2Ceil(batch_size); const auto last_batch_length = math::Log2Ceil(last_batch_size); - const auto& indexes = param.indexes_; + auto& indexes = param.indexes_; const auto super_batch_num = math::DivCeil(batch_num, kSuperBatch); - // Copy vector v - auto spvole_sum = AlignedVector(v.begin(), v.begin() + batch_num); + auto dpf_sum = AlignedVector(batch_num, 0); // GywzOtExt need uint128_t buffer - auto spvole_buff = - AlignedVector(1 << std::max(batch_length, last_batch_length)); - auto spvole_span = absl::MakeSpan(spvole_buff); + auto dpf_buf = Buffer((1 << std::max(batch_length, last_batch_length)) * + sizeof(uint128_t)); + // auto dpf_buf = + // AlignedVector(1 << std::max(batch_length, + // last_batch_length)); + auto dpf_span = absl::MakeSpan(dpf_buf.data(), + dpf_buf.size() / sizeof(uint128_t)); for (uint32_t s = 0; s < super_batch_num; ++s) { const uint32_t bound = @@ -512,13 +531,25 @@ void MpVoleRecv_fixed_index(const std::shared_ptr& ctx, // full_size = 1 << this_length, would avoid copying in GywzOtExt auto full_size = 1 << this_length; auto batch_idx = s * kSuperBatch + i; - auto this_span = spvole_span.subspan(0, full_size); + auto this_span = dpf_span.subspan(0, full_size); // TODO: @wenfan // "Slice" would force to slice original OtStore from "begin" to "end", // which might cause unexpected error. // It would be better to use "NextSlice" here, but it's not a const auto ot_slice = recv_ot.Slice(batch_idx * batch_length, batch_idx * batch_length + this_length); + + uint32_t real_index = 0; + for (size_t i = 0; i < this_length; ++i) { + real_index |= ot_slice.GetChoice(i) << i; + } + if (indexes[batch_idx] != real_index) { + SPDLOG_DEBUG( + "batch_idx {} , param.index_ ({}) and ot.choices mismatch ({}) !!!", + batch_idx, indexes[batch_idx], real_index); + indexes[batch_idx] = real_index; + } + auto recv_span = absl::MakeSpan(gywz_recv_msgs.data() + i * batch_length, this_length); // GywzOtExt is single-point COT @@ -528,12 +559,12 @@ void MpVoleRecv_fixed_index(const std::shared_ptr& ctx, // convert to uint64_t std::transform(this_span.begin(), this_span.begin() + this_size, output.data() + batch_idx * batch_size, - [](uint128_t t) -> uint64_t { return t; }); + [](uint128_t t) { return static_cast(t); }); // this_span xor - spvole_sum[batch_idx] = + dpf_sum[batch_idx] = std::reduce(output.data() + batch_idx * batch_size, output.data() + batch_idx * batch_size + this_size, - spvole_sum[batch_idx], std::bit_xor()); + dpf_sum[batch_idx], op.add); } } @@ -544,7 +575,9 @@ void MpVoleRecv_fixed_index(const std::shared_ptr& ctx, absl::MakeSpan(reinterpret_cast(recv_buff.data()), batch_num); for (uint32_t i = 0; i < batch_num; ++i) { - output[i * batch_size + indexes[i]] ^= recv_msgs[i] ^ spvole_sum[i]; + auto tmp = op.sub(recv_msgs[i], dpf_sum[i]); + output[i * batch_size + indexes[i]] = + op.add(output[i * batch_size + indexes[i]], tmp); } } diff --git a/yacl/crypto/primitives/dpf/mpfss.h b/yacl/crypto/primitives/dpf/mpfss.h new file mode 100644 index 00000000..70cc8778 --- /dev/null +++ b/yacl/crypto/primitives/dpf/mpfss.h @@ -0,0 +1,214 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +/* submodules */ +#include "yacl/crypto/primitives/ot/ot_store.h" +#include "yacl/crypto/utils/rand.h" +#include "yacl/crypto/utils/secparam.h" + +/* security parameter declaration */ +// this module is only a wrapper, no need for security parameter definition + +namespace yacl::crypto { + +// 2PC Multi-point functional secret sharing (MPFSS) implementation +// +// (n,t)-MPFSS is described in https://eprint.iacr.org/2019/273.pdf Section 4. +// In short, Sender and Receiver would input t-index (idx_1 , ... , idx_t) and +// t-element (val_1 , ... , val_t) respectively, and then get an output with +// n-element (output_1 , ... , output_n), such that: +// - for all k not in (idx_1 , ... , idx_t), Sender.output[k] = +// Receiver.output[k] +// - When k = idx_i, then Sender.output[k] = Receiver.output[k] + val_i +// +// Besides, in reference https://eprint.iacr.org/2019/1159.pdf Section 4, +// punctured PRF could be viewed as 2PC DPF. +// + +struct MpFssParam { + uint64_t base_vole_num_; + uint64_t noise_num_; + uint64_t sp_vole_size_; + uint64_t last_sp_vole_size_; + // mp_vole_size_ = sp_vole_size_ * (noise_num_ - 1) + last_sp_vole_size_ + uint64_t mp_vole_size_; // total size + uint64_t require_ot_num_; // total ot num + + LpnNoiseAsm assumption_ = LpnNoiseAsm::RegularNoise; + std::vector indexes_ = std::vector(0); // size zero + + bool is_mal_{false}; + + MpFssParam() : MpFssParam(1, 2, LpnNoiseAsm::RegularNoise, false) {} + + MpFssParam(uint64_t noise_num, uint64_t mp_vole_size, bool mal = false) + : MpFssParam(noise_num, mp_vole_size, LpnNoiseAsm::RegularNoise, mal) {} + + // full constructor + MpFssParam(uint64_t noise_num, uint64_t mp_vole_size, LpnNoiseAsm assumption, + bool mal = false) { + YACL_ENFORCE(assumption == LpnNoiseAsm::RegularNoise); + YACL_ENFORCE(noise_num > 0); + + is_mal_ = mal; + base_vole_num_ = (is_mal_ == false) ? noise_num : noise_num + 1; + noise_num_ = noise_num; + mp_vole_size_ = mp_vole_size; + assumption_ = assumption; + + sp_vole_size_ = mp_vole_size_ / noise_num_; + last_sp_vole_size_ = mp_vole_size_ - (noise_num_ - 1) * sp_vole_size_; + + YACL_ENFORCE(sp_vole_size_ > 1, + "The size of SpVole should be greater than 1, because " + "1-out-of-1 SpVole is meaningless"); + + require_ot_num_ = math::Log2Ceil(sp_vole_size_) * (noise_num_ - 1) + + math::Log2Ceil(last_sp_vole_size_); + } + + // [Warning] not strictly uniformly random + void GenIndexes() { + indexes_ = RandVec(noise_num_); + for (uint32_t i = 0; i < noise_num_ - 1; ++i) { + indexes_[i] %= sp_vole_size_; + } + indexes_[noise_num_ - 1] %= last_sp_vole_size_; + } + + void SetIndexes(absl::Span indexes) { + YACL_ENFORCE(indexes.size() >= noise_num_); + for (uint32_t i = 0; i < noise_num_ - 1; ++i) { + indexes_[i] = indexes[i] % sp_vole_size_; + } + indexes_[noise_num_ - 1] %= last_sp_vole_size_; + } + + // Convert index_ into choices for OT + dynamic_bitset GenChoices() { + YACL_ENFORCE(indexes_.size() == noise_num_); + + auto choices = dynamic_bitset(require_ot_num_); + + uint64_t pos = 0; + auto sp_vole_length = math::Log2Ceil(sp_vole_size_); + auto last_length = math::Log2Ceil(last_sp_vole_size_); + for (size_t i = 0; i < noise_num_; ++i) { + auto this_length = (i == noise_num_ - 1) ? last_length : sp_vole_length; + uint32_t bound = 1 << this_length; + for (uint32_t mask = 1; mask < bound; mask <<= 1) { + choices.set(pos, indexes_[i] & mask); + ++pos; + } + } + + return choices; + } +}; + +template +class MpfssOp { + public: + std::function add = std::bit_xor(); + std::function sub = std::bit_xor(); + + // Maybe, we could define a function to convert uint128_t to T + // - std::function convert = static_cast(); + // - std::function convert = PRF(); + + // default ctor + MpfssOp() { + add = std::bit_xor(); + sub = std::bit_xor(); + } + + // standard ctor + MpfssOp(std::function op1, + std::function op2) { + add = op1; + sub = op2; + } +}; + +template +MpfssOp MakeMpfssOp(std::function op1, + std::function op2) { + return MpfssOp(op1, op2); +} + +// Multi-point functional secret share with Regular Noise Distribution (GYWZ-OTe +// based) [Warning] low efficiency, too much send action + +// GF(2^128) or Ring(2^128) +void MpfssSend(const std::shared_ptr& ctx, + const OtSendStore& /*cot*/ send_ot, const MpFssParam& param, + absl::Span w, absl::Span output, + const MpfssOp& op = MpfssOp()); + +void MpfssRecv(const std::shared_ptr& ctx, + const OtRecvStore& /*cot*/ recv_ot, const MpFssParam& param, + absl::Span output, + const MpfssOp& op = MpfssOp()); + +// GF(2^64) or Ring(2^64) +void MpfssSend(const std::shared_ptr& ctx, + const OtSendStore& /*cot*/ send_ot, const MpFssParam& param, + absl::Span w, absl::Span output, + const MpfssOp& op = MpfssOp()); + +void MpfssRecv(const std::shared_ptr& ctx, + const OtRecvStore& /*cot*/ recv_ot, const MpFssParam& param, + absl::Span output, + const MpfssOp& op = MpfssOp()); +// +// -------------------------- +// Customized +// -------------------------- +// +// Multi-point functional secret share with Regular Noise Distribution (GYWZ-OTe +// based) Most efficiency! Punctured indexes would be determined by the choices +// of OtStore. But "MpfssSend_fixed_index/MpfssRecv_fixed_index" would not check +// whether the indexes determined by OtStore and the indexes provided by +// MpFssParam are same. + +// GF(2^128) or Ring(2^128) +void MpfssSend_fixed_index(const std::shared_ptr& ctx, + const OtSendStore& /*cot*/ send_ot, + MpFssParam& param, absl::Span w, + absl::Span output, + const MpfssOp& op = MpfssOp()); + +void MpfssRecv_fixed_index(const std::shared_ptr& ctx, + const OtRecvStore& /*cot*/ recv_ot, + MpFssParam& param, absl::Span output, + const MpfssOp& op = MpfssOp()); + +// GF(2^64) or Ring(2^64) +void MpfssSend_fixed_index(const std::shared_ptr& ctx, + const OtSendStore& /*cot*/ send_ot, + MpFssParam& param, absl::Span w, + absl::Span output, + const MpfssOp& op = MpfssOp()); + +void MpfssRecv_fixed_index(const std::shared_ptr& ctx, + const OtRecvStore& /*cot*/ recv_ot, + MpFssParam& param, absl::Span output, + const MpfssOp& op = MpfssOp()); + +} // namespace yacl::crypto diff --git a/yacl/crypto/primitives/dpf/mpfss_test.cc b/yacl/crypto/primitives/dpf/mpfss_test.cc new file mode 100644 index 00000000..2925c0b4 --- /dev/null +++ b/yacl/crypto/primitives/dpf/mpfss_test.cc @@ -0,0 +1,229 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mpfss.h" + +#include + +#include +#include +#include +#include +#include + +#include "yacl/base/exception.h" +#include "yacl/crypto/utils/rand.h" +#include "yacl/link/test_util.h" +#include "yacl/math/f2k/f2k.h" +#include "yacl/math/gadget.h" + +namespace yacl::crypto { + +struct TestParam { + size_t num; + size_t index_num; +}; + +class Mpfss64Test + : public ::testing::TestWithParam< + std::tuple> {}; + +class Mpfss128Test + : public ::testing::TestWithParam< + std::tuple> {}; + +template +MpfssOp CreateMpfssOp(bool xor_mode) { + MpfssOp ret; + if (xor_mode) { + ret = MakeMpfssOp(std::bit_xor(), std::bit_xor()); + } else { + ret = MakeMpfssOp(std::plus(), std::minus()); + } + return ret; +} + +TEST_P(Mpfss64Test, Work) { + auto lctxs = link::test::SetupWorld(2); // setup network + const auto op = CreateMpfssOp(std::get<0>(GetParam())); + const auto is_fixed = std::get<1>(GetParam()); + const uint64_t num = std::get<2>(GetParam()).num; + const uint64_t index_num = std::get<2>(GetParam()).index_num; + + MpFssParam param(index_num, num); + param.GenIndexes(); + + auto choices = RandBits>(param.require_ot_num_); + if (is_fixed) { + choices = param.GenChoices(); + } + auto cot = MockCots(param.require_ot_num_, FastRandU128(), choices); + + std::vector s_output(num); + std::vector r_output(num); + + auto w = RandVec(index_num); + + auto sender = std::async([&] { + if (is_fixed) { + MpfssSend_fixed_index(lctxs[0], cot.send, param, absl::MakeSpan(w), + absl::MakeSpan(s_output), op); + } else { + MpfssSend(lctxs[0], cot.send, param, absl::MakeSpan(w), + absl::MakeSpan(s_output), op); + } + }); + + auto receiver = std::async([&] { + if (is_fixed) { + MpfssRecv_fixed_index(lctxs[1], cot.recv, param, absl::MakeSpan(r_output), + op); + } else { + MpfssRecv(lctxs[1], cot.recv, param, absl::MakeSpan(r_output), op); + } + }); + + sender.get(); + receiver.get(); + + std::set indexes; + for (size_t i = 0; i < param.noise_num_; ++i) { + indexes.insert(i * param.sp_vole_size_ + param.indexes_[i]); + } + + uint64_t j = 0; + uint64_t i = 0; + for (; i < num && j < index_num; ++i) { + if (s_output[i] != r_output[i]) { + EXPECT_EQ(w[j], op.sub(s_output[i], r_output[i])); + EXPECT_TRUE(indexes.count(i)); + j++; + } + } + for (; i < num; ++i) { + EXPECT_EQ(s_output[i], r_output[i]); + } + EXPECT_EQ(j, index_num); +} + +TEST_P(Mpfss128Test, Work) { + auto lctxs = link::test::SetupWorld(2); // setup network + const auto op = CreateMpfssOp(std::get<0>(GetParam())); + const auto is_fixed = std::get<1>(GetParam()); + const uint64_t num = std::get<2>(GetParam()).num; + const uint64_t index_num = std::get<2>(GetParam()).index_num; + + MpFssParam param(index_num, num); + param.GenIndexes(); + + auto choices = RandBits>(param.require_ot_num_); + if (is_fixed) { + choices = param.GenChoices(); + } + auto cot = MockCots(param.require_ot_num_, FastRandU128(), choices); + + std::vector s_output(num); + std::vector r_output(num); + + auto w = RandVec(index_num); + + auto sender = std::async([&] { + if (is_fixed) { + MpfssSend_fixed_index(lctxs[0], cot.send, param, absl::MakeSpan(w), + absl::MakeSpan(s_output), op); + } else { + MpfssSend(lctxs[0], cot.send, param, absl::MakeSpan(w), + absl::MakeSpan(s_output), op); + } + }); + + auto receiver = std::async([&] { + if (is_fixed) { + MpfssRecv_fixed_index(lctxs[1], cot.recv, param, absl::MakeSpan(r_output), + op); + } else { + MpfssRecv(lctxs[1], cot.recv, param, absl::MakeSpan(r_output), op); + } + }); + + sender.get(); + receiver.get(); + + std::set indexes; + for (size_t i = 0; i < param.noise_num_; ++i) { + indexes.insert(i * param.sp_vole_size_ + param.indexes_[i]); + } + + uint64_t j = 0; + uint64_t i = 0; + for (; i < num && j < index_num; ++i) { + if (s_output[i] != r_output[i]) { + EXPECT_EQ(w[j], op.sub(s_output[i], r_output[i])); + EXPECT_TRUE(indexes.count(i)); + j++; + } + } + for (; i < num; ++i) { + EXPECT_EQ(s_output[i], r_output[i]); + } + EXPECT_EQ(j, index_num); +} + +INSTANTIATE_TEST_SUITE_P( + VoleInternal, Mpfss64Test, + testing::Combine( + testing::Values(true, // true for xor_mode, GF(2^64) + false // false for add_mode, Ring(2^64) + ), + testing::Values(true, // true for fix index (determined by OT) + false // false for selected index + ), + testing::Values(TestParam{4, 2}, // edge + TestParam{5, 2}, // edge + TestParam{7, 2}, // edge + TestParam{1 << 8, 64}, TestParam{1 << 10, 257}, + TestParam{1 << 20, 1024})), + [](const testing::TestParamInfo& p) { + return fmt::format("{}_{}_t{}xn{}", + std::get<1>(p.param) ? "FixedIndex" : "SelectedIndex", + std::get<0>(p.param) ? "XOR" : "ADD", + std::get<2>(p.param).index_num, + std::get<2>(p.param).num); + }); + +INSTANTIATE_TEST_SUITE_P( + VoleInternal, Mpfss128Test, + testing::Combine( + testing::Values(true, // true for xor_mode, GF(2^128) + false // false for add_mode, Ring(2^128) + ), + testing::Values(true, // true for fix index (determined by OT) + false // false for selected index + ), + testing::Values(TestParam{4, 2}, // edge + TestParam{5, 2}, // edge + TestParam{7, 2}, // edge + TestParam{1 << 8, 64}, TestParam{1 << 10, 257}, + TestParam{1 << 20, 1024})), + [](const testing::TestParamInfo& p) { + return fmt::format("{}_{}_t{}xn{}", + std::get<1>(p.param) ? "FixedIndex" : "SelectedIndex", + std::get<0>(p.param) ? "XOR" : "ADD", + std::get<2>(p.param).index_num, + std::get<2>(p.param).num); + }); + +} // namespace yacl::crypto diff --git a/yacl/crypto/primitives/ot/sgrr_ote.cc b/yacl/crypto/primitives/ot/sgrr_ote.cc index 0ead682d..20e02e71 100644 --- a/yacl/crypto/primitives/ot/sgrr_ote.cc +++ b/yacl/crypto/primitives/ot/sgrr_ote.cc @@ -45,17 +45,6 @@ const std::array kPrfKey = {AES_set_encrypt_key(0), // https://github.com/emp-toolkit/emp-ot/blob/master/emp-ot/ferret/twokeyprp.h // -// std::array SplitSeed(const std::array& keys, -// uint128_t seed) { -// std::array tmp = {seed, seed}; -// // Uncomment the following if you want to use CrHash: -// // kDefaltRp.Gen({seed ^ 1, seed ^ 2}, absl::MakeSpan(tmp)); - -// // Use two-key prf -// ParaEnc<2, 1>(tmp.data(), keys.data()); -// return {tmp[0] ^ seed, tmp[1] ^ seed}; -// } - inline dynamic_bitset MakeDynamicBitset(uint128_t input, size_t bits) { dynamic_bitset out; @@ -92,6 +81,62 @@ std::vector SplitAllSeeds(absl::Span seeds) { return out; } +struct CheckMsg { + std::array t; + std::array s; + + void Pack(absl::Span out) { + YACL_ENFORCE(out.size() >= 64); + memcpy(out.data(), t.data(), 32); + memcpy(out.data() + 32, s.data(), 32); + } + + Buffer Pack() { + auto ret = Buffer(64); + Pack(absl::MakeSpan(ret.data(), ret.size())); + return ret; + } + + void Unpack(ByteContainerView in) { + YACL_ENFORCE(in.size() == 64); + memcpy(t.data(), in.data(), 32); + memcpy(s.data(), in.data() + 32, 32); + } +}; + +CheckMsg GenCheckMsg(uint32_t n, absl::Span output) { + auto t = std::array(); + + std::vector> tmp; + for (uint32_t i = 0; i < n; ++i) { + tmp.emplace_back(Blake3(ByteContainerView(&output[i], sizeof(uint128_t)))); + // t = t xor tmp + std::transform(tmp[i].cbegin(), tmp[i].cend(), t.cbegin(), t.begin(), + std::bit_xor()); + } + auto s = Blake3(ByteContainerView(tmp.data(), tmp.size() * 32)); + return {t, s}; +} + +bool VerifyCheckMsg(uint32_t n, uint32_t index, absl::Span output, + const CheckMsg& proof) { + auto t = proof.t; + auto& s = proof.s; + + std::vector> tmp; + for (uint32_t i = 0; i < n; ++i) { + tmp.emplace_back(Blake3(ByteContainerView(&output[i], sizeof(uint128_t)))); + // t = t xor tmp + std::transform(tmp[i].cbegin(), tmp[i].cend(), t.cbegin(), t.begin(), + std::bit_xor()); + } + std::transform(t.cbegin(), t.cend(), tmp[index].cbegin(), tmp[index].begin(), + std::bit_xor()); + + auto hash = Blake3(ByteContainerView(tmp.data(), tmp.size() * 32)); + return ByteContainerView(hash) == ByteContainerView(s); +} + } // namespace void SgrrOtExtRecv(const std::shared_ptr& ctx, @@ -156,33 +201,13 @@ void SgrrOtExtRecv(const std::shared_ptr& ctx, // check consistency if (mal) { - size_t size = n; - - std::vector> s; - std::array, 2> t = {}; // set zeros - - for (size_t i = 0; i < size; ++i) { - s.emplace_back(Blake3(ByteContainerView(&output[i], sizeof(uint128_t)))); - // t[0] = t[0] xor s[i] - std::transform(s[i].cbegin(), s[i].cend(), t[0].cbegin(), t[0].begin(), - std::bit_xor()); - } - // t[0] = t[0] xor s[index] - std::transform(s[index].cbegin(), s[index].cend(), t[0].cbegin(), - t[0].begin(), std::bit_xor()); - - auto buff = ctx->Recv(ctx->NextRank(), "SGRR_OTE:RECV-PROOF"); - YACL_ENFORCE(buff.size() == 64); - std::array, 2> recv_t; - memcpy(recv_t.data(), buff.data(), buff.size()); - - // s[index] = t[0] xor recv_t[index] - std::transform(recv_t[0].cbegin(), recv_t[0].cend(), t[0].cbegin(), - s[index].begin(), std::bit_xor()); - - t[1] = Blake3(ByteContainerView(s.data(), s.size() * 32)); - YACL_ENFORCE(ByteContainerView(t[1]) == ByteContainerView(recv_t[1])); + auto recv_buf = ctx->Recv(ctx->NextRank(), "SGRR:PROOF"); + YACL_ENFORCE(recv_buf.size() == 64); + CheckMsg proof; + proof.Unpack(recv_buf); + YACL_ENFORCE(VerifyCheckMsg(n, index, output, proof), + "Malicious SgrrOtExt Consistency check: fail!"); // refresh output ParaCrHashInplace_128(output.subspan(0, n)); output[index] = 0; @@ -233,61 +258,54 @@ void SgrrOtExtSend(const std::shared_ptr& ctx, // prove consistency if (mal) { - size_t size = n; - std::vector> s; - std::array, 2> t = {}; // set zeros - for (size_t i = 0; i < size; ++i) { - s.emplace_back(Blake3(ByteContainerView(&output[i], sizeof(uint128_t)))); - // t[0] = t[0] xor s[i] - std::transform(s[i].cbegin(), s[i].cend(), t[0].cbegin(), t[0].begin(), - std::bit_xor()); - } - t[1] = Blake3(ByteContainerView(s.data(), s.size() * 32)); - ctx->SendAsync(ctx->NextRank(), ByteContainerView(t.data(), 64), - "SGRR_OTE:SEND-PROOF"); - // Refresh output + auto proof = GenCheckMsg(n, output); + ctx->SendAsync(ctx->NextRank(), proof.Pack(), "SGRR:PROOF"); + // refresh output ParaCrHashInplace_128(output.subspan(0, n)); } } -// Notice that: -// > In such case, punctured index would be the choice of cot -// > punctured index might be greater than n -// So, please do NOT use "FixIndexSgrrOtExtRecv" and "FixIndexSgrrOtExtSend", -// unless you are certainly sure how do these algorithms work. +// Notice that: In such case, punctured index would be the choice of cot, which +// means punctured index might be greater than n. So, please do NOT use +// "FixIndexSgrrOtExtRecv" and "FixIndexSgrrOtExtSend", unless you are certainly +// sure how do these algorithms work. void SgrrOtExtRecv_fixed_index(const std::shared_ptr& ctx, const OtRecvStore& base_ot, uint32_t n, - absl::Span output) { - uint32_t ot_num = math::Log2Ceil(n); + absl::Span output, bool mal) { + const uint64_t buf_size = SgrrOtExtHelper(n, mal); auto recv_buf = ctx->Recv(ctx->NextRank(), "SGRR_OTE:RECV-CORR"); - YACL_ENFORCE(recv_buf.size() >= - static_cast(ot_num * 2 * sizeof(uint128_t))); - auto recv_msgs = absl::MakeSpan( - reinterpret_cast*>(recv_buf.data()), ot_num); - SgrrOtExtRecv_fixed_index(base_ot, n, output, absl::MakeSpan(recv_msgs)); + YACL_ENFORCE_EQ(static_cast(recv_buf.size()), buf_size); + SgrrOtExtRecv_fixed_index( + base_ot, n, output, + absl::MakeSpan(recv_buf.data(), buf_size), mal); } void SgrrOtExtSend_fixed_index(const std::shared_ptr& ctx, const OtSendStore& base_ot, uint32_t n, - absl::Span output) { - uint32_t ot_num = math::Log2Ceil(n); - std::vector> send_msgs(ot_num); - SgrrOtExtSend_fixed_index(base_ot, n, output, absl::MakeSpan(send_msgs)); - - ctx->SendAsync( - ctx->NextRank(), - ByteContainerView(send_msgs.data(), ot_num * 2 * sizeof(uint128_t)), - "SGRR_OTE:SEND-CORR"); + absl::Span output, bool mal) { + const uint64_t buf_size = SgrrOtExtHelper(n, mal); + auto send_buf = Buffer(buf_size); + SgrrOtExtSend_fixed_index(base_ot, n, output, + absl::MakeSpan(send_buf.data(), buf_size), + mal); + + ctx->SendAsync(ctx->NextRank(), ByteContainerView(send_buf), + "SGRR_OTE:SEND-CORR"); } void SgrrOtExtRecv_fixed_index(const OtRecvStore& base_ot, uint32_t n, absl::Span output, - absl::Span> recv_msgs) { - uint32_t ot_num = math::Log2Ceil(n); + absl::Span recv_buf, bool mal) { + const uint32_t ot_num = math::Log2Ceil(n); + const uint64_t buf_size = SgrrOtExtHelper(n, mal); YACL_ENFORCE_GE(n, (uint32_t)1); // range should > 1 YACL_ENFORCE_GE((uint32_t)128, base_ot.Size()); // base ot num < 128 YACL_ENFORCE_GE(base_ot.Size(), ot_num); // - YACL_ENFORCE_GE(recv_msgs.size(), ot_num); + YACL_ENFORCE_EQ(static_cast(recv_buf.size()), buf_size); + + auto recv_msgs = absl::MakeConstSpan( + reinterpret_cast*>(recv_buf.data()), + ot_num); // we need log(n) 1-2 OTs from log(n) ROTs // most significant bit first @@ -323,20 +341,36 @@ void SgrrOtExtRecv_fixed_index(const OtRecvStore& base_ot, uint32_t n, output[inserted_idx] = insert_val; } } + + if (mal) { + auto index = GetPuncturedIndex(choice, ot_num - 1); + CheckMsg proof; + proof.Unpack(absl::MakeConstSpan(recv_buf.data() + buf_size - 64, 64)); + + YACL_ENFORCE(VerifyCheckMsg(n, index, output, proof), + "Malicious SgrrOtExt Consistency check: fail!"); + // refresh output + ParaCrHashInplace_128(output.subspan(0, n)); + output[index] = 0; + } } void SgrrOtExtSend_fixed_index(const OtSendStore& base_ot, uint32_t n, absl::Span output, - absl::Span> send_msgs) { - uint32_t ot_num = math::Log2Ceil(n); + absl::Span send_buf, bool mal) { + const uint32_t ot_num = math::Log2Ceil(n); + const uint64_t buf_size = SgrrOtExtHelper(n, mal); YACL_ENFORCE_GE(base_ot.Size(), ot_num); YACL_ENFORCE_GE(n, (uint32_t)1); - YACL_ENFORCE_GE(send_msgs.size(), ot_num); + YACL_ENFORCE_EQ(static_cast(send_buf.size()), buf_size); output[0] = SecureRandSeed(); - + auto send_msgs = absl::MakeSpan( + reinterpret_cast*>(send_buf.data()), ot_num); // generate the final level seeds based on master_seed for (uint32_t i = 0; i < ot_num; ++i) { + send_msgs[i][0] = base_ot.GetBlock(i, 1); + send_msgs[i][1] = base_ot.GetBlock(i, 0); // for each seeds in level i const uint32_t iter_num = 1 << i; auto splits = SplitAllSeeds(output.subspan(0, iter_num)); @@ -350,10 +384,11 @@ void SgrrOtExtSend_fixed_index(const OtSendStore& base_ot, uint32_t n, std::min(2 * iter_num, n) * sizeof(uint128_t)); } - // mask the ROT messages and send back - for (uint32_t i = 0; i < ot_num; ++i) { - send_msgs[i][0] ^= base_ot.GetBlock(i, 1); - send_msgs[i][1] ^= base_ot.GetBlock(i, 0); + if (mal) { + auto proof = GenCheckMsg(n, output); + proof.Pack(absl::MakeSpan(send_buf.data() + buf_size - 64, 64)); + // refresh output + ParaCrHashInplace_128(output.subspan(0, n)); } } diff --git a/yacl/crypto/primitives/ot/sgrr_ote.h b/yacl/crypto/primitives/ot/sgrr_ote.h index da9bb54b..31f04e85 100644 --- a/yacl/crypto/primitives/ot/sgrr_ote.h +++ b/yacl/crypto/primitives/ot/sgrr_ote.h @@ -38,10 +38,12 @@ namespace yacl::crypto { // Implementation of (n-1)-out-of-n Random OT (also called oblivious punctured // vector), paper: https://eprint.iacr.org/2019/1084. // -// This implementation requires at least n pre-generated Random OTs, and outputs -// n/n-1 64 bits seeds (but we are defining it as 128 bits), also, currently n -// should be 2^i, in test, we use n = 2^5, 2^10 ,2^15, plus n needs to at least -// be 4 +// This implementation requires at least log2(n) pre-generated Random OTs, and +// outputs n/n-1 64 bits seeds (but we are defining it as 128 bits), also, +// currently n should be 2^i, in test, we use n = 2^5, 2^10 ,2^15, plus n needs +// to at least be 4. +// We adopt the newer consistency check of Softspoken, see +// https://eprint.iacr.org/2022/192.pdf Fig.14 for more detail. // // Does the size in bits matter when seeding a pseudo-random number generator? // The rationale behind this is that a PRG's seed is understood as (some kind @@ -52,9 +54,15 @@ namespace yacl::crypto { // // Therefore, if we want 128-bit security, we can set seed length = 128. // +// Security assumptions: +// - Correlation-robust Hash, but here we use two-key PRF with AES key +// scheduling to optimize CrHash, see yacl/crypto/base/aes/aes_opt.h for more +// details. +// // Some Discussions in the community: -// https://crypto.stackexchange.com/questions/38039 -// https://stackoverflow.com/questions/50402168 +// - https://crypto.stackexchange.com/questions/38039 +// - https://stackoverflow.com/questions/50402168 +// void SgrrOtExtRecv(const std::shared_ptr& ctx, const OtRecvStore& base_ot, uint32_t n, uint32_t index, @@ -69,9 +77,21 @@ void SgrrOtExtSend(const std::shared_ptr& ctx, // Customized // -------------------------- // + +// SgrrOtExtHelper would return the size of Buffer used in +// `SgrrOtExtRecv_fixed_index` and `SgrrOtExtRecv_fixed_index`. +uint64_t inline SgrrOtExtHelper(uint32_t n, bool mal = false) { + const uint32_t ot_num = math::Log2Ceil(n); + const uint64_t ot_msg_size = ot_num * sizeof(uint128_t) * 2; + const uint64_t check_size = (mal ? 32 * 2 : 0); + return ot_msg_size + check_size; +} + // Notice that: -// > In such cases, punctured index would be the choice of cot -// > punctured index might be greater than n +// > In such cases, punctured index would be the choice of cot, which means +// punctured index might be greater than n. +// > Before call `SgrrOtExtRecv_fixed_index` and `SgrrOtExtSend_fixed_index`, +// it would be better to get Buffer's size by invoking `SgrrOtExtHelper`. void SgrrOtExtRecv_fixed_index(const std::shared_ptr& ctx, const OtRecvStore& base_ot, uint32_t n, absl::Span output, bool mal = false); @@ -81,13 +101,16 @@ void SgrrOtExtSend_fixed_index(const std::shared_ptr& ctx, absl::Span output, bool mal = false); // non-interactive function, Receiver should receive "recv_msgs" from Sender +// TODO: use `ByteContainerView` instead. void SgrrOtExtRecv_fixed_index(const OtRecvStore& base_ot, uint32_t n, absl::Span output, - absl::Span> recv_msg); + absl::Span recv_buf, + bool mal = false); // non-interactive function, Sender should send "send_msg" to Receiver +// TODO: void SgrrOtExtSend_fixed_index(const OtSendStore& base_ot, uint32_t n, absl::Span output, - absl::Span> send_msg); + absl::Span send_buf, bool mal = false); } // namespace yacl::crypto diff --git a/yacl/crypto/primitives/ot/sgrr_ote_test.cc b/yacl/crypto/primitives/ot/sgrr_ote_test.cc index e6ffa20a..6eed738e 100644 --- a/yacl/crypto/primitives/ot/sgrr_ote_test.cc +++ b/yacl/crypto/primitives/ot/sgrr_ote_test.cc @@ -45,11 +45,11 @@ TEST_P(SgrrParamTest, SemiHonestWorks) { std::vector send_out(n); std::vector recv_out(n); - std::future sender = std::async([&] { + std::future receiver = std::async([&] { SgrrOtExtRecv(lctxs[0], std::move(base_ot.recv), n, index, absl::MakeSpan(recv_out), false); }); - std::future receiver = std::async([&] { + std::future sender = std::async([&] { SgrrOtExtSend(lctxs[1], std::move(base_ot.send), n, absl::MakeSpan(send_out), false); }); @@ -79,11 +79,11 @@ TEST_P(SgrrParamTest, MaliciousWorks) { std::vector send_out(n); std::vector recv_out(n); - std::future sender = std::async([&] { + std::future receiver = std::async([&] { SgrrOtExtRecv(lctxs[0], std::move(base_ot.recv), n, index, absl::MakeSpan(recv_out), true); }); - std::future receiver = std::async([&] { + std::future sender = std::async([&] { SgrrOtExtSend(lctxs[1], std::move(base_ot.send), n, absl::MakeSpan(send_out), true); }); @@ -103,6 +103,88 @@ TEST_P(SgrrParamTest, MaliciousWorks) { } } +TEST_P(SgrrParamTest, SemiHonestFixedIndextWorks) { + size_t n = GetParam().n; + + auto lctxs = link::test::SetupWorld(2); + auto ot_num = math::Log2Ceil(n); + auto index = RandInRange(n); + dynamic_bitset choices; + choices.append(index); + choices.resize(ot_num); + auto base_ot = MockRots(ot_num, choices); // mock many base OTs + + // SPDLOG_INFO("index is {}", index); + + std::vector send_out(n); + std::vector recv_out(n); + + std::future receiver = std::async([&] { + SgrrOtExtRecv_fixed_index(lctxs[0], std::move(base_ot.recv), n, + absl::MakeSpan(recv_out)); + }); + std::future sender = std::async([&] { + SgrrOtExtSend_fixed_index(lctxs[1], std::move(base_ot.send), n, + absl::MakeSpan(send_out)); + }); + sender.get(); + receiver.get(); + + for (size_t i = 0; i < n; ++i) { + if (index != i) { + EXPECT_NE(recv_out[i], 0); + EXPECT_EQ(send_out[i], recv_out[i]); + } else { + EXPECT_EQ(0, recv_out[i]); + } + } +} + +TEST_P(SgrrParamTest, MaliciousFixedIndextWorks) { + size_t n = GetParam().n; + + auto lctxs = link::test::SetupWorld(2); + auto ot_num = math::Log2Ceil(n); + auto index = RandInRange(n); + dynamic_bitset choices; + choices.append(index); + choices.resize(ot_num); + auto base_ot = MockRots(ot_num, choices); // mock many base OTs + + // SPDLOG_INFO("index is {}", index); + + std::vector send_out(n); + std::vector recv_out(n); + + std::future receiver = std::async([&] { + auto recv_buf = lctxs[0]->Recv(lctxs[0]->NextRank(), "SGRR_OTE:RECV-CORR"); + YACL_ENFORCE(recv_buf.size() == + static_cast(SgrrOtExtHelper(n, true))); + SgrrOtExtRecv_fixed_index( + std::move(base_ot.recv), n, absl::MakeSpan(recv_out), + absl::MakeSpan(recv_buf.data(), recv_buf.size()), true); + }); + std::future sender = std::async([&] { + auto send_buf = Buffer(SgrrOtExtHelper(n, true)); + SgrrOtExtSend_fixed_index( + std::move(base_ot.send), n, absl::MakeSpan(send_out), + absl::MakeSpan(send_buf.data(), send_buf.size()), true); + lctxs[1]->SendAsync(lctxs[1]->NextRank(), ByteContainerView(send_buf), + "SGRR_OTE:SEND-CORR"); + }); + sender.get(); + receiver.get(); + + for (size_t i = 0; i < n; ++i) { + if (index != i) { + EXPECT_NE(recv_out[i], 0); + EXPECT_EQ(send_out[i], recv_out[i]); + } else { + EXPECT_EQ(0, recv_out[i]); + } + } +} + INSTANTIATE_TEST_SUITE_P(Works_Instances, SgrrParamTest, testing::Values(TestParams{4}, TestParams{5}, // TestParams{7}, // diff --git a/yacl/crypto/primitives/ot/softspoken_ote.cc b/yacl/crypto/primitives/ot/softspoken_ote.cc index f27a784e..a8efc1cd 100644 --- a/yacl/crypto/primitives/ot/softspoken_ote.cc +++ b/yacl/crypto/primitives/ot/softspoken_ote.cc @@ -235,6 +235,11 @@ void SoftspokenOtExtSender::OneTimeSetup( // set delta delta_ = base_ot.CopyChoice().data()[0]; + auto recv_size = 128 * 2 * sizeof(uint128_t) + pprf_num_ * (mal_ ? 64 : 0); + auto recv_buf = ctx->Recv(ctx->NextRank(), "SGRR_OTE:RECV-CORR"); + YACL_ENFORCE((uint64_t)recv_buf.size() == recv_size); + auto recv_span = absl::MakeSpan((recv_buf.data()), recv_size); + auto single_buf_size = SgrrOtExtHelper(pprf_range_, mal_); // One-time Setup for Softspoken // k 1-out-of-2 ROT to (2^k-1)-out-of-(2^k) ROT for (uint64_t i = 0; i < pprf_num_; ++i) { @@ -248,8 +253,11 @@ void SoftspokenOtExtSender::OneTimeSetup( // punctured leaves for the i-th pprf auto leaves = absl::MakeSpan(punctured_leaves_.data() + i * pprf_range_, range_limit); - SgrrOtExtRecv(ctx, sub_ot, range_limit, punctured_idx_[i], leaves, mal_); - + // prepare for cur_recv_buf + auto cur_buf_size = SgrrOtExtHelper(range_limit, mal_); + auto cur_recv_buf = recv_span.subspan(single_buf_size * i, cur_buf_size); + // SgrrOtExt + SgrrOtExtRecv_fixed_index(sub_ot, range_limit, leaves, cur_recv_buf, mal_); // if the j-th bit of punctured index is 1, set mask as all one; // set mask as all zero otherwise. for (uint64_t j = 0; j < k_limit; ++j) { @@ -291,6 +299,11 @@ void SoftspokenOtExtReceiver::OneTimeSetup( } // FIXME: Copy base_ot, since NextSlice is not const auto dup_base_ot = base_ot; + // Send Message Buffer + auto send_size = 128 * 2 * sizeof(uint128_t) + pprf_num_ * (mal_ ? 64 : 0); + auto send_buf = Buffer(send_size); + auto send_span = absl::MakeSpan(send_buf.data(), send_size); + auto single_buf_size = SgrrOtExtHelper(pprf_range_, mal_); // One-time Setup for Softspoken // k 1-out-of-2 ROT to (2^k-1)-out-of-(2^k) ROT for (uint64_t i = 0; i < pprf_num_; ++i) { @@ -301,8 +314,14 @@ void SoftspokenOtExtReceiver::OneTimeSetup( // leaves in i-th pprf auto leaves = absl::MakeSpan(all_leaves_.data() + i * pprf_range_, range_limit); - SgrrOtExtSend(ctx, sub_ot, range_limit, leaves, mal_); + // prepare cur_send_buf + auto cur_buf_size = SgrrOtExtHelper(range_limit, mal_); + auto cur_send_span = send_span.subspan(i * single_buf_size, cur_buf_size); + // SgrrOtExt + SgrrOtExtSend_fixed_index(sub_ot, range_limit, leaves, cur_send_span, mal_); } + ctx->SendAsync(ctx->NextRank(), ByteContainerView(send_buf), + "SGRR_OTE:SEND-CORR"); inited_ = true; } diff --git a/yacl/crypto/primitives/vole/f2k/BUILD.bazel b/yacl/crypto/primitives/vole/BUILD.bazel similarity index 87% rename from yacl/crypto/primitives/vole/f2k/BUILD.bazel rename to yacl/crypto/primitives/vole/BUILD.bazel index a92ba422..63d7ea2a 100644 --- a/yacl/crypto/primitives/vole/f2k/BUILD.bazel +++ b/yacl/crypto/primitives/vole/BUILD.bazel @@ -29,7 +29,6 @@ yacl_cc_library( "//yacl/crypto/utils:rand", "//yacl/crypto/utils:secparam", "//yacl/math:gadget", - "//yacl/math/f2k", "//yacl/utils:serialize", ], ) @@ -48,18 +47,17 @@ yacl_cc_test( ) yacl_cc_library( - name = "sparse_vole", - srcs = ["sparse_vole.cc"], - hdrs = ["sparse_vole.h"], + name = "mp_vole", + hdrs = ["mp_vole.h"], copts = AES_COPT_FLAGS, deps = [ "//yacl/base:aligned_vector", "//yacl/base:dynamic_bitset", "//yacl/base:int128", - "//yacl/crypto/primitives/ot:ferret_ote", + "//yacl/crypto/base/hash:hash_utils", + "//yacl/crypto/primitives/dpf:mpfss", "//yacl/crypto/primitives/ot:ot_store", - "//yacl/crypto/primitives/ot:sgrr_ote", - "//yacl/crypto/primitives/ot:softspoken_ote", + "//yacl/crypto/tools:common", "//yacl/crypto/utils:rand", "//yacl/crypto/utils:secparam", "//yacl/math:gadget", @@ -69,11 +67,11 @@ yacl_cc_library( ) yacl_cc_test( - name = "sparse_vole_test", - srcs = ["sparse_vole_test.cc"], + name = "mp_vole_test", + srcs = ["mp_vole_test.cc"], copts = AES_COPT_FLAGS, deps = [ - ":sparse_vole", + ":mp_vole", "//yacl/crypto/utils:rand", "//yacl/link:test_util", "//yacl/math:gadget", @@ -88,19 +86,17 @@ yacl_cc_library( copts = AES_COPT_FLAGS, deps = [ ":base_vole", - ":sparse_vole", + ":mp_vole", "//yacl/base:aligned_vector", "//yacl/base:dynamic_bitset", "//yacl/base:int128", "//yacl/crypto/primitives/code:code_interface", "//yacl/crypto/primitives/code:ea_code", "//yacl/crypto/primitives/code:silver_code", - "//yacl/crypto/primitives/ot:ferret_ote", "//yacl/crypto/primitives/ot:ot_store", "//yacl/crypto/primitives/ot:softspoken_ote", "//yacl/crypto/utils:secparam", "//yacl/link:context", - "//yacl/math:gadget", ], ) @@ -126,7 +122,6 @@ yacl_cc_binary( deps = [ ":base_vole", ":silent_vole", - ":sparse_vole", "//yacl/crypto/utils:rand", "//yacl/link:test_util", "@com_github_google_benchmark//:benchmark_main", diff --git a/yacl/crypto/primitives/vole/f2k/base_vole.h b/yacl/crypto/primitives/vole/base_vole.h similarity index 86% rename from yacl/crypto/primitives/vole/f2k/base_vole.h rename to yacl/crypto/primitives/vole/base_vole.h index b42b294f..2aee5d94 100644 --- a/yacl/crypto/primitives/vole/f2k/base_vole.h +++ b/yacl/crypto/primitives/vole/base_vole.h @@ -14,37 +14,19 @@ #pragma once -#include -#include - #include "yacl/base/exception.h" #include "yacl/base/int128.h" -#include "yacl/crypto/primitives/ot/ot_store.h" -#include "yacl/crypto/utils/secparam.h" #include "yacl/math/f2k/f2k.h" #include "yacl/math/gadget.h" /* submodules */ +#include "yacl/crypto/primitives/ot/ot_store.h" #include "yacl/crypto/primitives/ot/softspoken_ote.h" -#include "yacl/crypto/utils/rand.h" +#include "yacl/crypto/utils/secparam.h" -/* security parameter declaration */ YACL_MODULE_DECLARE("base_vole", SecParam::C::INF, SecParam::S::INF); - namespace yacl::crypto { -namespace vole::internal { - -uint128_t inline GfMul(absl::Span a, absl::Span b) { - return GfMul128(a, b); -} - -uint64_t inline GfMul(absl::Span a, absl::Span b) { - return GfMul64(a, b); -} - -} // namespace vole::internal - // Convert OT to f2k-VOLE (non-interactive) // the type of ot_store must be COT // w = u * delta + v, where delta = send_ot.delta @@ -78,7 +60,7 @@ void inline Ot2VoleSend(OtSendStore& send_ot, absl::Span w) { for (size_t j = 0; j < T_bits; ++j) { w_buff[j] = send_ot.GetBlock(i * T_bits + j, 0); } - w[i] = vole::internal::GfMul(absl::MakeSpan(w_buff), absl::MakeSpan(basis)); + w[i] = math::GfMul(absl::MakeSpan(w_buff), absl::MakeSpan(basis)); } } @@ -109,7 +91,7 @@ void inline Ot2VoleRecv(OtRecvStore& recv_ot, absl::Span u, for (size_t j = 0; j < T_bits; ++j) { v_buff[j] = recv_ot.GetBlock(i * T_bits + j); } - v[i] = vole::internal::GfMul(absl::MakeSpan(v_buff), absl::MakeSpan(basis)); + v[i] = math::GfMul(absl::MakeSpan(v_buff), absl::MakeSpan(basis)); } } @@ -125,11 +107,12 @@ void inline Ot2VoleRecv(OtRecvStore& recv_ot, absl::Span u, // GilboaVoleSend / GilboaVoleRecv template void inline GilboaVoleSend(const std::shared_ptr& ctx, - const OtRecvStore& base_ot, absl::Span w) { + const OtRecvStore& base_ot, absl::Span w, + bool mal = false) { constexpr size_t T_bits = sizeof(T) * 8; const size_t size = w.size(); - auto sender = SoftspokenOtExtSender(2); + auto sender = SoftspokenOtExtSender(2, mal); // setup Softspoken by base_ot sender.OneTimeSetup(ctx, base_ot); auto send_ot = sender.GenCot(ctx, size * T_bits); @@ -139,16 +122,16 @@ void inline GilboaVoleSend(const std::shared_ptr& ctx, template void inline GilboaVoleRecv(const std::shared_ptr& ctx, const OtSendStore& base_ot, absl::Span u, - absl::Span v) { + absl::Span v, bool mal = false) { constexpr size_t T_bits = sizeof(T) * 8; const size_t size = u.size(); YACL_ENFORCE(size == v.size()); - auto receiver = SoftspokenOtExtReceiver(2); + auto receiver = SoftspokenOtExtReceiver(2, mal); // setup Softspoken by base_ot receiver.OneTimeSetup(ctx, base_ot); auto recv_ot = receiver.GenCot(ctx, size * T_bits); Ot2VoleRecv(recv_ot, u, v); } -} // namespace yacl::crypto \ No newline at end of file +} // namespace yacl::crypto diff --git a/yacl/crypto/primitives/vole/f2k/base_vole_test.cc b/yacl/crypto/primitives/vole/base_vole_test.cc similarity index 75% rename from yacl/crypto/primitives/vole/f2k/base_vole_test.cc rename to yacl/crypto/primitives/vole/base_vole_test.cc index 63d8cd63..3f30ce79 100644 --- a/yacl/crypto/primitives/vole/f2k/base_vole_test.cc +++ b/yacl/crypto/primitives/vole/base_vole_test.cc @@ -12,11 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "yacl/crypto/primitives/vole/f2k/base_vole.h" +#include "yacl/crypto/primitives/vole/base_vole.h" #include -#include #include #include #include @@ -31,11 +30,6 @@ namespace yacl::crypto { -namespace { -uint64_t GfMul(uint64_t lhs, uint64_t rhs) { return GfMul64(lhs, rhs); } -uint128_t GfMul(uint128_t lhs, uint128_t rhs) { return GfMul128(lhs, rhs); } -} // namespace - struct TestParams { size_t num; }; @@ -45,6 +39,9 @@ class BaseVoleTest : public ::testing::TestWithParam {}; using GF64 = uint64_t; using GF128 = uint128_t; +// Semi-honst / Malicious +enum SM : bool { Semi = false, Mal = true }; + #define DECLARE_OT2VOLE_TEST(type0, type1) \ TEST_P(BaseVoleTest, Ot2Vole_##type0##x##type1##_Work) { \ const uint64_t vole_num = GetParam().num; \ @@ -65,7 +62,7 @@ using GF128 = uint128_t; type1 delta = delta128; \ for (uint64_t i = 0; i < vole_num; ++i) { \ type1 ui = u[i]; \ - EXPECT_EQ(GfMul(ui, delta), w[i] ^ v[i]); \ + EXPECT_EQ(math::GfMul(ui, delta), w[i] ^ v[i]); \ } \ } @@ -73,36 +70,45 @@ DECLARE_OT2VOLE_TEST(GF64, GF64); // Vole: GF(2^64) x GF(2^64) DECLARE_OT2VOLE_TEST(GF64, GF128); // subfield Vole: GF(2^64) x GF(2^128) DECLARE_OT2VOLE_TEST(GF128, GF128); // Vole: GF(2^128) x GF(2^128) -#define DECLARE_GILBOAVOLE_TEST(type0, type1) \ - TEST_P(BaseVoleTest, GilboaVole_##type0##x##type1##_Work) { \ - auto lctxs = link::test::SetupWorld(2); \ - const uint64_t vole_num = GetParam().num; \ - auto rot = MockRots(128); \ - auto delta128 = rot.recv.CopyChoice().data()[0]; \ - std::vector u(vole_num); \ - std::vector v(vole_num); \ - std::vector w(vole_num); \ - auto sender = std::async([&] { \ - GilboaVoleSend(lctxs[0], rot.recv, absl::MakeSpan(w)); \ - }); \ - auto receiver = std::async([&] { \ - GilboaVoleRecv(lctxs[1], rot.send, absl::MakeSpan(u), \ - absl::MakeSpan(v)); \ - }); \ - sender.get(); \ - receiver.get(); \ - type1 delta = delta128; \ - for (uint64_t i = 0; i < vole_num; ++i) { \ - type1 ui = u[i]; \ - EXPECT_EQ(GfMul(ui, delta), w[i] ^ v[i]); \ - } \ +#define DECLARE_GILBOAVOLE_TEST(kase, type0, type1) \ + TEST_P(BaseVoleTest, kase##_GilboaVole_##type0##x##type1##_Work) { \ + auto lctxs = link::test::SetupWorld(2); \ + const uint64_t vole_num = GetParam().num; \ + auto rot = MockRots(128); \ + auto delta128 = rot.recv.CopyChoice().data()[0]; \ + std::vector u(vole_num); \ + std::vector v(vole_num); \ + std::vector w(vole_num); \ + auto sender = std::async([&] { \ + GilboaVoleSend(lctxs[0], rot.recv, absl::MakeSpan(w), \ + SM::kase); \ + }); \ + auto receiver = std::async([&] { \ + GilboaVoleRecv(lctxs[1], rot.send, absl::MakeSpan(u), \ + absl::MakeSpan(v), SM::kase); \ + }); \ + sender.get(); \ + receiver.get(); \ + type1 delta = delta128; \ + for (uint64_t i = 0; i < vole_num; ++i) { \ + type1 ui = u[i]; \ + EXPECT_EQ(math::GfMul(ui, delta), w[i] ^ v[i]); \ + } \ } -DECLARE_GILBOAVOLE_TEST(GF64, GF64); // Vole: GF(2^64) x GF(2^64) -DECLARE_GILBOAVOLE_TEST(GF64, GF128); // subfield Vole: GF(2^64) x GF(2^128) -DECLARE_GILBOAVOLE_TEST(GF128, GF128); // Vole: GF(2^128) x GF(2^128) +// Semi-honest Base Vole +DECLARE_GILBOAVOLE_TEST(Semi, GF64, GF64); // Vole: GF(2^64) x GF(2^64) +DECLARE_GILBOAVOLE_TEST(Semi, GF64, + GF128); // subfield Vole: GF(2^64) x GF(2^128) +DECLARE_GILBOAVOLE_TEST(Semi, GF128, GF128); // Vole: GF(2^128) x GF(2^128) + +// Malicious Base Vole +DECLARE_GILBOAVOLE_TEST(Mal, GF64, GF64); // Vole: GF(2^64) x GF(2^64) +DECLARE_GILBOAVOLE_TEST(Mal, GF64, + GF128); // subfield Vole: GF(2^64) x GF(2^128) +DECLARE_GILBOAVOLE_TEST(Mal, GF128, GF128); // Vole: GF(2^128) x GF(2^128) -INSTANTIATE_TEST_SUITE_P(Works_Instances, BaseVoleTest, +INSTANTIATE_TEST_SUITE_P(f2kVOLE, BaseVoleTest, testing::Values(TestParams{4}, TestParams{5}, // TestParams{7}, // TestParams{1 << 8}, diff --git a/yacl/crypto/primitives/vole/benchmark.cc b/yacl/crypto/primitives/vole/benchmark.cc new file mode 100644 index 00000000..4378721a --- /dev/null +++ b/yacl/crypto/primitives/vole/benchmark.cc @@ -0,0 +1,268 @@ +// Copyright 2022 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "benchmark/benchmark.h" + +#include +#include +#include + +#include "yacl/base/exception.h" +#include "yacl/crypto/primitives/vole/base_vole.h" +#include "yacl/crypto/primitives/vole/silent_vole.h" +#include "yacl/link/test_util.h" + +// +// bazel run //yacl/crypto/primitives/vole/f2k:benchmark -c opt -- +// --benchmark_counters_tabular=true +// +// User Counters: +// 1. recv would record the average time (ms) VoleReceiver is needed. +// 2. send would record the average time (ms) VoleSender is needed. + +namespace yacl::crypto { + +using GF64 = uint64_t; +using GF128 = uint128_t; + +namespace decorator { + +// return elapse time (ms) +template +std::future inline async(Function&& fn, Args&&... args) { + return std::async([fn, args = std::tuple(std::move(args)...)]() mutable { + auto start = std::chrono::high_resolution_clock::now(); + std::apply([fn](auto&&... args) { (fn)(std::move(args)...); }, + std::move(args)); + auto end = std::chrono::high_resolution_clock::now(); + auto second = + std::chrono::duration_cast>(end - start) + .count(); + return second * 1000; + }); +} + +} // namespace decorator + +// Wrapper +template +void SendWrapper(SilentVoleSender& sender, std::shared_ptr& lctx, + absl::Span c) { + sender.Send(lctx, c); +} + +template <> +void SendWrapper(SilentVoleSender& sender, + std::shared_ptr& lctx, + absl::Span c) { + sender.SfSend(lctx, c); +} + +template +void RecvWrapper(SilentVoleReceiver& receiver, + std::shared_ptr& lctx, absl::Span a, + absl::Span b) { + receiver.Recv(lctx, a, b); +} + +template <> +void RecvWrapper(SilentVoleReceiver& receiver, + std::shared_ptr& lctx, + absl::Span a, absl::Span b) { + receiver.SfRecv(lctx, a, b); +} + +class StaticLink { + public: + static std::vector> GetLinks() { + if (lctxs_.empty()) { + // lctxs_ = link::test::SetupBrpcWorld(2); + lctxs_ = link::test::SetupWorld(2); + } + return lctxs_; + } + static std::vector> lctxs_; +}; + +std::vector> StaticLink::lctxs_ = {}; + +// Gilboa Vole (Semi-honset/Mal, Type0 , Type1) +template +void GilboaVoleBench(benchmark::State& state, Args&&... args) { + auto lctxs = StaticLink::GetLinks(); + YACL_ENFORCE(lctxs.size() == 2); + + auto param = std::forward_as_tuple(args...); + bool mal = std::get<0>(param); + + using T = std::decay_t(param))>; + using K = std::decay_t(param))>; + for (auto _ : state) { + state.PauseTiming(); + { + const size_t num_vole = state.range(0); + auto rot = MockRots(128); + std::vector u(num_vole); + std::vector v(num_vole); + std::vector w(num_vole); + uint64_t send_byte = 0; + uint64_t recv_byte = 0; + send_byte -= lctxs[0]->GetStats()->sent_bytes; + recv_byte -= lctxs[0]->GetStats()->recv_bytes; + state.ResumeTiming(); + auto sender = decorator::async([&] { + GilboaVoleSend(lctxs[0], rot.recv, absl::MakeSpan(w), mal); + }); + auto receiver = decorator::async([&] { + GilboaVoleRecv(lctxs[1], rot.send, absl::MakeSpan(u), + absl::MakeSpan(v), mal); + }); + state.counters["send"] += sender.get(); + state.counters["recv"] += receiver.get(); + state.PauseTiming(); + send_byte += lctxs[0]->GetStats()->sent_bytes; + recv_byte += lctxs[0]->GetStats()->recv_bytes; + state.counters["send_byte"] += send_byte; + state.counters["recv_byte"] += recv_byte; + } + state.ResumeTiming(); + } + state.counters["send"] /= state.iterations(); + state.counters["recv"] /= state.iterations(); + state.counters["send_byte"] /= state.iterations(); + state.counters["recv_byte"] /= state.iterations(); +} + +// Silent Vole (Codetype, Semi-honset/Mal, Type0 , Type1) +template +void SilentVoleBench(benchmark::State& state, Args&&... args) { + auto lctxs = StaticLink::GetLinks(); + YACL_ENFORCE(lctxs.size() == 2); + auto param = std::forward_as_tuple(args...); + CodeType codetype = std::get<0>(param); + bool mal = std::get<1>(param); + + using T = std::decay_t(param))>; + using K = std::decay_t(param))>; + + for (auto _ : state) { + state.PauseTiming(); + { + const size_t num_vole = state.range(0); + std::vector a(num_vole); + std::vector b(num_vole); + std::vector c(num_vole); + auto sender_init = std::async([&] { + auto sender = SilentVoleSender(codetype, mal); + return sender; + }); + auto receiver_init = std::async([&] { + auto receiver = SilentVoleReceiver(codetype, mal); + return receiver; + }); + auto sender = sender_init.get(); + auto receiver = receiver_init.get(); + + uint64_t send_byte = 0; + uint64_t recv_byte = 0; + send_byte -= lctxs[0]->GetStats()->sent_bytes; + recv_byte -= lctxs[0]->GetStats()->recv_bytes; + + state.ResumeTiming(); + auto sender_task = decorator::async( + [&] { SendWrapper(sender, lctxs[0], absl::MakeSpan(c)); }); + auto receiver_task = decorator::async([&] { + RecvWrapper(receiver, lctxs[1], absl::MakeSpan(a), + absl::MakeSpan(b)); + }); + state.counters["send"] += sender_task.get(); + state.counters["recv"] += receiver_task.get(); + state.PauseTiming(); + send_byte += lctxs[0]->GetStats()->sent_bytes; + recv_byte += lctxs[0]->GetStats()->recv_bytes; + state.counters["send_byte"] += send_byte; + state.counters["recv_byte"] += recv_byte; + } + state.ResumeTiming(); + } + state.counters["send"] /= state.iterations(); + state.counters["recv"] /= state.iterations(); + state.counters["send_byte"] /= state.iterations(); + state.counters["recv_byte"] /= state.iterations(); +} + +#define Zero(name) name(0) +#define GF64 uint64_t +#define GF128 uint128_t + +enum SM : bool { Semi = false, Mal = true }; + +#define GILBOA_VOLE_BM_TEMPLATE(kase, type0, type1, Arguments) \ + BENCHMARK_CAPTURE(GilboaVoleBench, kase##_BaseVole_##type0##x##type1, \ + SM::kase, Zero(type0), Zero(type1)) \ + ->Apply(Arguments); + +#define SM_GILBOA_VOLE_BM_TEMPLATE(kase, Arguments) \ + GILBOA_VOLE_BM_TEMPLATE(kase, GF64, GF64, Arguments) \ + GILBOA_VOLE_BM_TEMPLATE(kase, GF64, GF128, Arguments) \ + GILBOA_VOLE_BM_TEMPLATE(kase, GF128, GF128, Arguments) + +#define DECLARE_GILBOA_VOLE_BM(Arguments) \ + SM_GILBOA_VOLE_BM_TEMPLATE(Semi, Arguments) \ + SM_GILBOA_VOLE_BM_TEMPLATE(Mal, Arguments) + +#define SILENT_VOLE_BM_TEMPLATE(Code, kase, type0, type1, Arguments) \ + BENCHMARK_CAPTURE(SilentVoleBench, kase##_##Code##_##type0##x##type1, \ + CodeType::Code, SM::kase, Zero(type0), Zero(type1)) \ + ->Apply(Arguments); + +#define SM_SILENT_VOLE_BM_TEMPLATE(Code, kase, Arguments) \ + SILENT_VOLE_BM_TEMPLATE(Code, kase, GF64, GF64, Arguments) \ + SILENT_VOLE_BM_TEMPLATE(Code, kase, GF64, GF128, Arguments) \ + SILENT_VOLE_BM_TEMPLATE(Code, kase, GF128, GF128, Arguments) + +#define DECLARE_SPECIFIC_SILENT_VOLE_BM(Code, Arguments) \ + SM_SILENT_VOLE_BM_TEMPLATE(Code, Semi, Arguments) \ + SM_SILENT_VOLE_BM_TEMPLATE(Code, Mal, Arguments) + +#define DECLARE_SILVER_VOLE_BM(Arguments) \ + DECLARE_SPECIFIC_SILENT_VOLE_BM(Silver5, Arguments) \ + DECLARE_SPECIFIC_SILENT_VOLE_BM(Silver11, Arguments) + +#define DECLARE_EXACC_VOLE_BM(Arguments) \ + DECLARE_SPECIFIC_SILENT_VOLE_BM(ExAcc7, Arguments) \ + DECLARE_SPECIFIC_SILENT_VOLE_BM(ExAcc11, Arguments) \ + DECLARE_SPECIFIC_SILENT_VOLE_BM(ExAcc21, Arguments) \ + DECLARE_SPECIFIC_SILENT_VOLE_BM(ExAcc40, Arguments) + +void BM_DefaultArguments(benchmark::internal::Benchmark* b) { + b->Arg(8192)->Unit(benchmark::kMillisecond); +} + +void BM_PerfArguments(benchmark::internal::Benchmark* b) { + b->Arg(1 << 18) + ->Arg(1 << 20) // 1048576, one million + ->Arg(1 << 22) + ->Arg(1 << 24) + ->Arg(10000000) // ten million + ->Arg(22437250) + ->Unit(benchmark::kMillisecond) + ->Iterations(10); +} + +DECLARE_GILBOA_VOLE_BM(BM_DefaultArguments) +DECLARE_SILVER_VOLE_BM(BM_PerfArguments) +DECLARE_EXACC_VOLE_BM(BM_PerfArguments) + +} // namespace yacl::crypto diff --git a/yacl/crypto/primitives/vole/f2k/benchmark.cc b/yacl/crypto/primitives/vole/f2k/benchmark.cc deleted file mode 100644 index fa512526..00000000 --- a/yacl/crypto/primitives/vole/f2k/benchmark.cc +++ /dev/null @@ -1,264 +0,0 @@ -// Copyright 2022 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "benchmark/benchmark.h" - -#include -#include -#include - -#include "yacl/base/exception.h" -#include "yacl/crypto/primitives/vole/f2k/base_vole.h" -#include "yacl/crypto/primitives/vole/f2k/silent_vole.h" -#include "yacl/crypto/primitives/vole/f2k/sparse_vole.h" -#include "yacl/link/test_util.h" - -// -// bazel run //yacl/crypto/primitives/vole/f2k:benchmark -c opt -- -// --benchmark_counters_tabular=true -// -// User Counters: -// 1. recv would record the average time (ms) VoleReceiver is needed. -// 2. send would record the average time (ms) VoleSender is needed. - -namespace yacl::crypto { - -using GF64 = uint64_t; -using GF128 = uint128_t; - -namespace decorator { - -// return elapse time (ms) -template -std::future inline async(Function&& fn, Args&&... args) { - return std::async([fn, args = std::tuple(std::move(args)...)]() mutable { - auto start = std::chrono::high_resolution_clock::now(); - std::apply([fn](auto&&... args) { (fn)(std::move(args)...); }, - std::move(args)); - auto end = std::chrono::high_resolution_clock::now(); - auto second = - std::chrono::duration_cast>(end - start) - .count(); - return second * 1000; - }); -} - -} // namespace decorator - -// VoleBench -class VoleBench : public benchmark::Fixture { - public: - void SetUp(const ::benchmark::State&) override { - if (lctxs_.empty()) { - // lctxs_ = link::test::SetupBrpcWorld(2); - lctxs_ = link::test::SetupWorld(2); - } - } - - static std::vector> lctxs_; -}; - -std::vector> VoleBench::lctxs_ = {}; - -#define DECLARE_GIBLOA_VOLE_BENCH(type0, type1) \ - BENCHMARK_DEFINE_F(VoleBench, GilboaVole_##type0##x##type1) \ - (benchmark::State & state) { \ - YACL_ENFORCE(lctxs_.size() == 2); \ - for (auto _ : state) { \ - state.PauseTiming(); \ - { \ - const size_t num_vole = state.range(0); \ - auto rot = MockRots(128); \ - std::vector u(num_vole); \ - std::vector v(num_vole); \ - std::vector w(num_vole); \ - state.ResumeTiming(); \ - auto sender = decorator::async([&] { \ - GilboaVoleSend(lctxs_[0], rot.recv, \ - absl::MakeSpan(w)); \ - }); \ - auto receiver = decorator::async([&] { \ - GilboaVoleRecv(lctxs_[1], rot.send, absl::MakeSpan(u), \ - absl::MakeSpan(v)); \ - }); \ - state.counters["send"] += sender.get(); \ - state.counters["recv"] += receiver.get(); \ - state.PauseTiming(); \ - } \ - state.ResumeTiming(); \ - } \ - state.counters["send"] /= state.iterations(); \ - state.counters["recv"] /= state.iterations(); \ - } - -DECLARE_GIBLOA_VOLE_BENCH(GF64, GF64); -DECLARE_GIBLOA_VOLE_BENCH(GF64, GF128); -DECLARE_GIBLOA_VOLE_BENCH(GF128, GF128); - -#define BM_REGISTER_GILBOA_VOLE(Arguments) \ - BENCHMARK_REGISTER_F(VoleBench, GilboaVole_GF64xGF64)->Apply(Arguments); \ - BENCHMARK_REGISTER_F(VoleBench, GilboaVole_GF64xGF128)->Apply(Arguments); \ - BENCHMARK_REGISTER_F(VoleBench, GilboaVole_GF128xGF128)->Apply(Arguments); - -#define DELCARE_SILENT_VOLE_BENCH(CODE, type) \ - BENCHMARK_DEFINE_F(VoleBench, CODE##Vole_##type)(benchmark::State & state) { \ - YACL_ENFORCE(lctxs_.size() == 2); \ - for (auto _ : state) { \ - state.PauseTiming(); \ - { \ - const size_t num_vole = state.range(0); \ - std::vector a(num_vole); \ - std::vector b(num_vole); \ - std::vector c(num_vole); \ - auto sender_init = std::async([&] { \ - auto sender = SilentVoleSender(CodeType::CODE); \ - /* Execute OneTime Setup */ \ - /* sender.OneTimeSetup(lctxs_[0]); */ \ - return sender; \ - }); \ - auto receiver_init = std::async([&] { \ - auto receiver = SilentVoleReceiver(CodeType::CODE); \ - /* Execute OneTime Setup */ \ - /* receiver.OneTimeSetup(lctxs_[1]); */ \ - return receiver; \ - }); \ - auto sender = sender_init.get(); \ - auto receiver = receiver_init.get(); \ - state.ResumeTiming(); \ - auto sender_task = decorator::async( \ - [&] { sender.Send(lctxs_[0], absl::MakeSpan(c)); }); \ - auto receiver_task = decorator::async([&] { \ - receiver.Recv(lctxs_[1], absl::MakeSpan(a), absl::MakeSpan(b)); \ - }); \ - state.counters["send"] += sender_task.get(); \ - state.counters["recv"] += receiver_task.get(); \ - state.PauseTiming(); \ - } \ - state.ResumeTiming(); \ - } \ - state.counters["send"] /= state.iterations(); \ - state.counters["recv"] /= state.iterations(); \ - } - -DELCARE_SILENT_VOLE_BENCH(Silver5, GF64); -DELCARE_SILENT_VOLE_BENCH(Silver11, GF64); -DELCARE_SILENT_VOLE_BENCH(ExAcc7, GF64); -DELCARE_SILENT_VOLE_BENCH(ExAcc11, GF64); -DELCARE_SILENT_VOLE_BENCH(ExAcc21, GF64); -DELCARE_SILENT_VOLE_BENCH(ExAcc40, GF64); - -DELCARE_SILENT_VOLE_BENCH(Silver5, GF128); -DELCARE_SILENT_VOLE_BENCH(Silver11, GF128); -DELCARE_SILENT_VOLE_BENCH(ExAcc7, GF128); -DELCARE_SILENT_VOLE_BENCH(ExAcc11, GF128); -DELCARE_SILENT_VOLE_BENCH(ExAcc21, GF128); -DELCARE_SILENT_VOLE_BENCH(ExAcc40, GF128); - -#define DELCARE_SILENT_SUBFIELDVOLE_BENCH(CODE) \ - BENCHMARK_DEFINE_F(VoleBench, CODE##SubfieldVole) \ - (benchmark::State & state) { \ - YACL_ENFORCE(lctxs_.size() == 2); \ - for (auto _ : state) { \ - state.PauseTiming(); \ - { \ - const size_t num_vole = state.range(0); \ - std::vector a(num_vole); \ - std::vector b(num_vole); \ - std::vector c(num_vole); \ - auto sender_init = std::async([&] { \ - auto sender = SilentVoleSender(CodeType::CODE); \ - /* Execute OneTime Setup */ \ - /* sender.OneTimeSetup(lctxs_[0]); */ \ - return sender; \ - }); \ - auto receiver_init = std::async([&] { \ - auto receiver = SilentVoleReceiver(CodeType::CODE); \ - /* Execute OneTime Setup */ \ - /* receiver.OneTimeSetup(lctxs_[1]); */ \ - return receiver; \ - }); \ - auto sender = sender_init.get(); \ - auto receiver = receiver_init.get(); \ - state.ResumeTiming(); \ - auto sender_task = decorator::async( \ - [&] { sender.SfSend(lctxs_[0], absl::MakeSpan(c)); }); \ - auto receiver_task = decorator::async([&] { \ - receiver.SfRecv(lctxs_[1], absl::MakeSpan(a), absl::MakeSpan(b)); \ - }); \ - state.counters["send"] += sender_task.get(); \ - state.counters["recv"] += receiver_task.get(); \ - state.PauseTiming(); \ - } \ - state.ResumeTiming(); \ - } \ - state.counters["send"] /= state.iterations(); \ - state.counters["recv"] /= state.iterations(); \ - } - -DELCARE_SILENT_SUBFIELDVOLE_BENCH(Silver5); -DELCARE_SILENT_SUBFIELDVOLE_BENCH(Silver11); -DELCARE_SILENT_SUBFIELDVOLE_BENCH(ExAcc7); -DELCARE_SILENT_SUBFIELDVOLE_BENCH(ExAcc11); -DELCARE_SILENT_SUBFIELDVOLE_BENCH(ExAcc21); -DELCARE_SILENT_SUBFIELDVOLE_BENCH(ExAcc40); - -#define BM_REGISTER_SILVER_SILENT_VOLE(Arguments) \ - BENCHMARK_REGISTER_F(VoleBench, Silver5Vole_GF64)->Apply(Arguments); \ - BENCHMARK_REGISTER_F(VoleBench, Silver5Vole_GF128)->Apply(Arguments); \ - BENCHMARK_REGISTER_F(VoleBench, Silver5SubfieldVole)->Apply(Arguments); \ - BENCHMARK_REGISTER_F(VoleBench, Silver11Vole_GF64)->Apply(Arguments); \ - BENCHMARK_REGISTER_F(VoleBench, Silver11Vole_GF128)->Apply(Arguments); \ - BENCHMARK_REGISTER_F(VoleBench, Silver11SubfieldVole)->Apply(Arguments); - -#define BM_REGISTER_EXACC_SILENT_VOLE(Arguments) \ - BENCHMARK_REGISTER_F(VoleBench, ExAcc7Vole_GF64)->Apply(Arguments); \ - BENCHMARK_REGISTER_F(VoleBench, ExAcc7Vole_GF128)->Apply(Arguments); \ - BENCHMARK_REGISTER_F(VoleBench, ExAcc7SubfieldVole)->Apply(Arguments); \ - BENCHMARK_REGISTER_F(VoleBench, ExAcc11Vole_GF64)->Apply(Arguments); \ - BENCHMARK_REGISTER_F(VoleBench, ExAcc11Vole_GF128)->Apply(Arguments); \ - BENCHMARK_REGISTER_F(VoleBench, ExAcc11SubfieldVole)->Apply(Arguments); \ - BENCHMARK_REGISTER_F(VoleBench, ExAcc21Vole_GF64)->Apply(Arguments); \ - BENCHMARK_REGISTER_F(VoleBench, ExAcc21Vole_GF128)->Apply(Arguments); \ - BENCHMARK_REGISTER_F(VoleBench, ExAcc21SubfieldVole)->Apply(Arguments); \ - BENCHMARK_REGISTER_F(VoleBench, ExAcc40Vole_GF64)->Apply(Arguments); \ - BENCHMARK_REGISTER_F(VoleBench, ExAcc40Vole_GF128)->Apply(Arguments); \ - BENCHMARK_REGISTER_F(VoleBench, ExAcc40SubfieldVole)->Apply(Arguments); - -#define BM_REGISTER_ALL_VOLE(Arguments) \ - BM_REGISTER_GILBOA_VOLE(Arguments); \ - BM_REGISTER_SILVER_SILENT_VOLE(Arguments); \ - BM_REGISTER_EXACC_SILENT_VOLE(Arguments); - -void BM_DefaultArguments(benchmark::internal::Benchmark* b) { - b->Arg(8192)->Unit(benchmark::kMillisecond); -} - -void BM_PerfArguments(benchmark::internal::Benchmark* b) { - b->Arg(1 << 18) - ->Arg(1 << 20) // 1048576, one million - ->Arg(1 << 22) - ->Arg(1 << 24) - ->Arg(10000000) // ten million - ->Arg(22437250) - ->Unit(benchmark::kMillisecond) - ->Iterations(10); -} - -// BM_REGISTER_ALL_VOLE(BM_DefaultArguments); - -BM_REGISTER_GILBOA_VOLE(BM_DefaultArguments); -BM_REGISTER_SILVER_SILENT_VOLE(BM_PerfArguments); -BM_REGISTER_EXACC_SILENT_VOLE(BM_PerfArguments); - -} // namespace yacl::crypto diff --git a/yacl/crypto/primitives/vole/f2k/silent_vole_test.cc b/yacl/crypto/primitives/vole/f2k/silent_vole_test.cc deleted file mode 100644 index b4166c36..00000000 --- a/yacl/crypto/primitives/vole/f2k/silent_vole_test.cc +++ /dev/null @@ -1,159 +0,0 @@ -// Copyright 2023 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "yacl/crypto/primitives/vole/f2k/silent_vole.h" - -#include - -#include -#include -#include -#include -#include - -#include "yacl/base/int128.h" -#include "yacl/link/test_util.h" -#include "yacl/math/f2k/f2k.h" - -namespace yacl::crypto { - -struct TestParams { - CodeType codetype; - size_t num; -}; - -class VoleTest : public ::testing::TestWithParam {}; - -// VOLE over GF(2^64) x GF(2^64) -TEST_P(VoleTest, SlientVole_GF64_Test) { - auto lctxs = link::test::SetupWorld(2); // setup network - - const auto codetype = GetParam().codetype; - const uint64_t vole_num = GetParam().num; - - std::vector a(vole_num); - std::vector b(vole_num); - std::vector c(vole_num); - uint64_t delta = 0; - - auto sender = std::async([&] { - auto sv_sender = SilentVoleSender(codetype); - sv_sender.Send(lctxs[0], absl::MakeSpan(c)); - delta = sv_sender.GetDelta64(); - }); - - auto receiver = std::async([&] { - auto sv_receiver = SilentVoleReceiver(codetype); - sv_receiver.Recv(lctxs[1], absl::MakeSpan(a), absl::MakeSpan(b)); - }); - - sender.get(); - receiver.get(); - - for (uint64_t i = 0; i < vole_num; ++i) { - EXPECT_EQ(GfMul64(a[i], delta) ^ b[i], c[i]); - } -} - -// VOLE over GF(2^128) x GF(2^128) -TEST_P(VoleTest, SlientVole_GF128_Test) { - auto lctxs = link::test::SetupWorld(2); // setup network - - const auto codetype = GetParam().codetype; - const uint64_t vole_num = GetParam().num; - - std::vector a(vole_num); - std::vector b(vole_num); - std::vector c(vole_num); - uint128_t delta = 0; - - auto sender = std::async([&] { - auto sv_sender = SilentVoleSender(codetype); - sv_sender.Send(lctxs[0], absl::MakeSpan(c)); - delta = sv_sender.GetDelta(); - }); - - auto receiver = std::async([&] { - auto sv_receiver = SilentVoleReceiver(codetype); - sv_receiver.Recv(lctxs[1], absl::MakeSpan(a), absl::MakeSpan(b)); - }); - - sender.get(); - receiver.get(); - - for (uint64_t i = 0; i < vole_num; ++i) { - EXPECT_EQ(GfMul128(a[i], delta) ^ b[i], c[i]); - } -} - -// subfield VOLE over GF(2^64) x GF(2^128) -TEST_P(VoleTest, SlientVole_GF64xGF128_Test) { - auto lctxs = link::test::SetupWorld(2); // setup network - - const auto codetype = GetParam().codetype; - const uint64_t vole_num = GetParam().num; - - std::vector a(vole_num); - std::vector b(vole_num); - std::vector c(vole_num); - uint128_t delta = 0; - - auto sender = std::async([&] { - auto sv_sender = SilentVoleSender(codetype); - sv_sender.SfSend(lctxs[0], absl::MakeSpan(c)); - delta = sv_sender.GetDelta(); - }); - - auto receiver = std::async([&] { - auto sv_receiver = SilentVoleReceiver(codetype); - sv_receiver.SfRecv(lctxs[1], absl::MakeSpan(a), absl::MakeSpan(b)); - }); - - sender.get(); - receiver.get(); - - for (uint64_t i = 0; i < vole_num; ++i) { - auto ai = yacl::MakeUint128(0, a[i]); - EXPECT_EQ(GfMul128(ai, delta) ^ b[i], c[i]); - } -} - -INSTANTIATE_TEST_SUITE_P( - Works_Instances, VoleTest, - testing::Values(TestParams{CodeType::Silver5, 64}, // edge test - TestParams{CodeType::Silver5, 1 << 10}, - TestParams{CodeType::Silver5, 1 << 14}, - TestParams{CodeType::Silver5, 1 << 18}, - TestParams{CodeType::Silver11, 64}, // edge test - TestParams{CodeType::Silver11, 1 << 10}, - TestParams{CodeType::Silver11, 1 << 14}, - TestParams{CodeType::Silver11, 1 << 18}, - TestParams{CodeType::ExAcc7, 64}, // edge test - TestParams{CodeType::ExAcc7, 1 << 10}, - TestParams{CodeType::ExAcc7, 1 << 14}, - TestParams{CodeType::ExAcc7, 1 << 18}, - TestParams{CodeType::ExAcc11, 64}, // edge test - TestParams{CodeType::ExAcc11, 1 << 10}, - TestParams{CodeType::ExAcc11, 1 << 14}, - TestParams{CodeType::ExAcc11, 1 << 18}, - TestParams{CodeType::ExAcc21, 64}, // edge test - TestParams{CodeType::ExAcc21, 1 << 10}, - TestParams{CodeType::ExAcc21, 1 << 14}, - TestParams{CodeType::ExAcc21, 1 << 18}, - TestParams{CodeType::ExAcc40, 64}, // edge test - TestParams{CodeType::ExAcc40, 1 << 10}, - TestParams{CodeType::ExAcc40, 1 << 14}, - TestParams{CodeType::ExAcc40, 1 << 18})); - -} // namespace yacl::crypto \ No newline at end of file diff --git a/yacl/crypto/primitives/vole/f2k/sparse_vole.h b/yacl/crypto/primitives/vole/f2k/sparse_vole.h deleted file mode 100644 index 57ac5cc1..00000000 --- a/yacl/crypto/primitives/vole/f2k/sparse_vole.h +++ /dev/null @@ -1,163 +0,0 @@ -// Copyright 2023 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "yacl/crypto/primitives/ot/ot_store.h" -#include "yacl/crypto/utils/secparam.h" -#include "yacl/math/gadget.h" - -/* submodules */ -#include "yacl/crypto/primitives/ot/ferret_ote.h" -#include "yacl/crypto/primitives/ot/gywz_ote.h" -#include "yacl/crypto/primitives/ot/sgrr_ote.h" -#include "yacl/crypto/primitives/ot/softspoken_ote.h" -#include "yacl/crypto/tools/crhash.h" -#include "yacl/crypto/tools/rp.h" -#include "yacl/crypto/utils/rand.h" - -/* security parameter declaration */ -YACL_MODULE_DECLARE("sparse_vole", SecParam::C::INF, SecParam::S::INF); - -namespace yacl::crypto { - -// Implementation about sparse-VOLE over GF(2^128), including base-VOLE, -// single-point VOLE and multi-point VOLE. For more detail, see -// https://eprint.iacr.org/2019/1084.pdf protocol 4 & protocol 5. - -// Single-point f2k-Vole (by SGRR-OTe) -// the type of ot_store must be ROT -// void SpVoleSend(const std::shared_ptr& ctx, -// const OtSendStore& /*rot*/ send_ot, uint32_t n, uint128_t w, -// absl::Span output); - -// void SpVoleRecv(const std::shared_ptr& ctx, -// const OtRecvStore& /*rot*/ recv_ot, uint32_t n, uint32_t -// index, uint128_t v, absl::Span output); - -// Single-point GF(2^128)-Vole (by GYWZ-OTe) -// the type of ot_store must be COT -void SpVoleSend(const std::shared_ptr& ctx, - const OtSendStore& /*cot*/ send_ot, uint32_t n, uint128_t w, - absl::Span output); - -void SpVoleRecv(const std::shared_ptr& ctx, - const OtRecvStore& /*cot*/ recv_ot, uint32_t n, uint32_t index, - uint128_t v, absl::Span output); - -struct MpVoleParam { - uint64_t noise_num_; - uint64_t sp_vole_size_; - uint64_t last_sp_vole_size_; - - uint64_t mp_vole_size_; - uint64_t require_ot_num_; // total ot num - - std::vector indexes_; - LpnNoiseAsm assumption_; - - MpVoleParam() : MpVoleParam(1, 2) {} - - MpVoleParam(uint64_t noise_num, uint64_t mp_vole_size, - LpnNoiseAsm assumption = LpnNoiseAsm::RegularNoise) { - YACL_ENFORCE(assumption == LpnNoiseAsm::RegularNoise); - YACL_ENFORCE(noise_num > 0); - noise_num_ = noise_num; - mp_vole_size_ = mp_vole_size; - assumption_ = assumption; - - sp_vole_size_ = mp_vole_size_ / noise_num_; - last_sp_vole_size_ = mp_vole_size_ - (noise_num_ - 1) * sp_vole_size_; - - YACL_ENFORCE(sp_vole_size_ > 1, - "The size of SpVole should be greater than 1, because " - "1-out-of-1 SpVole is meaningless"); - - require_ot_num_ = math::Log2Ceil(sp_vole_size_) * (noise_num_ - 1) + - math::Log2Ceil(last_sp_vole_size_); - } - - // [Warning] not strictly uniformly random - void GenIndexes() { - indexes_ = RandVec(noise_num_); - for (uint32_t i = 0; i < noise_num_ - 1; ++i) { - indexes_[i] %= sp_vole_size_; - } - indexes_[noise_num_ - 1] %= last_sp_vole_size_; - } - - void SetIndexes(absl::Span indexes) { - for (uint32_t i = 0; i < noise_num_ - 1; ++i) { - indexes_[i] = indexes[i] % sp_vole_size_; - } - indexes_[noise_num_ - 1] %= last_sp_vole_size_; - } -}; - -// Multi-point f2k-Vole with Regular Noise (SGRR-OTe based) -// void MpVoleSend(const std::shared_ptr& ctx, -// const OtSendStore& /*rot*/ send_ot, const MpVoleParam& param, -// absl::Span w, absl::Span output); - -// void MpVoleRecv(const std::shared_ptr& ctx, -// const OtRecvStore& /*rot*/ recv_ot, const MpVoleParam& param, -// absl::Span v, absl::Span output); - -// Multi-point GF(2^128)-Vole with Regular Noise (GYWZ-OTe based) -void MpVoleSend(const std::shared_ptr& ctx, - const OtSendStore& /*cot*/ send_ot, const MpVoleParam& param, - absl::Span w, absl::Span output); - -void MpVoleRecv(const std::shared_ptr& ctx, - const OtRecvStore& /*cot*/ recv_ot, const MpVoleParam& param, - absl::Span v, absl::Span output); - -// -// -------------------------- -// Customized -// -------------------------- -// -// Multi-point f2k-Vole with Regular Noise (GYWZ-OTe based) -// Most efficiency! Punctured indexes would be determined by the choices of -// OtStore. But "MpVoleSend_fixed_index/MpVoleRecv_fixed_index" would not check -// whether the indexes determined by OtStore and the indexes provided by -// MpVoleParam are same.、 -// -// GF(2^128) -void MpVoleSend_fixed_index(const std::shared_ptr& ctx, - const OtSendStore& /*cot*/ send_ot, - const MpVoleParam& param, - absl::Span w, - absl::Span output); - -void MpVoleRecv_fixed_index(const std::shared_ptr& ctx, - const OtRecvStore& /*cot*/ recv_ot, - const MpVoleParam& param, - absl::Span v, - absl::Span output); - -// GF(2^64) -void MpVoleSend_fixed_index(const std::shared_ptr& ctx, - const OtSendStore& /*cot*/ send_ot, - const MpVoleParam& param, - absl::Span w, - absl::Span output); - -void MpVoleRecv_fixed_index(const std::shared_ptr& ctx, - const OtRecvStore& /*cot*/ recv_ot, - const MpVoleParam& param, - absl::Span v, - absl::Span output); - -} // namespace yacl::crypto diff --git a/yacl/crypto/primitives/vole/f2k/sparse_vole_test.cc b/yacl/crypto/primitives/vole/f2k/sparse_vole_test.cc deleted file mode 100644 index 5d0b072c..00000000 --- a/yacl/crypto/primitives/vole/f2k/sparse_vole_test.cc +++ /dev/null @@ -1,269 +0,0 @@ -// Copyright 2023 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "yacl/crypto/primitives/vole/f2k/sparse_vole.h" - -#include - -#include -#include -#include -#include - -#include "sparse_vole.h" - -#include "yacl/base/exception.h" -#include "yacl/crypto/utils/rand.h" -#include "yacl/link/test_util.h" -#include "yacl/math/f2k/f2k.h" -#include "yacl/math/gadget.h" - -namespace yacl::crypto { - -struct TestParams { - size_t num; -}; - -struct TestParams2 { - size_t num; - size_t index_num; -}; - -class SpVoleTest : public ::testing::TestWithParam {}; -class MpVoleTest : public ::testing::TestWithParam {}; - -TEST_P(SpVoleTest, SpVoleWork) { - auto lctxs = link::test::SetupWorld(2); // setup network - const uint64_t num = GetParam().num; - auto cot = MockCots(math::Log2Ceil(num), FastRandU128()); - - // auto delta = rot.recv.CopyChoice().data()[0]; - - std::vector v(num); - std::vector w(num); - - uint128_t single_v = FastRandU128(); - uint128_t single_w = FastRandU128(); - uint32_t index = FastRandU64() % num; - - auto sender = std::async([&] { - SpVoleSend(lctxs[0], cot.send, num, single_w, absl::MakeSpan(w)); - }); - - auto receiver = std::async([&] { - SpVoleRecv(lctxs[1], cot.recv, num, index, single_v, absl::MakeSpan(v)); - }); - - sender.get(); - receiver.get(); - - for (uint64_t i = 0; i < num; ++i) { - if (i == index) { - EXPECT_EQ(v[i] ^ w[i], single_v ^ single_w); - } else { - EXPECT_EQ(v[i], w[i]); - } - } -} - -TEST_P(MpVoleTest, MpVoleWork) { - auto lctxs = link::test::SetupWorld(2); // setup network - const uint64_t num = GetParam().num; - const uint64_t index_num = GetParam().index_num; - - MpVoleParam param(index_num, num); - param.GenIndexes(); - - auto cot = MockCots(param.require_ot_num_, FastRandU128()); - - std::vector s_output(num); - std::vector r_output(num); - - auto v = RandVec(index_num); - auto w = RandVec(index_num); - - auto sender = std::async([&] { - MpVoleSend(lctxs[0], cot.send, param, absl::MakeSpan(w), - absl::MakeSpan(s_output)); - }); - - auto receiver = std::async([&] { - MpVoleRecv(lctxs[1], cot.recv, param, absl::MakeSpan(v), - absl::MakeSpan(r_output)); - }); - - sender.get(); - receiver.get(); - - std::set indexes; - for (size_t i = 0; i < param.noise_num_; ++i) { - indexes.insert(i * param.sp_vole_size_ + param.indexes_[i]); - } - uint64_t j = 0; - uint64_t i = 0; - for (; i < num && j < index_num; ++i) { - if (s_output[i] != r_output[i]) { - EXPECT_EQ(v[j] ^ w[j], s_output[i] ^ r_output[i]); - EXPECT_TRUE(indexes.count(i)); - j++; - } - } - for (; i < num; ++i) { - EXPECT_EQ(s_output[i], r_output[i]); - } -} - -TEST_P(MpVoleTest, MpVole128_fixed_index_Work) { - auto lctxs = link::test::SetupWorld(2); // setup network - const uint64_t num = GetParam().num; - const uint64_t index_num = GetParam().index_num; - - MpVoleParam param(index_num, num); - - auto choices = RandBits>(param.require_ot_num_); - // dynamic_bitset choices; - // generate the choices for MpVole - param.GenIndexes(); - uint64_t pos = 0; - for (size_t i = 0; i < param.noise_num_; ++i) { - auto this_size = (i == param.noise_num_ - 1) - ? math::Log2Ceil(param.last_sp_vole_size_) - : math::Log2Ceil(param.sp_vole_size_); - uint32_t bound = 1 << this_size; - for (uint32_t mask = 1; mask < bound; mask <<= 1) { - choices.set(pos, param.indexes_[i] & mask); - ++pos; - } - } - - YACL_ENFORCE(pos == param.require_ot_num_); - - auto cot = MockCots(param.require_ot_num_, FastRandU128(), choices); - - std::vector s_output(num); - std::vector r_output(num); - - auto v = RandVec(index_num); - auto w = RandVec(index_num); - - auto sender = std::async([&] { - MpVoleSend_fixed_index(lctxs[0], cot.send, param, absl::MakeSpan(w), - absl::MakeSpan(s_output)); - }); - - auto receiver = std::async([&] { - MpVoleRecv_fixed_index(lctxs[1], cot.recv, param, absl::MakeSpan(v), - absl::MakeSpan(r_output)); - }); - - sender.get(); - receiver.get(); - - std::set indexes; - for (size_t i = 0; i < param.noise_num_; ++i) { - indexes.insert(i * param.sp_vole_size_ + param.indexes_[i]); - } - uint64_t j = 0; - uint64_t i = 0; - for (; i < num && j < index_num; ++i) { - if (s_output[i] != r_output[i]) { - EXPECT_EQ(v[j] ^ w[j], s_output[i] ^ r_output[i]); - EXPECT_TRUE(indexes.count(i)); - j++; - } - } - for (; i < num; ++i) { - EXPECT_EQ(s_output[i], r_output[i]); - } -} - -TEST_P(MpVoleTest, MpVole64_fixed_index_Work) { - auto lctxs = link::test::SetupWorld(2); // setup network - const uint64_t num = GetParam().num; - const uint64_t index_num = GetParam().index_num; - - MpVoleParam param(index_num, num); - - auto choices = RandBits>(param.require_ot_num_); - // dynamic_bitset choices; - // generate the choices for MpVole - param.GenIndexes(); - uint64_t pos = 0; - for (size_t i = 0; i < param.noise_num_; ++i) { - auto this_size = (i == param.noise_num_ - 1) - ? math::Log2Ceil(param.last_sp_vole_size_) - : math::Log2Ceil(param.sp_vole_size_); - uint32_t bound = 1 << this_size; - for (uint32_t mask = 1; mask < bound; mask <<= 1) { - choices.set(pos, param.indexes_[i] & mask); - ++pos; - } - } - - YACL_ENFORCE(pos == param.require_ot_num_); - - auto cot = MockCots(param.require_ot_num_, FastRandU128(), choices); - - std::vector s_output(num); - std::vector r_output(num); - - auto v = RandVec(index_num); - auto w = RandVec(index_num); - - auto sender = std::async([&] { - MpVoleSend_fixed_index(lctxs[0], cot.send, param, absl::MakeSpan(w), - absl::MakeSpan(s_output)); - }); - - auto receiver = std::async([&] { - MpVoleRecv_fixed_index(lctxs[1], cot.recv, param, absl::MakeSpan(v), - absl::MakeSpan(r_output)); - }); - - sender.get(); - receiver.get(); - - std::set indexes; - for (size_t i = 0; i < param.noise_num_; ++i) { - indexes.insert(i * param.sp_vole_size_ + param.indexes_[i]); - } - uint64_t j = 0; - uint64_t i = 0; - for (; i < num && j < index_num; ++i) { - if (s_output[i] != r_output[i]) { - EXPECT_EQ(v[j] ^ w[j], s_output[i] ^ r_output[i]); - EXPECT_TRUE(indexes.count(i)); - j++; - } - } - for (; i < num; ++i) { - EXPECT_EQ(s_output[i], r_output[i]); - } -} - -INSTANTIATE_TEST_SUITE_P(Works_Instances, SpVoleTest, - testing::Values(TestParams{4}, TestParams{5}, // - TestParams{7}, // - TestParams{1 << 8}, - TestParams{1 << 10})); - -INSTANTIATE_TEST_SUITE_P(Works_Instances, MpVoleTest, - testing::Values(TestParams2{4, 2}, - TestParams2{5, 2}, // - TestParams2{7, 2}, // - TestParams2{1 << 8, 64}, - TestParams2{1 << 10, 257}, - TestParams2{1 << 20, 1024})); - -} // namespace yacl::crypto diff --git a/yacl/crypto/primitives/vole/mp_vole.h b/yacl/crypto/primitives/vole/mp_vole.h new file mode 100644 index 00000000..4010e169 --- /dev/null +++ b/yacl/crypto/primitives/vole/mp_vole.h @@ -0,0 +1,215 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +/* submodules */ +#include "yacl/crypto/base/hash/hash_utils.h" +#include "yacl/crypto/primitives/dpf/mpfss.h" +#include "yacl/crypto/primitives/ot/ot_store.h" +#include "yacl/crypto/tools/common.h" +#include "yacl/crypto/utils/rand.h" +#include "yacl/crypto/utils/secparam.h" +#include "yacl/math/gadget.h" + +YACL_MODULE_DECLARE("mp_vole", SecParam::C::INF, SecParam::S::INF); +namespace yacl::crypto { + +using MpVoleParam = MpFssParam; + +// +// Implementation about multi-point VOLE over GF(2^128). For more detail, see +// https://eprint.iacr.org/2019/1084.pdf protocol 4 & protocol 5. +// +// +-----------+ +-----------+ +// | MP-FSS | => | Mp-VOLE | +// +-----------+ +-----------+ +// num = m num = m +// len = kappa len = kappa +// +// > kappa: computation security parameter (128 for example) +// +// Consistency check is adapted from Ferret/Wolverine: +// 1) Ferret Consistency check: https://eprint.iacr.org/2020/924.pdf Fig.6 and +// Appendix C +// 2) Wolverine Cosisntency check: https://eprint.iacr.org/2020/925.pdf +// Fig.7 +// +// Notice: +// - MpVoleSender would get vector c; MpVoleReceiver would get vector a and +// vector b, where a is a sparse vector (t-weight). +// - Send && Recv would consume base-VOLE (pre_c_ = delta * pre_a_ + pre_b_ ). +// Before invoking MpVoleSender::Send and MpVoleReceiver::Recv, caller needs +// provide t base-VOLE by calling OneTimeSetup. +// + +template +class MpVoleSender { + public: + MpVoleSender(const MpVoleParam& param) : param_(param) { + is_mal_ = param_.is_mal_; + } + + MpVoleSender(uint64_t noise_num, uint64_t mp_vole_size, bool mal = false) + : param_(noise_num, mp_vole_size, mal), is_mal_(mal) {} + + void OneTimeSetup(K delta, absl::Span pre_c) { + YACL_ENFORCE(param_.base_vole_num_ == pre_c.size()); + + delta_ = delta; + pre_c_ = std::vector(pre_c.begin(), pre_c.end()); + is_setup_ = true; + is_finish_ = false; + } + + void OneTimeSetup(K delta, std::vector&& pre_c) { + YACL_ENFORCE(param_.base_vole_num_ == pre_c.size()); + + delta_ = delta; + pre_c_ = std::move(pre_c); + is_setup_ = true; + is_finish_ = false; + } + + void Send(const std::shared_ptr& ctx, + const OtSendStore& /*cot*/ send_ot, absl::Span c, + bool fixed_index = false) { + YACL_ENFORCE(is_setup_ == true); + YACL_ENFORCE(is_finish_ == false); + YACL_ENFORCE(c.size() >= param_.mp_vole_size_); + // Call MPFSS + if (fixed_index) { + MpfssSend_fixed_index(ctx, send_ot, param_, absl::MakeSpan(pre_c_), c); + } else { + MpfssSend(ctx, send_ot, param_, absl::MakeSpan(pre_c_), c); + } + + if (is_mal_) { + K seed = SyncSeedRecv(ctx); + auto uhash = + math::UniversalHash(seed, c.subspan(0, param_.mp_vole_size_)); + + auto buf = ctx->Recv(ctx->NextRank(), "MpVole:Malicious"); + YACL_ENFORCE(buf.size() == sizeof(K)); + auto payload = *reinterpret_cast(buf.data()); + + uhash = uhash ^ math::GfMul(payload, delta_) ^ + pre_c_[param_.base_vole_num_ - 1]; + + auto hash = Blake3(ByteContainerView(&uhash, sizeof(uhash))); + ctx->SendAsync(ctx->NextRank(), ByteContainerView(hash), + "MpVole:Hash Value"); + } + is_finish_ = true; + } + + private: + MpVoleParam param_; + K delta_{0}; + std::vector pre_c_; + + bool is_mal_{false}; + bool is_setup_{false}; + bool is_finish_{false}; +}; + +template +class MpVoleReceiver { + public: + MpVoleReceiver(const MpVoleParam& param) : param_(param) { + is_mal_ = param_.is_mal_; + } + + MpVoleReceiver(uint64_t noise_num, uint64_t mp_vole_size, bool mal = false) + : param_(noise_num, mp_vole_size, mal), is_mal_(mal) {} + + void OneTimeSetup(absl::Span pre_a, absl::Span pre_b) { + YACL_ENFORCE(param_.base_vole_num_ == pre_a.size()); + YACL_ENFORCE(param_.base_vole_num_ == pre_b.size()); + + pre_a_ = std::vector(pre_a.begin(), pre_a.end()); + pre_b_ = std::vector(pre_b.begin(), pre_b.end()); + is_setup_ = true; + } + + void OneTimeSetup(std::vector&& pre_a, std::vector&& pre_b) { + YACL_ENFORCE(param_.base_vole_num_ == pre_a.size()); + YACL_ENFORCE(param_.base_vole_num_ == pre_b.size()); + + pre_a_ = std::move(pre_a); + pre_b_ = std::move(pre_b); + is_setup_ = true; + } + + void Recv(const std::shared_ptr& ctx, + const OtRecvStore& /*cot*/ recv_ot, absl::Span a, + absl::Span b, bool fixed_index = false) { + YACL_ENFORCE(is_setup_ == true); + YACL_ENFORCE(is_finish_ == false); + YACL_ENFORCE(a.size() >= param_.mp_vole_size_); + YACL_ENFORCE(b.size() >= param_.mp_vole_size_); + // Call MPFSS + if (fixed_index) { + MpfssRecv_fixed_index(ctx, recv_ot, param_, b); + } else { + MpfssRecv(ctx, recv_ot, param_, b); + } + + std::vector indexes(param_.noise_num_); + for (size_t i = 0; i < indexes.size(); ++i) { + auto index = i * param_.sp_vole_size_ + param_.indexes_[i]; + indexes[i] = index; + a[index] = pre_a_[i]; + b[index] = b[index] ^ pre_b_[i]; + } + + if (is_mal_) { + K seed = SyncSeedSend(ctx); + auto uhash = + math::UniversalHash(seed, b.subspan(0, param_.mp_vole_size_)); + auto coef = math::ExtractHashCoef(seed, indexes); + // Notice that: Sender.uhash + Receiver.uhash = payload * delta + auto payload = math::GfMul(absl::MakeSpan(coef), + absl::MakeSpan(pre_a_.data(), indexes.size())); + + // mask uhash && payload by extra one VOLE correlation + payload ^= pre_a_[param_.base_vole_num_ - 1]; + uhash ^= pre_b_[param_.base_vole_num_ - 1]; + + ctx->SendAsync(ctx->NextRank(), + ByteContainerView(&payload, sizeof(payload)), + "MpVole:Malicious"); + + auto hash = Blake3(ByteContainerView(&uhash, sizeof(uhash))); + auto buf = ctx->Recv(ctx->NextRank(), "MpVole: Hash Value"); + YACL_ENFORCE(ByteContainerView(hash) == ByteContainerView(buf)); + } + + is_finish_ = true; + } + + private: + MpVoleParam param_; + + std::vector pre_a_; + std::vector pre_b_; + + bool is_mal_{false}; + bool is_setup_{false}; + bool is_finish_{false}; +}; + +} // namespace yacl::crypto diff --git a/yacl/crypto/primitives/vole/mp_vole_test.cc b/yacl/crypto/primitives/vole/mp_vole_test.cc new file mode 100644 index 00000000..b15a597f --- /dev/null +++ b/yacl/crypto/primitives/vole/mp_vole_test.cc @@ -0,0 +1,116 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mp_vole.h" + +#include + +#include +#include +#include +#include + +#include "yacl/base/exception.h" +#include "yacl/crypto/utils/rand.h" +#include "yacl/link/test_util.h" +#include "yacl/math/f2k/f2k.h" +#include "yacl/math/gadget.h" + +namespace yacl::crypto { + +struct TestParams { + size_t num; + size_t index_num; +}; + +class MpVoleTest + : public ::testing::TestWithParam> {}; + +using GF64 = uint64_t; +using GF128 = uint128_t; + +#define DECLEAR_MPVOLE_TEST(Type0, Type1) \ + TEST_P(MpVoleTest, Type0##x##Type1) { \ + /* setup network */ \ + auto lctxs = link::test::SetupWorld(2); \ + const auto mal = std::get<0>(GetParam()); \ + const auto fixed_index = std::get<1>(GetParam()); \ + const uint64_t num = std::get<2>(GetParam()).num; \ + const uint64_t index_num = std::get<2>(GetParam()).index_num; \ + MpVoleParam param(index_num, num, mal); \ + const auto base_vole_num = param.base_vole_num_; \ + param.GenIndexes(); \ + auto delta = static_cast(FastRandU128()); \ + auto pre_a = RandVec(base_vole_num); \ + auto pre_b = RandVec(base_vole_num); \ + auto pre_c = RandVec(base_vole_num); \ + for (size_t i = 0; i < base_vole_num; ++i) { \ + pre_c[i] = math::GfMul(delta, pre_a[i]) ^ pre_b[i]; \ + } \ + auto choices = RandBits>(param.require_ot_num_); \ + param.GenIndexes(); \ + if (fixed_index) { \ + choices = param.GenChoices(); \ + } \ + YACL_ENFORCE(choices.size() == param.require_ot_num_); \ + auto cot = MockCots(param.require_ot_num_, FastRandU128(), choices); \ + std::vector a(num, 0); \ + std::vector b(num, 0); \ + std::vector c(num, 0); \ + MpVoleSender mp_sender(param); \ + MpVoleReceiver mp_receiver(param); \ + mp_sender.OneTimeSetup(delta, std::move(pre_c)); \ + mp_receiver.OneTimeSetup(std::move(pre_a), std::move(pre_b)); \ + auto sender = std::async([&] { \ + mp_sender.Send(lctxs[0], cot.send, absl::MakeSpan(c), fixed_index); \ + }); \ + auto receiver = std::async([&] { \ + mp_receiver.Recv(lctxs[1], cot.recv, absl::MakeSpan(a), \ + absl::MakeSpan(b), fixed_index); \ + }); \ + sender.get(); \ + receiver.get(); \ + std::set indexes; \ + for (size_t i = 0; i < param.noise_num_; ++i) { \ + indexes.insert(i* param.sp_vole_size_ + param.indexes_[i]); \ + } \ + for (uint64_t i = 0; i < num; ++i) { \ + EXPECT_EQ(math::GfMul(delta, a[i]), b[i] ^ c[i]); \ + } \ + } + +DECLEAR_MPVOLE_TEST(GF64, GF64); +DECLEAR_MPVOLE_TEST(GF64, GF128); +DECLEAR_MPVOLE_TEST(GF128, GF128); + +INSTANTIATE_TEST_SUITE_P( + f2kVOLE, MpVoleTest, + testing::Combine(testing::Values(false, // false for semi-honest + true), // true for malicious + testing::Values(false, // false for selected-index + true), // true for fixed-index + testing::Values(TestParams{4, 2}, // edge + TestParams{5, 2}, // edge + TestParams{7, 2}, // edge + TestParams{1 << 8, 64}, + TestParams{1 << 10, 257}, + TestParams{1 << 20, 1024})), + [](const testing::TestParamInfo& p) { + return fmt::format( + "{}_{}_t{}xn{}", std::get<0>(p.param) ? "Mal" : "Semi", + std::get<1>(p.param) ? "Fixed_index" : "Selected_index", + std::get<2>(p.param).index_num, std::get<2>(p.param).num); + }); + +} // namespace yacl::crypto diff --git a/yacl/crypto/primitives/vole/f2k/silent_vole.cc b/yacl/crypto/primitives/vole/silent_vole.cc similarity index 70% rename from yacl/crypto/primitives/vole/f2k/silent_vole.cc rename to yacl/crypto/primitives/vole/silent_vole.cc index 783fa92e..5092bedf 100644 --- a/yacl/crypto/primitives/vole/f2k/silent_vole.cc +++ b/yacl/crypto/primitives/vole/silent_vole.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "yacl/crypto/primitives/vole/f2k/silent_vole.h" +#include "yacl/crypto/primitives/vole/silent_vole.h" #include #include @@ -20,34 +20,17 @@ #include "yacl/base/aligned_vector.h" #include "yacl/base/dynamic_bitset.h" #include "yacl/base/int128.h" -#include "yacl/math/gadget.h" namespace yacl::crypto { namespace { -// Linear Test, more details could be found in -// https://eprint.iacr.org/2022/1014.pdf Definition 2.5 bias( Reg_t^N ) equal or -// less than e^{-td/N} where t is the number of noise in dual-LPN problem, d is -// the minimum weight of vectors in dual-LPN matrix. Thus, we can view d/N as -// the minimum distance ratio for dual-LPN matrix. -// -// Implementation of GenRegNoiseWeight is mostly from: -// https://github.com/osu-crypto/libOTe/blob/master/libOTe/TwoChooseOne/ConfigureCode.cpp -// which would return the number of noise in MpVole -// -uint64_t GenRegNoiseWeight(double min_dist_ratio, uint64_t security_param) { - if (min_dist_ratio > 0.5 || min_dist_ratio <= 0) { - YACL_THROW("mini distance too small, rate {}", min_dist_ratio); - } - - auto d = -std::log2(1 - 2 * min_dist_ratio); - auto t = std::max(128, double(security_param) / d); - - return math::RoundUpTo(t, 8); -} +// minimum distance for dual-LPN code +static std::map kMinDistanceRatio = { + {CodeType::Silver5, 0.2}, {CodeType::Silver11, 0.2}, + {CodeType::ExAcc7, 0.05}, {CodeType::ExAcc11, 0.1}, + {CodeType::ExAcc21, 0.1}, {CodeType::ExAcc40, 0.2}}; -// Silent Vole internal parameters template struct VoleParam { uint64_t vole_num_; // vole num @@ -62,16 +45,25 @@ struct VoleParam { uint64_t mp_vole_ot_num_; // mp vole (cot/rot-based) uint64_t require_ot_num_; // total ot num + bool is_mal_{false}; + // Constructor - VoleParam(CodeType code, uint64_t vole_num) - : VoleParam(code, vole_num, YACL_MODULE_SECPARAM_C_UINT("silent_vole")) {} + VoleParam(CodeType code, uint64_t vole_num, bool mal = false) {} + + VoleParam(CodeType code, uint64_t vole_num, uint64_t sec, bool mal = false) { + codetype_ = code; + is_mal_ = mal; + vole_num_ = vole_num; + assumption_ = LpnNoiseAsm::RegularNoise; - VoleParam(CodeType code, uint64_t vole_num, uint64_t sec) { // default - uint64_t gap = 0; + uint64_t gap = 0; // Silver Parameters uint64_t code_scaler = 2; - double min_dist_ratio = 0.2; - codetype_ = code; + // check + YACL_ENFORCE( + kMinDistanceRatio.count(code), + "Error: could not found the minimum distance for current code."); + double min_dist_ratio = kMinDistanceRatio[code]; switch (codetype_) { case CodeType::Silver5: @@ -80,25 +72,10 @@ struct VoleParam { case CodeType::Silver11: gap = 32; break; - case CodeType::ExAcc7: - min_dist_ratio = 0.05; - break; - case CodeType::ExAcc11: - case CodeType::ExAcc21: - min_dist_ratio = 0.1; - break; - case CodeType::ExAcc40: - min_dist_ratio = 0.2; - break; - // TODO(@wenfan) - // support ExConv Code default: break; } - vole_num_ = vole_num; - assumption_ = LpnNoiseAsm::RegularNoise; - auto noise_num = GenRegNoiseWeight(min_dist_ratio, sec); // Note that: the size of SpVole must be greater than one. // because 1-out-of-1 Vole/OT is meaningless @@ -107,13 +84,14 @@ struct VoleParam { static_cast(2)); auto mp_vole_size = sp_vole_size * noise_num + gap; - mp_param_ = MpVoleParam(noise_num, mp_vole_size, assumption_); + // initialize parameters for MpVole + mp_param_ = MpVoleParam(noise_num, mp_vole_size, assumption_, is_mal_); code_size_ = mp_param_.mp_vole_size_ / code_scaler; // base_vole + mp_vole base_vole_ot_num_ = - mp_param_.noise_num_ * sizeof(T) * 8; // base_vole (cot-based) - mp_vole_ot_num_ = mp_param_.require_ot_num_; // mp_vole (cot/rot-based) + mp_param_.base_vole_num_ * sizeof(T) * 8; // base_vole (cot-based) + mp_vole_ot_num_ = mp_param_.require_ot_num_; // mp_vole (cot/rot-based) require_ot_num_ = base_vole_ot_num_ + mp_vole_ot_num_; } }; @@ -229,7 +207,8 @@ void SilentVoleSender::SendImpl(const std::shared_ptr& ctx, } const auto vole_num = c.size(); - auto param = VoleParam(codetype_, vole_num); + auto param = VoleParam( + codetype_, vole_num, YACL_MODULE_SECPARAM_C_UINT("silent_vole"), is_mal_); auto& mp_param = param.mp_param_; // [Warning] copy, low efficiency @@ -238,16 +217,23 @@ void SilentVoleSender::SendImpl(const std::shared_ptr& ctx, auto base_vole_cot = all_cot.NextSlice(param.base_vole_ot_num_); // base-vole // base vole, w = u * delta + v - AlignedVector w(mp_param.noise_num_); + std::vector w(mp_param.base_vole_num_); Ot2VoleSend(base_vole_cot, absl::MakeSpan(w)); // mp vole - AlignedVector mp_vole_output(mp_param.mp_vole_size_); - MpVoleSend_fixed_index(ctx, mp_vole_cot, mp_param, absl::MakeSpan(w), - absl::MakeSpan(mp_vole_output)); + auto mpvole = MpVoleSender(mp_param); + // w would be moved into mpvole + mpvole.OneTimeSetup(static_cast(delta_), std::move(w)); + // mp_vole output + // AlignedVector mp_vole_output(mp_param.mp_vole_size_); + auto buf = Buffer(mp_param.mp_vole_size_ * sizeof(K)); + auto mp_vole_output = absl::MakeSpan(buf.data(), mp_param.mp_vole_size_); + // mpvole with fixed index + // which means punctured index would be determined by mp_vole_cot choices + mpvole.Send(ctx, mp_vole_cot, mp_vole_output, true); // dual LPN // compressing mp_vole_output into c - DualLpnEncode(param, absl::MakeSpan(mp_vole_output), c); + DualLpnEncode(param, mp_vole_output, c); } template @@ -259,27 +245,19 @@ void SilentVoleReceiver::RecvImpl(const std::shared_ptr& ctx, const auto vole_num = a.size(); YACL_ENFORCE(vole_num == b.size()); - auto param = VoleParam(codetype_, vole_num); + auto param = VoleParam( + codetype_, vole_num, YACL_MODULE_SECPARAM_C_UINT("silent_vole"), is_mal_); auto& mp_param = param.mp_param_; - auto choices = RandBits>(param.require_ot_num_); // generate punctured indexes for MpVole mp_param.GenIndexes(); - // set mp-cot choices by punctured indexes - { - uint64_t pos = 0; - auto sp_vole_length = math::Log2Ceil(mp_param.sp_vole_size_); - auto last_length = math::Log2Ceil(mp_param.last_sp_vole_size_); - for (size_t i = 0; i < mp_param.noise_num_; ++i) { - auto this_length = - (i == mp_param.noise_num_ - 1) ? last_length : sp_vole_length; - uint32_t bound = 1 << this_length; - for (uint32_t mask = 1; mask < bound; mask <<= 1) { - choices.set(pos, mp_param.indexes_[i] & mask); - ++pos; - } - } - } + // convert MpVole indexes to ot choices + auto choices = mp_param.GenChoices(); // size param.mp_vole_ot_num_ + // generate the choices of base VOLE + auto base_choices = + RandBits>(param.base_vole_ot_num_); + // append choices and base_vole_choices + choices.append(base_choices); // [Warning] copy, low efficiency auto all_cot = ss_receiver_.GenCot(ctx, choices); // generate Cot by choices @@ -287,26 +265,28 @@ void SilentVoleReceiver::RecvImpl(const std::shared_ptr& ctx, auto base_vole_cot = all_cot.NextSlice(param.base_vole_ot_num_); // base vole // base vole, w = u * delta + v - AlignedVector u(mp_param.noise_num_); - AlignedVector v(mp_param.noise_num_); + std::vector u(mp_param.base_vole_num_); + std::vector v(mp_param.base_vole_num_); - // VOLE or subfield VOLE + // base (subfield) VOLE Ot2VoleRecv(base_vole_cot, absl::MakeSpan(u), absl::MakeSpan(v)); // mp vole - // construct sparse noise - auto sparse_noise = AlignedVector(mp_param.mp_vole_size_); - for (uint32_t i = 0; i < mp_param.noise_num_; ++i) { - sparse_noise[i * mp_param.sp_vole_size_ + mp_param.indexes_[i]] = u[i]; - } - AlignedVector mp_vole_output(mp_param.mp_vole_size_); - MpVoleRecv_fixed_index(ctx, mp_vole_cot, mp_param, absl::MakeSpan(v), - absl::MakeSpan(mp_vole_output)); - + auto mpvole = MpVoleReceiver(mp_param); + // u && v would be moved into mpvole + mpvole.OneTimeSetup(std::move(u), std::move(v)); + // sparse_noise && mp_vole output + AlignedVector sparse_noise(mp_param.mp_vole_size_); + // AlignedVector mp_vole_output(mp_param.mp_vole_size_); + auto buf = Buffer(mp_param.mp_vole_size_ * sizeof(K)); + auto mp_vole_output = absl::MakeSpan(buf.data(), mp_param.mp_vole_size_); + // mpvole with fixed index + // which means punctured index would be determined by mp_vole_cot choices + mpvole.Recv(ctx, mp_vole_cot, absl::MakeSpan(sparse_noise), mp_vole_output, + true); // dual LPN // compressing sparse_noise into a, mp_vole_output into b - DualLpnEncode2(param, absl::MakeSpan(sparse_noise), a, - absl::MakeSpan(mp_vole_output), b); + DualLpnEncode2(param, absl::MakeSpan(sparse_noise), a, mp_vole_output, b); } } // namespace yacl::crypto diff --git a/yacl/crypto/primitives/vole/f2k/silent_vole.h b/yacl/crypto/primitives/vole/silent_vole.h similarity index 89% rename from yacl/crypto/primitives/vole/f2k/silent_vole.h rename to yacl/crypto/primitives/vole/silent_vole.h index 4420ca51..f0efdb2a 100644 --- a/yacl/crypto/primitives/vole/f2k/silent_vole.h +++ b/yacl/crypto/primitives/vole/silent_vole.h @@ -14,9 +14,6 @@ #pragma once -#include -#include - #include "yacl/base/exception.h" #include "yacl/base/int128.h" #include "yacl/crypto/utils/secparam.h" @@ -26,30 +23,14 @@ #include "yacl/crypto/primitives/code/code_interface.h" #include "yacl/crypto/primitives/code/ea_code.h" #include "yacl/crypto/primitives/code/silver_code.h" -#include "yacl/crypto/primitives/ot/ferret_ote.h" -#include "yacl/crypto/primitives/vole/f2k/base_vole.h" -#include "yacl/crypto/primitives/vole/f2k/sparse_vole.h" +#include "yacl/crypto/primitives/vole/base_vole.h" +#include "yacl/crypto/primitives/vole/mp_vole.h" #include "yacl/crypto/utils/rand.h" /* security parameter declaration */ YACL_MODULE_DECLARE("silent_vole", SecParam::C::k128, SecParam::S::INF); - namespace yacl::crypto { -enum class CodeType { - // Support Silver & ExAcc only - Silver5, - Silver11, - ExAcc7, - ExAcc11, - ExAcc21, - ExAcc40, - // TODO(@wenfan) - // Support ExConv Code - ExConv7x24, - ExConv21x24 -}; - // Silent Vector OLE Implementation // // Silent VOLE is a "framework" to generate Vector OLE correlation. First of @@ -70,7 +51,7 @@ enum class CodeType { // > OT extension functionality, for more details about its implementation, see // `yacl/crypto/primitives/ot/softspoken_ote.h` // > base VOLE and multi-point VOLE functionalities, for more details about its -// implementation, see `yacl/crypto/primitives/vole/f2k/sparse_vole.h` +// implementation, see `yacl/crypto/primitives/vole/mp_vole.h` // > Dual LPN problem, for more details, please see the original papers // 1) Silver (https://eprint.iacr.org/2021/1150.pdf) Most // efficiency, but not recommended to use due to its security flaw. @@ -82,16 +63,30 @@ enum class CodeType { // would get delta and vector c, such that c = a * delta + b // > Silent Vole aims to generate large amount of VOLE correlation, thus the // length of a,b,c should be greater than 256 at least. -// > When small amount of VOLE correlation is needed (less than 256), see -// `yacl/crypto/primitives/vole/f2k/sparse_vole.h` and use +// > When small amount of VOLE correlation is needed (less than 256), use // `GilboaVoleSend/GilboaVoleRecv` instead. +// dual-LPN code type +enum class CodeType { + // Support Silver & ExAcc only + Silver5, + Silver11, + ExAcc7, + ExAcc11, + ExAcc21, + ExAcc40, + // TODO: @wenfan + // Support ExConv Code + ExConv7x24, + ExConv21x24 +}; + class SilentVoleSender { public: - explicit SilentVoleSender(CodeType code) { - ss_sender_ = SoftspokenOtExtSender(2); + explicit SilentVoleSender(CodeType code, bool mal = false) { codetype_ = code; - delta_ = MakeUint128(0, 0); // init delta_ + is_mal_ = mal; + ss_sender_ = SoftspokenOtExtSender(2, is_mal_); } void OneTimeSetup(const std::shared_ptr& ctx) { @@ -129,8 +124,9 @@ class SilentVoleSender { private: bool is_inited_{false}; + bool is_mal_{false}; CodeType codetype_; - uint128_t delta_; + uint128_t delta_{0}; SoftspokenOtExtSender ss_sender_; template @@ -139,9 +135,10 @@ class SilentVoleSender { class SilentVoleReceiver { public: - explicit SilentVoleReceiver(CodeType code) { - ss_receiver_ = SoftspokenOtExtReceiver(2); + explicit SilentVoleReceiver(CodeType code, bool mal = false) { codetype_ = code; + is_mal_ = mal; + ss_receiver_ = SoftspokenOtExtReceiver(2, is_mal_); } void OneTimeSetup(const std::shared_ptr& ctx) { @@ -167,6 +164,7 @@ class SilentVoleReceiver { private: bool is_inited_{false}; + bool is_mal_{false}; CodeType codetype_; SoftspokenOtExtReceiver ss_receiver_; diff --git a/yacl/crypto/primitives/vole/silent_vole_test.cc b/yacl/crypto/primitives/vole/silent_vole_test.cc new file mode 100644 index 00000000..42e65779 --- /dev/null +++ b/yacl/crypto/primitives/vole/silent_vole_test.cc @@ -0,0 +1,117 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "yacl/crypto/primitives/vole/silent_vole.h" + +#include + +#include +#include +#include +#include +#include +#include + +#include "yacl/base/int128.h" +#include "yacl/link/test_util.h" +#include "yacl/math/f2k/f2k.h" +#include "yacl/math/gadget.h" + +namespace yacl::crypto { + +class SilentVoleTest + : public ::testing::TestWithParam> {}; + +using GF64 = uint64_t; +using GF128 = uint128_t; + +template +void SendWarpper(SilentVoleSender& sender, std::shared_ptr& lctx, + absl::Span c) { + sender.Send(lctx, c); +} + +template <> +void SendWarpper(SilentVoleSender& sender, + std::shared_ptr& lctx, + absl::Span c) { + sender.SfSend(lctx, c); +} + +template +void RecvWrapper(SilentVoleReceiver& receiver, + std::shared_ptr& lctx, absl::Span a, + absl::Span b) { + receiver.Recv(lctx, a, b); +} + +template <> +void RecvWrapper(SilentVoleReceiver& receiver, + std::shared_ptr& lctx, + absl::Span a, absl::Span b) { + receiver.SfRecv(lctx, a, b); +} + +#define DECLARE_SILENT_VOLE_TEST(Type0, Type1) \ + TEST_P(SilentVoleTest, Type0##x##Type1) { \ + auto lctxs = link::test::SetupWorld(2); \ + const auto codetype = std::get<0>(GetParam()); \ + const auto vole_num = std::get<1>(GetParam()); \ + const auto is_mal = std::get<2>(GetParam()); \ + std::vector a(vole_num); \ + std::vector b(vole_num); \ + std::vector c(vole_num); \ + Type1 delta = 0; \ + auto sender = std::async([&] { \ + auto sv_sender = SilentVoleSender(codetype, is_mal); \ + SendWarpper(sv_sender, lctxs[0], absl::MakeSpan(c)); \ + delta = sv_sender.GetDelta(); \ + }); \ + auto receiver = std::async([&] { \ + auto sv_receiver = SilentVoleReceiver(codetype, is_mal); \ + RecvWrapper(sv_receiver, lctxs[1], absl::MakeSpan(a), \ + absl::MakeSpan(b)); \ + }); \ + sender.get(); \ + receiver.get(); \ + for (uint64_t i = 0; i < vole_num; ++i) { \ + EXPECT_EQ(math::GfMul(a[i], delta) ^ b[i], c[i]); \ + } \ + } + +DECLARE_SILENT_VOLE_TEST(GF64, GF64) +DECLARE_SILENT_VOLE_TEST(GF64, GF128) +DECLARE_SILENT_VOLE_TEST(GF128, GF128) + +static std::map kCodeName = { + {CodeType::Silver5, "Silver5"}, {CodeType::Silver11, "Silver11"}, + {CodeType::ExAcc7, "ExAcc7"}, {CodeType::ExAcc11, "ExAcc11"}, + {CodeType::ExAcc21, "ExAcc21"}, {CodeType::ExAcc40, "ExAcc40"}}; + +INSTANTIATE_TEST_SUITE_P( + f2kVole, SilentVoleTest, + testing::Combine(testing::Values(CodeType::Silver5, CodeType::Silver11, + CodeType::ExAcc7, CodeType::ExAcc11, + CodeType::ExAcc21, + CodeType::ExAcc40), // Dual LPN code type + testing::Values(64, 1 << 10, 1 << 14, + 1 << 18), // Vole num + testing::Values(false, true)), // Semi-honest or Malicious + [](const testing::TestParamInfo& p) { + return fmt::format( + "{}_{}_{}", std::get<2>(p.param) == true ? "Mal" : "Semi", + kCodeName[std::get<0>(p.param)], (int)std::get<1>(p.param)); + }); + +} // namespace yacl::crypto diff --git a/yacl/crypto/tools/BUILD.bazel b/yacl/crypto/tools/BUILD.bazel index d5dbddfb..594d7e62 100644 --- a/yacl/crypto/tools/BUILD.bazel +++ b/yacl/crypto/tools/BUILD.bazel @@ -16,6 +16,15 @@ load("//bazel:yacl.bzl", "yacl_cc_binary", "yacl_cc_library", "yacl_cc_test") package(default_visibility = ["//visibility:public"]) +yacl_cc_library( + name = "common", + hdrs = ["common.h"], + deps = [ + "//yacl/link", + "//yacl/utils:serialize", + ], +) + yacl_cc_library( name = "prg", srcs = ["prg.cc"], diff --git a/yacl/crypto/tools/common.h b/yacl/crypto/tools/common.h new file mode 100644 index 00000000..1d178be5 --- /dev/null +++ b/yacl/crypto/tools/common.h @@ -0,0 +1,34 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "yacl/link/link.h" +#include "yacl/utils/serialize.h" + +namespace yacl::crypto { +// ----------------------- +// sync seed protocol +// ----------------------- +uint128_t inline SyncSeedSend(const std::shared_ptr& ctx) { + uint128_t seed = SecureRandSeed(); + ctx->SendAsync(ctx->NextRank(), SerializeUint128(seed), "SyncSeed"); + return seed; +} + +uint128_t inline SyncSeedRecv(const std::shared_ptr& ctx) { + auto buf = ctx->Recv(ctx->NextRank(), "SyncSeed"); + return DeserializeUint128(buf); +} + +} // namespace yacl::crypto diff --git a/yacl/crypto/utils/BUILD.bazel b/yacl/crypto/utils/BUILD.bazel index cb5139be..9f40739c 100644 --- a/yacl/crypto/utils/BUILD.bazel +++ b/yacl/crypto/utils/BUILD.bazel @@ -57,6 +57,7 @@ yacl_cc_library( deps = [ "//yacl/base:exception", "//yacl/base:int128", + "//yacl/math:gadget", ], ) diff --git a/yacl/crypto/utils/secparam.h b/yacl/crypto/utils/secparam.h index 1f5c6d08..79793b83 100644 --- a/yacl/crypto/utils/secparam.h +++ b/yacl/crypto/utils/secparam.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include #include @@ -22,12 +23,14 @@ #include "yacl/base/exception.h" #include "yacl/crypto/utils/compile_time_utils.h" +#include "yacl/math/gadget.h" namespace yacl::crypto { // ------------------ // Security parameter // ------------------ + class SecParam { public: // Computational Security Parameter: A number associated with the amount of @@ -154,6 +157,32 @@ class LpnParam { return {10485760, 452000, 1280, LpnNoiseAsm::RegularNoise}; } }; + +// -------------------------------------- +// dual LPN Parameter (security related) +// -------------------------------------- + +// Linear Test, more details could be found in +// https://eprint.iacr.org/2022/1014.pdf Definition 2.5 bias( Reg_t^N ) equal or +// less than e^{-td/N} where t is the number of noise in dual-LPN problem, d is +// the minimum weight of vectors in dual-LPN matrix. Thus, we can view d/N as +// the minimum distance ratio for dual-LPN matrix. +// +// Implementation of GenRegNoiseWeight is mostly from: +// https://github.com/osu-crypto/libOTe/blob/master/libOTe/TwoChooseOne/ConfigureCode.cpp +// which would return the number of noise in MpVole +// +uint64_t inline GenRegNoiseWeight(double min_dist_ratio, uint64_t sec) { + if (min_dist_ratio > 0.5 || min_dist_ratio <= 0) { + YACL_THROW("mini distance too small, rate {}", min_dist_ratio); + } + + auto d = std::log2(1 - 2 * min_dist_ratio); + auto t = std::max(128, -double(sec) / d); + + return math::RoundUpTo(t, 8); +} + } // namespace yacl::crypto // ------------------ diff --git a/yacl/math/BUILD.bazel b/yacl/math/BUILD.bazel index d3be6754..6fbd8575 100644 --- a/yacl/math/BUILD.bazel +++ b/yacl/math/BUILD.bazel @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("//bazel:yacl.bzl", "yacl_cc_binary", "yacl_cc_library") +load("//bazel:yacl.bzl", "yacl_cc_library") package(default_visibility = ["//visibility:public"]) @@ -20,6 +20,8 @@ yacl_cc_library( name = "gadget", hdrs = ["gadget.h"], deps = [ + "//yacl/base:aligned_vector", + "//yacl/math/f2k", "@com_google_absl//absl/strings", ], ) diff --git a/yacl/math/f2k/f2k.h b/yacl/math/f2k/f2k.h index 41354047..f7da1d1e 100644 --- a/yacl/math/f2k/f2k.h +++ b/yacl/math/f2k/f2k.h @@ -140,8 +140,8 @@ inline uint64_t GfMul64(uint64_t x, uint64_t y) { } // Inner product -inline std::pair ClMul128(absl::Span x, - absl::Span y) { +inline std::pair ClMul128(absl::Span x, + absl::Span y) { YACL_ENFORCE(x.size() == y.size()); const uint64_t size = x.size(); @@ -156,13 +156,15 @@ inline std::pair ClMul128(absl::Span x, return std::make_pair(toU128(ret_high), toU128(ret_low)); } -inline uint128_t GfMul128(absl::Span x, absl::Span y) { +inline uint128_t GfMul128(absl::Span x, + absl::Span y) { YACL_ENFORCE(x.size() == y.size()); auto [high, low] = ClMul128(x, y); return Reduce128(high, low); } -inline uint128_t ClMul64(absl::Span x, absl::Span y) { +inline uint128_t ClMul64(absl::Span x, + absl::Span y) { YACL_ENFORCE(x.size() == y.size()); const uint64_t size = x.size(); @@ -190,7 +192,8 @@ inline uint128_t ClMul64(absl::Span x, absl::Span y) { return toU128(ret); } -inline uint64_t GfMul64(absl::Span x, absl::Span y) { +inline uint64_t GfMul64(absl::Span x, + absl::Span y) { YACL_ENFORCE(x.size() == y.size()); return Reduce64(ClMul64(x, y)); } diff --git a/yacl/math/gadget.h b/yacl/math/gadget.h index 59cfe764..19040dcd 100644 --- a/yacl/math/gadget.h +++ b/yacl/math/gadget.h @@ -16,7 +16,9 @@ #include "absl/strings/numbers.h" +#include "yacl/base/aligned_vector.h" #include "yacl/base/exception.h" +#include "yacl/math/f2k/f2k.h" namespace yacl::math { @@ -38,4 +40,86 @@ constexpr uint64_t RoundUpTo(uint64_t x, uint64_t y) { return DivCeil(x, y) * y; } +// ------------------------ +// f2k-field operation +// ------------------------ + +// inner-product +uint128_t inline GfMul(absl::Span a, + absl::Span b) { + return GfMul128(a, b); +} + +uint64_t inline GfMul(absl::Span a, + absl::Span b) { + return GfMul64(a, b); +} + +uint128_t inline GfMul(absl::Span a, + absl::Span b) { + AlignedVector tmp(b.size()); + std::transform(b.cbegin(), b.cend(), tmp.begin(), [](const uint64_t& val) { + return static_cast(val); + }); + return GfMul128(a, absl::MakeSpan(tmp)); +} + +uint128_t inline GfMul(absl::Span a, + absl::Span b) { + return GfMul(b, a); +} + +// element-wise +uint128_t inline GfMul(uint128_t a, uint128_t b) { return GfMul128(a, b); } + +uint64_t inline GfMul(uint64_t a, uint64_t b) { return GfMul64(a, b); } + +uint128_t inline GfMul(uint128_t a, uint64_t b) { + return GfMul128(a, static_cast(b)); +} + +uint128_t inline GfMul(uint64_t a, uint128_t b) { + return GfMul128(static_cast(a), b); +} + +// ------------------------ +// f2k-Universal Hash +// ------------------------ + +template +T UniversalHash(T seed, absl::Span data) { + T ret = 0; + for_each(data.rbegin(), data.rend(), [&ret, &seed](const T& val) { + ret ^= val; + ret = GfMul(seed, ret); + }); + return ret; +} + +template +std::vector ExtractHashCoef(T seed, + absl::Span indexes /*sorted*/) { + std::array buff = {}; + auto max_bits = math::Log2Ceil(indexes.back()); + buff[0] = seed; + for (size_t i = 1; i <= max_bits; ++i) { + buff[i] = GfMul(buff[i - 1], buff[i - 1]); + } + + std::vector ret; + for (const auto& index : indexes) { + auto index_plus_one = index + 1; + uint64_t mask = 1; + T coef = 1; + for (size_t i = 0; i < 64 && mask <= index_plus_one; ++i) { + if (mask & index_plus_one) { + coef = GfMul(coef, buff[i]); + } + mask <<= 1; + } + ret.push_back(coef); + } + return ret; +} + } // namespace yacl::math