Skip to content

Commit

Permalink
repo-sync-2024-01-12T14:46:04+0800
Browse files Browse the repository at this point in the history
  • Loading branch information
anakinxc committed Jan 12, 2024
1 parent 43a9e3e commit 8dc91d6
Show file tree
Hide file tree
Showing 13 changed files with 64 additions and 31 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/scorecard.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ jobs:

steps:
- name: "Checkout code"
uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1
uses: actions/checkout@f43a0e5ff2bd294095638e18286ca9a3d1956744 # v3.6.0
with:
persist-credentials: false

- name: "Run analysis"
uses: ossf/scorecard-action@0864cf19026789058feabb7e87baa5f140aac736 # v2.3.1
uses: ossf/scorecard-action@e38b1902ae4f44df626f11ba0734b14fb91f8f86 # v2.1.2
with:
results_file: results.sarif
results_format: sarif
Expand Down Expand Up @@ -67,6 +67,6 @@ jobs:

# Upload the results to GitHub's code scanning dashboard.
- name: "Upload to code-scanning"
uses: github/codeql-action/upload-sarif@8b7fcbfac2aae0e6c24d9f9ebd5830b1290b18e4 # v2.23.0
uses: github/codeql-action/upload-sarif@17573ee1cc1b9d061760f3a006fc4aac4f944fd5 # v2.2.4
with:
sarif_file: results.sarif
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
> please add your unreleased change here.
- [Improvement] Optimize one-time setup for yacl ot
- [Improvement] Optimize sort performance

## 20240105

Expand All @@ -21,7 +22,7 @@
- [Feature] Add equal support for SEMI2K and ABY3
- [Improvement] Optimize sort memory usage
- [Improvement] Improve compatibility with latest Jax
- [Bugfix] Fix compilation cache collision under certian cases
- [Bugfix] Fix compilation cache collision under certain cases
- [Deprecated] macOS 11.x is no longer supported

## 20231108
Expand Down
4 changes: 2 additions & 2 deletions bazel/repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,9 @@ def _com_github_openxla_xla():
maybe(
http_archive,
name = "bazel_skylib",
sha256 = "cd55a062e763b9349921f0f5db8c3933288dc8ba4f76dd9416aac68acee3cb94",
sha256 = "74d544d96f4a5bb630d465ca8bbcfe231e3594e5aae57e1edbf17a6eb3ca2506",
urls = [
"https://github.com/bazelbuild/bazel-skylib/releases/download/1.5.0/bazel-skylib-1.5.0.tar.gz",
"https://github.com/bazelbuild/bazel-skylib/releases/download/1.3.0/bazel-skylib-1.3.0.tar.gz",
],
)

