From 513a8fd23f48c981c57bdf439cd720509bca5200 Mon Sep 17 00:00:00 2001 From: Shaun Zhang Date: Thu, 25 Jan 2024 19:29:56 +0800 Subject: [PATCH] [fix] several improvements 1. fix `memcpy` bug 2. replace hasher by lambda 3. use hash sort as shuffle 4. rename constants --- yacl/crypto/primitives/psu/krtw19_psu.cc | 96 +++++++++++------------- 1 file changed, 45 insertions(+), 51 deletions(-) diff --git a/yacl/crypto/primitives/psu/krtw19_psu.cc b/yacl/crypto/primitives/psu/krtw19_psu.cc index 3f4961c2..7a3f69ec 100644 --- a/yacl/crypto/primitives/psu/krtw19_psu.cc +++ b/yacl/crypto/primitives/psu/krtw19_psu.cc @@ -17,7 +17,6 @@ #include #include #include -#include #include #include @@ -28,32 +27,27 @@ namespace yacl::crypto { namespace { // reference: https://eprint.iacr.org/2019/1234.pdf (Figure 2) -constexpr float ZETA{0.06f}; -constexpr size_t BIN_SIZE{64ul}; // m+1 -constexpr uint128_t BOT{}; -constexpr size_t NUM_BINS_PER_BATCH{16ul}; -constexpr size_t BATCH_SIZE{NUM_BINS_PER_BATCH * BIN_SIZE}; - -constexpr size_t NUM_BASE_OT{128ul}; -constexpr size_t NUM_INKP_OT{512ul}; - -static std::random_device rd; -static std::mt19937 gen(rd()); - -struct U128Hasher { - size_t operator()(const uint128_t& x) const { - auto hash = Blake3_128({&x, sizeof x}); - size_t ret; - std::memcpy(&ret, &hash, 1); - return ret; - } +constexpr float kZeta{0.06f}; +constexpr size_t kBinSize{64ul}; // m+1 +constexpr uint128_t kBot{}; +constexpr size_t kNumBinsPerBatch{16ul}; +constexpr size_t kBatchSize{kNumBinsPerBatch * kBinSize}; + +constexpr size_t kNumBaseOT{128ul}; +constexpr size_t kNumInkpOT{512ul}; + +static auto HashToSizeT = [](const uint128_t& x) { + auto hash = Blake3_128({&x, sizeof x}); + size_t ret; + std::memcpy(&ret, &hash, sizeof ret); + return ret; }; auto HashInputs(const std::vector& elem_hashes, size_t count) { - size_t num_bins = std::ceil(count * ZETA); + size_t num_bins = std::ceil(count * kZeta); std::vector> hashing(num_bins); for (auto elem : elem_hashes) { - auto hash = U128Hasher{}(elem); + auto hash = HashToSizeT(elem); hashing[hash % num_bins].emplace_back(elem); } return hashing; @@ -109,41 +103,41 @@ void KrtwPsuSend(const std::shared_ptr& ctx, // Step 2. Prepares OPRF KkrtOtExtReceiver receiver; - size_t num_ot{hashing.size() * BIN_SIZE}; - auto choice = RandBits(NUM_BASE_OT); - auto base_ot = BaseOtRecv(ctx, choice, NUM_BASE_OT); - auto store = IknpOtExtSend(ctx, base_ot, NUM_INKP_OT); + size_t num_ot{hashing.size() * kBinSize}; + auto choice = RandBits(kNumBaseOT); + auto base_ot = BaseOtRecv(ctx, choice, kNumBaseOT); + auto store = IknpOtExtSend(ctx, base_ot, kNumInkpOT); receiver.Init(ctx, store, num_ot); - receiver.SetBatchSize(BATCH_SIZE); + receiver.SetBatchSize(kBatchSize); std::vector elems; elems.reserve(num_ot); size_t oprf_idx{}; for (size_t bin_idx{}; bin_idx != hashing.size(); ++bin_idx) { - if (bin_idx % NUM_BINS_PER_BATCH == 0) { + if (bin_idx % kNumBinsPerBatch == 0) { receiver.SendCorrection( - ctx, std::min(BATCH_SIZE, (hashing.size() - bin_idx) * BIN_SIZE)); + ctx, std::min(kBatchSize, (hashing.size() - bin_idx) * kBinSize)); } - hashing[bin_idx].resize(BIN_SIZE); - std::shuffle(hashing[bin_idx].begin(), hashing[bin_idx].end(), gen); + hashing[bin_idx].resize(kBinSize); + std::sort(hashing[bin_idx].begin(), hashing[bin_idx].end()); // Step 3. For each bin element, invokes PSU(1, m+1) for (auto elem : hashing[bin_idx]) { elems.emplace_back(elem); uint64_t eval; receiver.Encode(oprf_idx++, elem, {reinterpret_cast(&eval), sizeof eval}); - std::vector coeffs(BIN_SIZE); + std::vector coeffs(kBinSize); auto buf = ctx->Recv(ctx->PrevRank(), "Receive coefficients"); std::memcpy(coeffs.data(), buf.data(), buf.size()); - auto y = Evaluate(coeffs, U128Hasher{}(elem)) ^ eval; + auto y = Evaluate(coeffs, HashToSizeT(elem)) ^ eval; ctx->SendAsync(ctx->NextRank(), SerializeUint128(y), "Send evaluation"); } } // Step 4. Sends new elements through OT std::vector> keys(num_ot); - choice = SecureRandBits(NUM_BASE_OT); - base_ot = BaseOtRecv(ctx, choice, NUM_BASE_OT); + choice = SecureRandBits(kNumBaseOT); + base_ot = BaseOtRecv(ctx, choice, kNumBaseOT); IknpOtExtSend(ctx, base_ot, absl::MakeSpan(keys)); std::vector ciphers(num_ot); for (size_t i{}; i != num_ot; ++i) { @@ -171,30 +165,30 @@ std::vector KrtwPsuRecv( // Step 2. Prepares OPRF KkrtOtExtSender sender; - size_t num_ot{hashing.size() * BIN_SIZE}; - auto base_ot = BaseOtSend(ctx, NUM_BASE_OT); - auto choice = RandBits(NUM_INKP_OT); - auto store = IknpOtExtRecv(ctx, base_ot, choice, NUM_INKP_OT); + size_t num_ot{hashing.size() * kBinSize}; + auto base_ot = BaseOtSend(ctx, kNumBaseOT); + auto choice = RandBits(kNumInkpOT); + auto store = IknpOtExtRecv(ctx, base_ot, choice, kNumInkpOT); sender.Init(ctx, store, num_ot); - sender.SetBatchSize(BATCH_SIZE); + sender.SetBatchSize(kBatchSize); auto oprf = sender.GetOprf(); yacl::dynamic_bitset<> ot_choice(num_ot); size_t oprf_idx{}; // Step 3. For each bin, invokes PSU(1, m+1) for (size_t bin_idx{}; bin_idx != hashing.size(); ++bin_idx) { - if (bin_idx % NUM_BINS_PER_BATCH == 0) { + if (bin_idx % kNumBinsPerBatch == 0) { sender.RecvCorrection( - ctx, std::min(BATCH_SIZE, (hashing.size() - bin_idx) * BIN_SIZE)); + ctx, std::min(kBatchSize, (hashing.size() - bin_idx) * kBinSize)); } auto bin_size = hashing[bin_idx].size(); - for (size_t elem_idx{}; elem_idx != BIN_SIZE; ++elem_idx, ++oprf_idx) { + for (size_t elem_idx{}; elem_idx != kBinSize; ++elem_idx, ++oprf_idx) { auto seed = FastRandU64(); - std::vector xs(BIN_SIZE), ys(BIN_SIZE); - for (size_t i{}; i != BIN_SIZE; ++i) { - xs[i] = (i < bin_size ? U128Hasher{}(hashing[bin_idx][i]) + std::vector xs(kBinSize), ys(kBinSize); + for (size_t i{}; i != kBinSize; ++i) { + xs[i] = (i < bin_size ? HashToSizeT(hashing[bin_idx][i]) : i > bin_size ? FastRandU64() - : BOT); + : kBot); ys[i] = oprf->Eval(oprf_idx, xs[i]) ^ seed; } auto coeffs = Interpolate(xs, ys); @@ -208,16 +202,16 @@ std::vector KrtwPsuRecv( // Step 4. Receives new elements through OT std::vector keys(num_ot); - base_ot = BaseOtSend(ctx, NUM_BASE_OT); + base_ot = BaseOtSend(ctx, kNumBaseOT); IknpOtExtRecv(ctx, base_ot, ot_choice, absl::MakeSpan(keys)); std::vector ciphers(num_ot); auto buf = ctx->Recv(ctx->PrevRank(), "Receive ciphertexts"); std::memcpy(ciphers.data(), buf.data(), buf.size()); - std::unordered_set set_union(elem_hashes.begin(), - elem_hashes.end()); + std::unordered_set set_union( + elem_hashes.begin(), elem_hashes.end(), count, HashToSizeT); for (size_t i{}; i != num_ot; ++i) { if (!ot_choice[i]) { - if (auto new_elem = ciphers[i] ^ keys[i]; new_elem != BOT) { + if (auto new_elem = ciphers[i] ^ keys[i]; new_elem != kBot) { set_union.emplace(ciphers[i] ^ keys[i]); } }