From e08ae392f7e30021bffc950e35ca71faabf06c38 Mon Sep 17 00:00:00 2001 From: anakinxc <103552181+anakinxc@users.noreply.github.com> Date: Fri, 1 Dec 2023 08:32:35 +0000 Subject: [PATCH] Port sort fix --- libspu/kernel/hal/permute.cc | 16 +++++++++++++--- libspu/mpc/cheetah/yacl_ot/yacl_ote_adapter.cc | 8 ++++---- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/libspu/kernel/hal/permute.cc b/libspu/kernel/hal/permute.cc index ab0e1062..88804827 100644 --- a/libspu/kernel/hal/permute.cc +++ b/libspu/kernel/hal/permute.cc @@ -108,7 +108,12 @@ std::vector BitonicSort(SPUContext *ctx, // make a copy for inplace sort std::vector ret; for (auto const &input : inputs) { - ret.emplace_back(input.clone()); + if (input.isPublic()) { + // we can not linear_scatter a secret value to a public operand + ret.emplace_back(_p2s(ctx, input).setDtype(input.dtype())); + } else { + ret.emplace_back(input.clone()); + } } // sort by per network layer for memory optimizations, sorting N elements @@ -515,7 +520,7 @@ std::vector sort1d(SPUContext *ctx, for (const auto &input : inputs) { ret.push_back(Permute1D(ctx, input, indices_to_sort)); } - } else { + } else if (comparator_ret_vis == VIS_SECRET) { SPU_ENFORCE(!is_stable, "Stable sort is unsupported if comparator return is secret."); @@ -591,7 +596,12 @@ std::vector simple_sort1d(SPUContext *ctx, return result; }; - auto ret = sort1d(ctx, inputs, comp_fn, inputs[0].vtype(), false); + Visibility vis = + std::all_of(inputs.begin(), inputs.begin() + num_keys, + [](const spu::Value &v) { return v.isPublic(); }) + ? VIS_PUBLIC + : VIS_SECRET; + auto ret = sort1d(ctx, inputs, comp_fn, vis, false); return ret; } } diff --git a/libspu/mpc/cheetah/yacl_ot/yacl_ote_adapter.cc b/libspu/mpc/cheetah/yacl_ot/yacl_ote_adapter.cc index b1c7e912..cfb19990 100644 --- a/libspu/mpc/cheetah/yacl_ot/yacl_ote_adapter.cc +++ b/libspu/mpc/cheetah/yacl_ot/yacl_ote_adapter.cc @@ -211,13 +211,13 @@ void YaclFerretOTeAdapter::Bootstrap() { yacl::AlignedVector send_ot(ot_buff_.begin(), ot_buff_.begin() + reserve_num_); auto send_ot_store = yc::MakeCompactOtSendStore(std::move(send_ot), Delta); - yc::FerretOtExtSend_Cheetah(ctx_, send_ot_store, lpn_param_, lpn_param_.n, + yc::FerretOtExtSend_cheetah(ctx_, send_ot_store, lpn_param_, lpn_param_.n, absl::MakeSpan(ot_buff_.data(), lpn_param_.n)); } else { yacl::AlignedVector recv_ot(ot_buff_.begin(), ot_buff_.begin() + reserve_num_); auto recv_ot_store = yc::MakeCompactOtRecvStore(std::move(recv_ot)); - yc::FerretOtExtRecv_Cheetah(ctx_, recv_ot_store, lpn_param_, lpn_param_.n, + yc::FerretOtExtRecv_cheetah(ctx_, recv_ot_store, lpn_param_, lpn_param_.n, absl::MakeSpan(ot_buff_.data(), lpn_param_.n)); } auto end = std::chrono::high_resolution_clock::now(); @@ -244,11 +244,11 @@ void YaclFerretOTeAdapter::BootstrapInplace(absl::Span ot, auto begin = std::chrono::high_resolution_clock::now(); if (is_sender_) { auto send_ot_store = yc::MakeCompactOtSendStore(std::move(ot_tmp), Delta); - yc::FerretOtExtSend_Cheetah(ctx_, send_ot_store, lpn_param_, lpn_param_.n, + yc::FerretOtExtSend_cheetah(ctx_, send_ot_store, lpn_param_, lpn_param_.n, data); } else { auto recv_ot_store = yc::MakeCompactOtRecvStore(std::move(ot_tmp)); - yc::FerretOtExtRecv_Cheetah(ctx_, recv_ot_store, lpn_param_, lpn_param_.n, + yc::FerretOtExtRecv_cheetah(ctx_, recv_ot_store, lpn_param_, lpn_param_.n, data); } auto end = std::chrono::high_resolution_clock::now();