Expand Down
3 changes: 1 addition & 2 deletions libspu/compiler/tests/enum_conversion_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ TEST(EnumConversion, Public) {
mlir::pphlo::symbolizeEnum<mlir::pphlo::Visibility>(Visibility_Name(v)); \
EXPECT_EQ(mlir_v, mlir::pphlo::Visibility::T);

{ CHECK(VIS_PUBLIC) }
{ CHECK(VIS_SECRET) }
{CHECK(VIS_PUBLIC)} { CHECK(VIS_SECRET) }

#undef CHECK
}
Expand Down
20 changes: 10 additions & 10 deletions libspu/mpc/ab_api_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,10 @@ bool verifyCost(Kernel* kernel, std::string_view name, FieldType field,
/* WHEN */ \
auto a0 = p2a(obj.get(), p0); \
auto a1 = p2a(obj.get(), p1); \
auto prev = obj->prot() -> getState<Communicator>() -> getStats(); \
auto prev = obj->prot()->getState<Communicator>()->getStats(); \
auto tmp = OP##_aa(obj.get(), a0, a1); \
auto cost = \
obj->prot() -> getState<Communicator>() -> getStats() - prev; \
obj->prot()->getState<Communicator>()->getStats() - prev; \
auto re = a2p(obj.get(), tmp); \
auto rp = OP##_pp(obj.get(), p0, p1); \
\
Expand All @@ -131,10 +131,10 @@ bool verifyCost(Kernel* kernel, std::string_view name, FieldType field,
\
/* WHEN */ \
auto a0 = p2a(obj.get(), p0); \
auto prev = obj->prot() -> getState<Communicator>() -> getStats(); \
auto prev = obj->prot()->getState<Communicator>()->getStats(); \
auto tmp = OP##_ap(obj.get(), a0, p1); \
auto cost = \
obj->prot() -> getState<Communicator>() -> getStats() - prev; \
obj->prot()->getState<Communicator>()->getStats() - prev; \
auto re = a2p(obj.get(), tmp); \
auto rp = OP##_pp(obj.get(), p0, p1); \
\
Expand Down Expand Up @@ -480,10 +480,10 @@ TEST_P(ArithmeticTest, A2P) {
/* WHEN */ \
auto b0 = p2b(obj.get(), p0); \
auto b1 = p2b(obj.get(), p1); \
auto prev = obj->prot() -> getState<Communicator>() -> getStats(); \
auto prev = obj->prot()->getState<Communicator>()->getStats(); \
auto tmp = OP##_bb(obj.get(), b0, b1); \
auto cost = \
obj->prot() -> getState<Communicator>() -> getStats() - prev; \
obj->prot()->getState<Communicator>()->getStats() - prev; \
auto re = b2p(obj.get(), tmp); \
auto rp = OP##_pp(obj.get(), p0, p1); \
\
Expand All @@ -510,10 +510,10 @@ TEST_P(ArithmeticTest, A2P) {
\
/* WHEN */ \
auto b0 = p2b(obj.get(), p0); \
auto prev = obj->prot() -> getState<Communicator>() -> getStats(); \
auto prev = obj->prot()->getState<Communicator>()->getStats(); \
auto tmp = OP##_bp(obj.get(), b0, p1); \
auto cost = \
obj->prot() -> getState<Communicator>() -> getStats() - prev; \
obj->prot()->getState<Communicator>()->getStats() - prev; \
auto re = b2p(obj.get(), tmp); \
auto rp = OP##_pp(obj.get(), p0, p1); \
\
Expand Down Expand Up @@ -550,10 +550,10 @@ TEST_BOOLEAN_BINARY_OP(xor)
continue; \
} \
/* WHEN */ \
auto prev = obj->prot() -> getState<Communicator>() -> getStats(); \
auto prev = obj->prot()->getState<Communicator>()->getStats(); \
auto tmp = OP##_b(obj.get(), b0, bits); \
auto cost = \
obj->prot() -> getState<Communicator>() -> getStats() - prev; \
obj->prot()->getState<Communicator>()->getStats() - prev; \
auto r_b = b2p(obj.get(), tmp); \
auto r_p = OP##_p(obj.get(), p0, bits); \
\
Expand Down
10 changes: 6 additions & 4 deletions libspu/mpc/cheetah/arith/cheetah_dot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ struct CheetahDot::Impl : public EnableCPRNG {
std::shared_ptr<yacl::link::Context> lctx_;
bool disable_pack_ = false;

mutable std::shared_mutex context_lock_;
// field_bitlen -> functor mapping
std::unordered_map<size_t, std::shared_ptr<seal::SEALContext>> seal_cntxts_;
std::unordered_map<size_t, seal::SEALContext> galoi_cntxts_;
Expand All @@ -197,7 +196,6 @@ struct CheetahDot::Impl : public EnableCPRNG {
};

void CheetahDot::Impl::LazyInitGaloisKey(size_t field_bitlen) {
// NOTE: make sure context_lock_ is obtained.
if (galoi_cntxts_.find(field_bitlen) != galoi_cntxts_.end()) {
return;
}
Expand Down Expand Up @@ -233,7 +231,6 @@ void CheetahDot::Impl::LazyInitGaloisKey(size_t field_bitlen) {
}

void CheetahDot::Impl::LazyInit(size_t field_bitlen, bool need_galois_keys) {
std::unique_lock guard(context_lock_);
if (seal_cntxts_.find(field_bitlen) != seal_cntxts_.end()) {
if (need_galois_keys) {
LazyInitGaloisKey(field_bitlen);
Expand Down Expand Up @@ -760,6 +757,11 @@ CheetahDot::CheetahDot(const std::shared_ptr<yacl::link::Context> &lctx,

CheetahDot::~CheetahDot() = default;

void CheetahDot::LazyInitKeys(FieldType field) {
SPU_ENFORCE(impl_ != nullptr);
return impl_->LazyInit(SizeOf(field) * 8, /*create_galois*/ true);
}

NdArrayRef CheetahDot::DotOLE(const NdArrayRef &inp, yacl::link::Context *conn,
const Shape3D &dim3, bool is_self_lhs) {
SPU_ENFORCE(impl_ != nullptr);
Expand All @@ -780,4 +782,4 @@ NdArrayRef CheetahDot::BatchDotOLE(const NdArrayRef &inp,
return impl_->BatchDotOLE(inp, conn, dim4, is_self_lhs);
}

} // namespace spu::mpc::cheetah
} // namespace spu::mpc::cheetah
5 changes: 5 additions & 0 deletions libspu/mpc/cheetah/arith/cheetah_dot.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,19 @@ class CheetahDot {

CheetahDot(CheetahDot&&) = delete;

void LazyInitKeys(FieldType field);

// make sure to call InitKeys first
NdArrayRef DotOLE(const NdArrayRef& inp, const Shape3D& dim3,
bool is_self_lhs);

// LHS.shape MxK, RHS.shape KxL => MxL
// make sure to call InitKeys first
NdArrayRef DotOLE(const NdArrayRef& inp, yacl::link::Context* conn,
const Shape3D& dim3, bool is_self_lhs);

// LHS.shape BxMxK, RHS.shape BxKxL => BxMxL
// make sure to call InitKeys first
NdArrayRef BatchDotOLE(const NdArrayRef& inp, yacl::link::Context* conn,
const Shape4D& dim4, bool is_self_lhs);

Expand Down
19 changes: 15 additions & 4 deletions libspu/mpc/cheetah/arith/cheetah_mul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
#include "seal/keygenerator.h"
#include "seal/publickey.h"
#include "seal/secretkey.h"
#include "seal/util/locks.h"
#include "seal/util/polyarithsmallmod.h"
#include "seal/valcheck.h"
#include "spdlog/spdlog.h"
Expand Down Expand Up @@ -103,6 +102,15 @@ struct CheetahMul::Impl : public EnableCPRNG {

int64_t num_slots() const { return parms_.poly_modulus_degree(); }

void LazyInit(FieldType field, uint32_t msg_width_hint) {
Options options;
options.ring_bitlen = SizeOf(field) * 8;
options.msg_bitlen =
msg_width_hint == 0 ? options.ring_bitlen : msg_width_hint;
LazyExpandSEALContexts(options);
LazyInitModSwitchHelper(options);
}

void LazyExpandSEALContexts(const Options &options,
yacl::link::Context *conn = nullptr);

Expand Down Expand Up @@ -189,7 +197,6 @@ struct CheetahMul::Impl : public EnableCPRNG {
uint32_t current_crt_plain_bitlen_{0};

// SEAL's contexts for ZZ_{2^k}
mutable std::mutex context_lock_;
std::vector<seal::SEALContext> seal_cntxts_;

// own secret key
Expand All @@ -206,7 +213,6 @@ struct CheetahMul::Impl : public EnableCPRNG {
};

void CheetahMul::Impl::LazyInitModSwitchHelper(const Options &options) {
std::lock_guard guard(context_lock_);
if (ms_helpers_.count(options) > 0) {
return;
}
Expand Down Expand Up @@ -269,7 +275,6 @@ void CheetahMul::Impl::LocalExpandSEALContexts(size_t target) {
void CheetahMul::Impl::LazyExpandSEALContexts(const Options &options,
yacl::link::Context *conn) {
uint32_t target_plain_bitlen = TotalCRTBitLen(options);
std::lock_guard guard(context_lock_);
if (current_crt_plain_bitlen_ >= target_plain_bitlen) {
return;
}
Expand Down Expand Up @@ -719,4 +724,10 @@ NdArrayRef CheetahMul::MulOLE(const NdArrayRef &inp, bool is_evaluator,
return impl_->MulOLE(inp, nullptr, is_evaluator, msg_width_hint);
}

void CheetahMul::LazyInitKeys(FieldType field, uint32_t msg_width_hint) {
SPU_ENFORCE(impl_ != nullptr);
SPU_ENFORCE(msg_width_hint <= SizeOf(field) * 8);
return impl_->LazyInit(field, msg_width_hint);
}

} // namespace spu::mpc::cheetah
4 changes: 4 additions & 0 deletions libspu/mpc/cheetah/arith/cheetah_mul.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,13 @@ class CheetahMul {

CheetahMul(CheetahMul&&) = delete;

void LazyInitKeys(FieldType field, uint32_t msg_width_hint = 0);

// NOTE: make sure to call InitKeys first
NdArrayRef MulOLE(const NdArrayRef& inp, yacl::link::Context* conn,
bool is_evaluator, uint32_t msg_width_hint = 0);

// NOTE: make sure to call InitKeys first
NdArrayRef MulOLE(const NdArrayRef& inp, bool is_evaluator,
uint32_t msg_width_hint = 0);

Expand Down
6 changes: 6 additions & 0 deletions libspu/mpc/cheetah/arithmetic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,8 @@ NdArrayRef MulAA::mulDirectly(KernelEvalContext* ctx, const NdArrayRef& x,
// Compute the cross terms x0*y1, x1*y0 homomorphically
auto* comm = ctx->getState<Communicator>();
auto* mul_prot = ctx->getState<CheetahMulState>()->get();
mul_prot->LazyInitKeys(x.eltype().as<Ring2k>()->field());

const int rank = comm->getRank();
auto fx = x.reshape({x.numel()});
auto fy = y.reshape({y.numel()});
Expand Down Expand Up @@ -311,6 +313,8 @@ NdArrayRef MatMulAA::proc(KernelEvalContext* ctx, const NdArrayRef& x,

auto* comm = ctx->getState<Communicator>();
auto* dot_prot = ctx->getState<CheetahDotState>()->get();
dot_prot->LazyInitKeys(x.eltype().as<Ring2k>()->field());

const int rank = comm->getRank();

// (x0 + x1) * (y0 + y1)
Expand Down Expand Up @@ -347,6 +351,8 @@ NdArrayRef MatMulAV::proc(KernelEvalContext* ctx, const NdArrayRef& x,
}
auto* comm = ctx->getState<Communicator>();
auto* dot_prot = ctx->getState<CheetahDotState>()->get();
dot_prot->LazyInitKeys(x.eltype().as<Ring2k>()->field());

const int rank = comm->getRank();
const auto* ptype = y.eltype().as<Priv2kTy>();
SPU_ENFORCE(ptype != nullptr, "rhs should be a private type");
Expand Down
3 changes: 2 additions & 1 deletion libspu/mpc/cheetah/state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

namespace spu::mpc::cheetah {
size_t InitOTState(KernelEvalContext* ctx, size_t njobs) {
constexpr size_t kMinWorkSize = 1500;
constexpr size_t kMinWorkSize = 5000;
if (njobs == 0) {
return 0;
}
Expand Down Expand Up @@ -70,6 +70,7 @@ void CheetahMulState::makeSureCacheSize(FieldType field, int64_t numel) {
// Then the beaver (a0, b0, c0) and (a1, b1, c1)
// where c0 = a0*b0 + <a0*b1> + <a1*b0>
// c1 = a1*b1 + <a0*b1> + <a1*b0>
mul_prot_->LazyInitKeys(field);
const int rank = mul_prot_->Rank();
const int64_t ole_sze = mul_prot_->OLEBatchSize();
const int64_t num_ole = CeilDiv<size_t>(2 * numel, ole_sze);
Expand Down
8 changes: 6 additions & 2 deletions libspu/mpc/cheetah/state.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,12 @@ class CheetahOTState : public State {
if (basic_ot_prot_[idx]) {
return;
}
// NOTE: create a separated link for OT
auto _comm = std::make_shared<Communicator>(comm->lctx()->Spawn());
// NOTE(lwj): create a separated link for OT
// We **do not** block on the OT link since the message volume is small for
// LPN-based OTe
auto link = comm->lctx()->Spawn();
link->SetThrottleWindowSize(0);
auto _comm = std::make_shared<Communicator>(std::move(link));
basic_ot_prot_[idx] = std::make_shared<BasicOTProtocols>(std::move(_comm));
}

Expand Down
4 changes: 2 additions & 2 deletions libspu/mpc/utils/ring_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ constexpr char kModule[] = "RingOps";
#define DEF_UNARY_RING_OP(NAME, OP) \
void NAME##_impl(NdArrayRef& ret, const NdArrayRef& x) { \
ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, x); \
const auto field = x.eltype().as<Ring2k>() -> field(); \
const auto field = x.eltype().as<Ring2k>()->field(); \
const int64_t numel = ret.numel(); \
return DISPATCH_ALL_FIELDS(field, kModule, [&]() { \
using T = std::make_signed_t<ring2k_t>; \
Expand All @@ -65,7 +65,7 @@ DEF_UNARY_RING_OP(ring_neg, -);
const NdArrayRef& y) { \
ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, x); \
ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, y); \
const auto field = x.eltype().as<Ring2k>() -> field(); \
const auto field = x.eltype().as<Ring2k>()->field(); \
const int64_t numel = ret.numel(); \
return DISPATCH_ALL_FIELDS(field, kModule, [&]() { \
NdArrayView<ring2k_t> _x(x); \
Expand Down

0 comments on commit 8dc91d6

Please sign in to comment.