diff --git a/yacl/crypto/primitives/psu/BUILD.bazel b/yacl/crypto/primitives/psu/BUILD.bazel new file mode 100644 index 00000000..3c46ea68 --- /dev/null +++ b/yacl/crypto/primitives/psu/BUILD.bazel @@ -0,0 +1,46 @@ +# Copyright 2024 zhangwfjh +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//bazel:yacl.bzl", "AES_COPT_FLAGS", "yacl_cc_library", "yacl_cc_test") + +package(default_visibility = ["//visibility:public"]) + +yacl_cc_library( + name = "krtw19_psu", + srcs = [ + "krtw19_psu.cc", + ], + hdrs = [ + "krtw19_psu.h", + ], + copts = AES_COPT_FLAGS, + deps = [ + "//yacl/base:exception", + "//yacl/base:int128", + "//yacl/crypto/base/hash:hash_utils", + "//yacl/crypto/primitives/ot:base_ot", + "//yacl/crypto/primitives/ot:iknp_ote", + "//yacl/crypto/primitives/ot:kkrt_ote", + "//yacl/crypto/utils:rand", + "//yacl/link", + "//yacl/math/f2k", + "@com_google_absl//absl/types:span", + ], +) + +yacl_cc_test( + name = "krtw19_psu_test", + srcs = ["krtw19_psu_test.cc"], + deps = [":krtw19_psu"], +) diff --git a/yacl/crypto/primitives/psu/krtw19_psu.cc b/yacl/crypto/primitives/psu/krtw19_psu.cc new file mode 100644 index 00000000..7a3f69ec --- /dev/null +++ b/yacl/crypto/primitives/psu/krtw19_psu.cc @@ -0,0 +1,222 @@ +// Copyright 2024 zhangwfjh +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "yacl/crypto/primitives/psu/krtw19_psu.h" + +#include +#include +#include +#include +#include + +#include "yacl/utils/serialize.h" + +namespace yacl::crypto { + +namespace { + +// reference: https://eprint.iacr.org/2019/1234.pdf (Figure 2) +constexpr float kZeta{0.06f}; +constexpr size_t kBinSize{64ul}; // m+1 +constexpr uint128_t kBot{}; +constexpr size_t kNumBinsPerBatch{16ul}; +constexpr size_t kBatchSize{kNumBinsPerBatch * kBinSize}; + +constexpr size_t kNumBaseOT{128ul}; +constexpr size_t kNumInkpOT{512ul}; + +static auto HashToSizeT = [](const uint128_t& x) { + auto hash = Blake3_128({&x, sizeof x}); + size_t ret; + std::memcpy(&ret, &hash, sizeof ret); + return ret; +}; + +auto HashInputs(const std::vector& elem_hashes, size_t count) { + size_t num_bins = std::ceil(count * kZeta); + std::vector> hashing(num_bins); + for (auto elem : elem_hashes) { + auto hash = HashToSizeT(elem); + hashing[hash % num_bins].emplace_back(elem); + } + return hashing; +} + +uint64_t Evaluate(const std::vector& coeffs, uint64_t x) { + uint64_t y{coeffs.back()}; + for (auto it = std::next(coeffs.rbegin()); it != coeffs.rend(); ++it) { + y = GfMul64(y, x) ^ *it; + } + return y; +} + +auto Interpolate(const std::vector& xs, + const std::vector& ys) { + YACL_ENFORCE_EQ(xs.size(), ys.size(), "Sizes mismatch."); + size_t size{xs.size()}; + std::vector L_coeffs(size); + for (size_t i{}; i != size; ++i) { + std::vector Li_coeffs(size); + Li_coeffs[0] = ys[i]; + uint64_t prod{1}; + for (size_t j{}; j != size; ++j) { + if (xs[i] != xs[j]) { + prod = GfMul64(prod, xs[i] ^ xs[j]); + uint64_t sum{}; + for (size_t k{}; k != size; ++k) { + sum = std::exchange(Li_coeffs[k], GfMul64(Li_coeffs[k], xs[j]) ^ sum); + } + } + } + for (size_t k{}; k != size; ++k) { + L_coeffs[k] ^= GfMul64(Li_coeffs[k], Inv64(prod)); + } + } + return L_coeffs; +} + +} // namespace + +void KrtwPsuSend(const std::shared_ptr& ctx, + const std::vector& elem_hashes) { + ctx->SendAsync(ctx->NextRank(), SerializeUint128(elem_hashes.size()), + "Send set size"); + size_t peer_count = + DeserializeUint128(ctx->Recv(ctx->PrevRank(), "Receive set size")); + auto count = std::max(elem_hashes.size(), peer_count); + if (count == 0) { + return; + } + // Step 1. Hashes inputs + auto hashing = HashInputs(elem_hashes, count); + + // Step 2. Prepares OPRF + KkrtOtExtReceiver receiver; + size_t num_ot{hashing.size() * kBinSize}; + auto choice = RandBits(kNumBaseOT); + auto base_ot = BaseOtRecv(ctx, choice, kNumBaseOT); + auto store = IknpOtExtSend(ctx, base_ot, kNumInkpOT); + receiver.Init(ctx, store, num_ot); + receiver.SetBatchSize(kBatchSize); + + std::vector elems; + elems.reserve(num_ot); + size_t oprf_idx{}; + for (size_t bin_idx{}; bin_idx != hashing.size(); ++bin_idx) { + if (bin_idx % kNumBinsPerBatch == 0) { + receiver.SendCorrection( + ctx, std::min(kBatchSize, (hashing.size() - bin_idx) * kBinSize)); + } + hashing[bin_idx].resize(kBinSize); + std::sort(hashing[bin_idx].begin(), hashing[bin_idx].end()); + // Step 3. For each bin element, invokes PSU(1, m+1) + for (auto elem : hashing[bin_idx]) { + elems.emplace_back(elem); + uint64_t eval; + receiver.Encode(oprf_idx++, elem, + {reinterpret_cast(&eval), sizeof eval}); + std::vector coeffs(kBinSize); + auto buf = ctx->Recv(ctx->PrevRank(), "Receive coefficients"); + std::memcpy(coeffs.data(), buf.data(), buf.size()); + auto y = Evaluate(coeffs, HashToSizeT(elem)) ^ eval; + ctx->SendAsync(ctx->NextRank(), SerializeUint128(y), "Send evaluation"); + } + } + + // Step 4. Sends new elements through OT + std::vector> keys(num_ot); + choice = SecureRandBits(kNumBaseOT); + base_ot = BaseOtRecv(ctx, choice, kNumBaseOT); + IknpOtExtSend(ctx, base_ot, absl::MakeSpan(keys)); + std::vector ciphers(num_ot); + for (size_t i{}; i != num_ot; ++i) { + ciphers[i] = elems[i] ^ keys[i][0]; + } + ctx->SendAsync(ctx->NextRank(), + yacl::Buffer{reinterpret_cast(ciphers.data()), + ciphers.size() * sizeof(uint128_t)}, + "Send ciphertexts"); +} + +std::vector KrtwPsuRecv( + const std::shared_ptr& ctx, + const std::vector& elem_hashes) { + size_t peer_count = + DeserializeUint128(ctx->Recv(ctx->PrevRank(), "Receive set size")); + ctx->SendAsync(ctx->NextRank(), SerializeUint128(elem_hashes.size()), + "Send set size"); + auto count = std::max(elem_hashes.size(), peer_count); + if (count == 0) { + return {}; + } + // Step 1. Hashes inputs + auto hashing = HashInputs(elem_hashes, count); + + // Step 2. Prepares OPRF + KkrtOtExtSender sender; + size_t num_ot{hashing.size() * kBinSize}; + auto base_ot = BaseOtSend(ctx, kNumBaseOT); + auto choice = RandBits(kNumInkpOT); + auto store = IknpOtExtRecv(ctx, base_ot, choice, kNumInkpOT); + sender.Init(ctx, store, num_ot); + sender.SetBatchSize(kBatchSize); + auto oprf = sender.GetOprf(); + + yacl::dynamic_bitset<> ot_choice(num_ot); + size_t oprf_idx{}; + // Step 3. For each bin, invokes PSU(1, m+1) + for (size_t bin_idx{}; bin_idx != hashing.size(); ++bin_idx) { + if (bin_idx % kNumBinsPerBatch == 0) { + sender.RecvCorrection( + ctx, std::min(kBatchSize, (hashing.size() - bin_idx) * kBinSize)); + } + auto bin_size = hashing[bin_idx].size(); + for (size_t elem_idx{}; elem_idx != kBinSize; ++elem_idx, ++oprf_idx) { + auto seed = FastRandU64(); + std::vector xs(kBinSize), ys(kBinSize); + for (size_t i{}; i != kBinSize; ++i) { + xs[i] = (i < bin_size ? HashToSizeT(hashing[bin_idx][i]) + : i > bin_size ? FastRandU64() + : kBot); + ys[i] = oprf->Eval(oprf_idx, xs[i]) ^ seed; + } + auto coeffs = Interpolate(xs, ys); + yacl::Buffer buf(coeffs.data(), coeffs.size() * sizeof(uint64_t)); + ctx->SendAsync(ctx->NextRank(), buf, "Send coefficients"); + auto eval = + DeserializeUint128(ctx->Recv(ctx->PrevRank(), "Receive evaluation")); + ot_choice[oprf_idx] = eval == seed; + } + } + + // Step 4. Receives new elements through OT + std::vector keys(num_ot); + base_ot = BaseOtSend(ctx, kNumBaseOT); + IknpOtExtRecv(ctx, base_ot, ot_choice, absl::MakeSpan(keys)); + std::vector ciphers(num_ot); + auto buf = ctx->Recv(ctx->PrevRank(), "Receive ciphertexts"); + std::memcpy(ciphers.data(), buf.data(), buf.size()); + std::unordered_set set_union( + elem_hashes.begin(), elem_hashes.end(), count, HashToSizeT); + for (size_t i{}; i != num_ot; ++i) { + if (!ot_choice[i]) { + if (auto new_elem = ciphers[i] ^ keys[i]; new_elem != kBot) { + set_union.emplace(ciphers[i] ^ keys[i]); + } + } + } + return std::vector(set_union.begin(), set_union.end()); +} + +} // namespace yacl::crypto diff --git a/yacl/crypto/primitives/psu/krtw19_psu.h b/yacl/crypto/primitives/psu/krtw19_psu.h new file mode 100644 index 00000000..c06535b1 --- /dev/null +++ b/yacl/crypto/primitives/psu/krtw19_psu.h @@ -0,0 +1,46 @@ +// Copyright 2024 zhangwfjh +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "yacl/base/int128.h" +#include "yacl/crypto/utils/secparam.h" +#include "yacl/link/link.h" +#include "yacl/math/f2k/f2k.h" + +/* submodules */ +#include "yacl/crypto/base/hash/hash_utils.h" +#include "yacl/crypto/primitives/ot/base_ot.h" +#include "yacl/crypto/primitives/ot/iknp_ote.h" +#include "yacl/crypto/primitives/ot/kkrt_ote.h" +#include "yacl/crypto/utils/rand.h" + +/* security parameter declaration */ +YACL_MODULE_DECLARE("krtw_psu", SecParam::C::k128, SecParam::S::k40); + +namespace yacl::crypto { + +// Scalable Private Set Union from Symmetric-Key Techniques +// https://eprint.iacr.org/2019/776.pdf (Figure 10) + +void KrtwPsuSend(const std::shared_ptr&, + const std::vector&); + +std::vector KrtwPsuRecv(const std::shared_ptr&, + const std::vector&); + +} // namespace yacl::crypto diff --git a/yacl/crypto/primitives/psu/krtw19_psu_test.cc b/yacl/crypto/primitives/psu/krtw19_psu_test.cc new file mode 100644 index 00000000..c498b930 --- /dev/null +++ b/yacl/crypto/primitives/psu/krtw19_psu_test.cc @@ -0,0 +1,82 @@ +// Copyright 2024 zhangwfjh +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "yacl/crypto/primitives/psu/krtw19_psu.h" + +#include +#include +#include +#include + +#include "gtest/gtest.h" + +#include "yacl/base/exception.h" +#include "yacl/crypto/base/hash/hash_utils.h" +#include "yacl/crypto/utils/secparam.h" +#include "yacl/link/test_util.h" + +struct TestParams { + std::vector items_a; + std::vector items_b; +}; + +namespace yacl::crypto { + +class KrtwPsuTest : public testing::TestWithParam {}; + +TEST_P(KrtwPsuTest, Works) { + auto params = GetParam(); + const int kWorldSize = 2; + auto contexts = yacl::link::test::SetupWorld(kWorldSize); + + std::future krtwpsu_sender = + std::async([&] { return KrtwPsuSend(contexts[0], params.items_a); }); + std::future> krtwpsu_receiver = + std::async([&] { return KrtwPsuRecv(contexts[1], params.items_b); }); + + krtwpsu_sender.get(); + auto psu_result = krtwpsu_receiver.get(); + std::sort(psu_result.begin(), psu_result.end()); + + std::set union_set; + union_set.insert(params.items_a.begin(), params.items_a.end()); + union_set.insert(params.items_b.begin(), params.items_b.end()); + std::vector union_vec(union_set.begin(), union_set.end()); + + EXPECT_EQ(psu_result, union_vec); +} + +std::vector CreateRangeItems(size_t begin, size_t size) { + std::vector ret; + for (size_t i = 0; i < size; i++) { + ret.push_back(Blake3_128(std::to_string(begin + i))); + } + return ret; +} + +INSTANTIATE_TEST_SUITE_P( + Works_Instances, KrtwPsuTest, + testing::Values( + TestParams{{}, {}}, // + TestParams{{}, {Blake3_128("a")}}, // + TestParams{{Blake3_128("a")}, {}}, // + // No overlap + TestParams{CreateRangeItems(0, 1024), CreateRangeItems(1024, 1024)}, // + // Partial overlap + TestParams{CreateRangeItems(0, 1024), CreateRangeItems(512, 1024)}, // + // Complete overlap + TestParams{CreateRangeItems(0, 1024), CreateRangeItems(0, 1024)} // + )); + +} // namespace yacl::crypto diff --git a/yacl/math/f2k/f2k.h b/yacl/math/f2k/f2k.h index f7da1d1e..9896aec5 100644 --- a/yacl/math/f2k/f2k.h +++ b/yacl/math/f2k/f2k.h @@ -139,6 +139,56 @@ inline uint64_t GfMul64(uint64_t x, uint64_t y) { return Reduce64(ClMul64(x, y)); } +// inverse over Galois Field F_{2^64} +inline uint64_t Inv64(uint64_t x) { + uint64_t t0 = x; + uint64_t t1 = GfMul64(t0, t0); + uint64_t t2 = GfMul64(t1, t0); + t0 = GfMul64(t2, t2); + t0 = GfMul64(t0, t0); + t1 = GfMul64(t1, t0); + t2 = GfMul64(t2, t0); + t0 = GfMul64(t2, t2); + t0 = GfMul64(t0, t0); + t0 = GfMul64(t0, t0); + t0 = GfMul64(t0, t0); + t1 = GfMul64(t1, t0); + t2 = GfMul64(t2, t0); + t0 = GfMul64(t2, t2); + t0 = GfMul64(t0, t0); + t0 = GfMul64(t0, t0); + t0 = GfMul64(t0, t0); + t0 = GfMul64(t0, t0); + t0 = GfMul64(t0, t0); + t0 = GfMul64(t0, t0); + t0 = GfMul64(t0, t0); + t1 = GfMul64(t1, t0); + t2 = GfMul64(t2, t0); + t0 = GfMul64(t2, t2); + t0 = GfMul64(t0, t0); + t0 = GfMul64(t0, t0); + t0 = GfMul64(t0, t0); + t0 = GfMul64(t0, t0); + t0 = GfMul64(t0, t0); + t0 = GfMul64(t0, t0); + t0 = GfMul64(t0, t0); + t0 = GfMul64(t0, t0); + t0 = GfMul64(t0, t0); + t0 = GfMul64(t0, t0); + t0 = GfMul64(t0, t0); + t0 = GfMul64(t0, t0); + t0 = GfMul64(t0, t0); + t0 = GfMul64(t0, t0); + t0 = GfMul64(t0, t0); + t1 = GfMul64(t1, t0); + t0 = GfMul64(t0, t2); + for (int i = 0; i < 32; i++) { + t0 = GfMul64(t0, t0); + } + t0 = GfMul64(t0, t1); + return t0; +} + // Inner product inline std::pair ClMul128(absl::Span x, absl::Span y) {