Skip to content

Commit

Permalink
[fix] several improvements
Browse files Browse the repository at this point in the history
1. fix `memcpy` bug
2. replace hasher by lambda
3. use hash sort as shuffle
4. rename constants
  • Loading branch information
zhangwfjh committed Jan 25, 2024
1 parent 33965b5 commit 513a8fd
Showing 1 changed file with 45 additions and 51 deletions.
96 changes: 45 additions & 51 deletions yacl/crypto/primitives/psu/krtw19_psu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#include <algorithm>
#include <array>
#include <iterator>
#include <random>
#include <unordered_set>
#include <utility>

Expand All @@ -28,32 +27,27 @@ namespace yacl::crypto {
namespace {

// reference: https://eprint.iacr.org/2019/1234.pdf (Figure 2)
constexpr float ZETA{0.06f};
constexpr size_t BIN_SIZE{64ul}; // m+1
constexpr uint128_t BOT{};
constexpr size_t NUM_BINS_PER_BATCH{16ul};
constexpr size_t BATCH_SIZE{NUM_BINS_PER_BATCH * BIN_SIZE};

constexpr size_t NUM_BASE_OT{128ul};
constexpr size_t NUM_INKP_OT{512ul};

static std::random_device rd;
static std::mt19937 gen(rd());

struct U128Hasher {
size_t operator()(const uint128_t& x) const {
auto hash = Blake3_128({&x, sizeof x});
size_t ret;
std::memcpy(&ret, &hash, 1);
return ret;
}
constexpr float kZeta{0.06f};
constexpr size_t kBinSize{64ul}; // m+1
constexpr uint128_t kBot{};
constexpr size_t kNumBinsPerBatch{16ul};
constexpr size_t kBatchSize{kNumBinsPerBatch * kBinSize};

constexpr size_t kNumBaseOT{128ul};
constexpr size_t kNumInkpOT{512ul};

static auto HashToSizeT = [](const uint128_t& x) {
auto hash = Blake3_128({&x, sizeof x});
size_t ret;
std::memcpy(&ret, &hash, sizeof ret);
return ret;
};

auto HashInputs(const std::vector<uint128_t>& elem_hashes, size_t count) {
size_t num_bins = std::ceil(count * ZETA);
size_t num_bins = std::ceil(count * kZeta);
std::vector<std::vector<uint128_t>> hashing(num_bins);
for (auto elem : elem_hashes) {
auto hash = U128Hasher{}(elem);
auto hash = HashToSizeT(elem);
hashing[hash % num_bins].emplace_back(elem);
}
return hashing;
Expand Down Expand Up @@ -109,41 +103,41 @@ void KrtwPsuSend(const std::shared_ptr<yacl::link::Context>& ctx,

// Step 2. Prepares OPRF
KkrtOtExtReceiver receiver;
size_t num_ot{hashing.size() * BIN_SIZE};
auto choice = RandBits(NUM_BASE_OT);
auto base_ot = BaseOtRecv(ctx, choice, NUM_BASE_OT);
auto store = IknpOtExtSend(ctx, base_ot, NUM_INKP_OT);
size_t num_ot{hashing.size() * kBinSize};
auto choice = RandBits(kNumBaseOT);
auto base_ot = BaseOtRecv(ctx, choice, kNumBaseOT);
auto store = IknpOtExtSend(ctx, base_ot, kNumInkpOT);
receiver.Init(ctx, store, num_ot);
receiver.SetBatchSize(BATCH_SIZE);
receiver.SetBatchSize(kBatchSize);

std::vector<uint128_t> elems;
elems.reserve(num_ot);
size_t oprf_idx{};
for (size_t bin_idx{}; bin_idx != hashing.size(); ++bin_idx) {
if (bin_idx % NUM_BINS_PER_BATCH == 0) {
if (bin_idx % kNumBinsPerBatch == 0) {
receiver.SendCorrection(
ctx, std::min(BATCH_SIZE, (hashing.size() - bin_idx) * BIN_SIZE));
ctx, std::min(kBatchSize, (hashing.size() - bin_idx) * kBinSize));
}
hashing[bin_idx].resize(BIN_SIZE);
std::shuffle(hashing[bin_idx].begin(), hashing[bin_idx].end(), gen);
hashing[bin_idx].resize(kBinSize);
std::sort(hashing[bin_idx].begin(), hashing[bin_idx].end());
// Step 3. For each bin element, invokes PSU(1, m+1)
for (auto elem : hashing[bin_idx]) {
elems.emplace_back(elem);
uint64_t eval;
receiver.Encode(oprf_idx++, elem,
{reinterpret_cast<uint8_t*>(&eval), sizeof eval});
std::vector<uint64_t> coeffs(BIN_SIZE);
std::vector<uint64_t> coeffs(kBinSize);
auto buf = ctx->Recv(ctx->PrevRank(), "Receive coefficients");
std::memcpy(coeffs.data(), buf.data(), buf.size());
auto y = Evaluate(coeffs, U128Hasher{}(elem)) ^ eval;
auto y = Evaluate(coeffs, HashToSizeT(elem)) ^ eval;
ctx->SendAsync(ctx->NextRank(), SerializeUint128(y), "Send evaluation");
}
}

// Step 4. Sends new elements through OT
std::vector<std::array<uint128_t, 2>> keys(num_ot);
choice = SecureRandBits(NUM_BASE_OT);
base_ot = BaseOtRecv(ctx, choice, NUM_BASE_OT);
choice = SecureRandBits(kNumBaseOT);
base_ot = BaseOtRecv(ctx, choice, kNumBaseOT);
IknpOtExtSend(ctx, base_ot, absl::MakeSpan(keys));
std::vector<uint128_t> ciphers(num_ot);
for (size_t i{}; i != num_ot; ++i) {
Expand Down Expand Up @@ -171,30 +165,30 @@ std::vector<uint128_t> KrtwPsuRecv(

// Step 2. Prepares OPRF
KkrtOtExtSender sender;
size_t num_ot{hashing.size() * BIN_SIZE};
auto base_ot = BaseOtSend(ctx, NUM_BASE_OT);
auto choice = RandBits(NUM_INKP_OT);
auto store = IknpOtExtRecv(ctx, base_ot, choice, NUM_INKP_OT);
size_t num_ot{hashing.size() * kBinSize};
auto base_ot = BaseOtSend(ctx, kNumBaseOT);
auto choice = RandBits(kNumInkpOT);
auto store = IknpOtExtRecv(ctx, base_ot, choice, kNumInkpOT);
sender.Init(ctx, store, num_ot);
sender.SetBatchSize(BATCH_SIZE);
sender.SetBatchSize(kBatchSize);
auto oprf = sender.GetOprf();

yacl::dynamic_bitset<> ot_choice(num_ot);
size_t oprf_idx{};
// Step 3. For each bin, invokes PSU(1, m+1)
for (size_t bin_idx{}; bin_idx != hashing.size(); ++bin_idx) {
if (bin_idx % NUM_BINS_PER_BATCH == 0) {
if (bin_idx % kNumBinsPerBatch == 0) {
sender.RecvCorrection(
ctx, std::min(BATCH_SIZE, (hashing.size() - bin_idx) * BIN_SIZE));
ctx, std::min(kBatchSize, (hashing.size() - bin_idx) * kBinSize));
}
auto bin_size = hashing[bin_idx].size();
for (size_t elem_idx{}; elem_idx != BIN_SIZE; ++elem_idx, ++oprf_idx) {
for (size_t elem_idx{}; elem_idx != kBinSize; ++elem_idx, ++oprf_idx) {
auto seed = FastRandU64();
std::vector<uint64_t> xs(BIN_SIZE), ys(BIN_SIZE);
for (size_t i{}; i != BIN_SIZE; ++i) {
xs[i] = (i < bin_size ? U128Hasher{}(hashing[bin_idx][i])
std::vector<uint64_t> xs(kBinSize), ys(kBinSize);
for (size_t i{}; i != kBinSize; ++i) {
xs[i] = (i < bin_size ? HashToSizeT(hashing[bin_idx][i])
: i > bin_size ? FastRandU64()
: BOT);
: kBot);
ys[i] = oprf->Eval(oprf_idx, xs[i]) ^ seed;
}
auto coeffs = Interpolate(xs, ys);
Expand All @@ -208,16 +202,16 @@ std::vector<uint128_t> KrtwPsuRecv(

// Step 4. Receives new elements through OT
std::vector<uint128_t> keys(num_ot);
base_ot = BaseOtSend(ctx, NUM_BASE_OT);
base_ot = BaseOtSend(ctx, kNumBaseOT);
IknpOtExtRecv(ctx, base_ot, ot_choice, absl::MakeSpan(keys));
std::vector<uint128_t> ciphers(num_ot);
auto buf = ctx->Recv(ctx->PrevRank(), "Receive ciphertexts");
std::memcpy(ciphers.data(), buf.data(), buf.size());
std::unordered_set<uint128_t, U128Hasher> set_union(elem_hashes.begin(),
elem_hashes.end());
std::unordered_set<uint128_t, decltype(HashToSizeT)> set_union(
elem_hashes.begin(), elem_hashes.end(), count, HashToSizeT);
for (size_t i{}; i != num_ot; ++i) {
if (!ot_choice[i]) {
if (auto new_elem = ciphers[i] ^ keys[i]; new_elem != BOT) {
if (auto new_elem = ciphers[i] ^ keys[i]; new_elem != kBot) {
set_union.emplace(ciphers[i] ^ keys[i]);
}
}
Expand Down

0 comments on commit 513a8fd

Please sign in to comment.