diff --git a/CHANGELOG.md b/CHANGELOG.md index 8c6bc9bb..051366ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ > > please add your unreleased change here. +- [API] Change IO interface, split share into chunk + ## 20230705 - [SPU] 0.4.1 release - [Improvement] Improve tanh performance diff --git a/bazel/repositories.bzl b/bazel/repositories.bzl index d0fcd1e1..86eaad02 100644 --- a/bazel/repositories.bzl +++ b/bazel/repositories.bzl @@ -18,7 +18,7 @@ load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository") SECRETFLOW_GIT = "https://github.com/secretflow" -YACL_COMMIT_ID = "fff880ff9e1e50d5260c5a881ff7abc051cc27d5" +YACL_COMMIT_ID = "a85ab1f6cb642c8e36dc429714a108d5368df762" def spu_deps(): _bazel_platform() diff --git a/bazel/seal.BUILD b/bazel/seal.BUILD index 78f9c194..629b1fa9 100644 --- a/bazel/seal.BUILD +++ b/bazel/seal.BUILD @@ -59,5 +59,5 @@ spu_cmake_external( out_static_libs = ["libseal-4.1.a"], deps = [ "@com_github_facebook_zstd//:zstd", - ] + ], ) diff --git a/benchmark/setup_dockers_and_run.sh b/benchmark/setup_dockers_and_run.sh index 14bba654..c22312ac 100644 --- a/benchmark/setup_dockers_and_run.sh +++ b/benchmark/setup_dockers_and_run.sh @@ -42,7 +42,7 @@ sleep 10 echo -e "${COLOR_GREEN}Run benchmark${COLOR_END}" docker run --rm --mount type=bind,source="$(pwd)",target=/home/admin/dev/ --network nn-benchmark spu-build:v1 \ - sh -c "cd /home/admin/dev && bash benchmark/run_bench.sh $@" | tee benchmark_results.log; + sh -c "cd /home/admin/dev && bash benchmark/run_bench.sh $@" | tee benchmark_results.log; echo -e "${COLOR_GREEN}Shutdown docker compose${COLOR_END}" docker-compose -f .circleci/benchmark.yml down diff --git a/build_wheel_entrypoint.sh b/build_wheel_entrypoint.sh index 4e29a6ef..17d70cf4 100755 --- a/build_wheel_entrypoint.sh +++ b/build_wheel_entrypoint.sh @@ -15,7 +15,6 @@ # limitations under the License. # - pip install numpy python setup.py bdist_wheel diff --git a/docs/development/add_protocols.rst b/docs/development/add_protocols.rst index 320c5d14..91d85afc 100644 --- a/docs/development/add_protocols.rst +++ b/docs/development/add_protocols.rst @@ -167,7 +167,7 @@ Inside the **makeAby3Protocol** function, it does three things. - The third is to register the protocol kernels (functions). We can see that three types of kernels are registered. \ The first type is the kernels implemented in the `pv2k.cc `_ \ file, using **Pub2k** as the naming prefix of kernel classes. The second type is the kernels implemented in the \ - `ab_api.cc `_ file, using **ABProt** as the \ + `ab_api.cc `_ file, using **ABProt** as the \ naming prefix of kernel classes. The third type is implemented in `arithmetic.cc `_, \ `boolean.cc `_ and other files under the aby3 directory. diff --git a/libspu/compiler/core/core.cc b/libspu/compiler/core/core.cc index a9680a7c..093d78be 100644 --- a/libspu/compiler/core/core.cc +++ b/libspu/compiler/core/core.cc @@ -83,6 +83,8 @@ void Core::buildPipeline(mlir::PassManager *pm) { optPM.addPass(mlir::createLoopInvariantCodeMotionPass()); optPM.addPass(mlir::createCSEPass()); + + optPM.addPass(mlir::pphlo::createInsertDeallocationOp()); } } // namespace spu::compiler diff --git a/libspu/compiler/passes/BUILD.bazel b/libspu/compiler/passes/BUILD.bazel index c04e7159..96088b94 100644 --- a/libspu/compiler/passes/BUILD.bazel +++ b/libspu/compiler/passes/BUILD.bazel @@ -234,6 +234,18 @@ spu_cc_library( ], ) +spu_cc_library( + name = "insert_deallocation", + srcs = ["insert_deallocation.cc"], + hdrs = ["passes.h"], + deps = [ + ":pass_details", + "//libspu/dialect:pphlo_dialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:TransformUtils", + ], +) + spu_cc_library( name = "all_passes", hdrs = ["register_passes.h"], @@ -242,6 +254,7 @@ spu_cc_library( ":decompose_minmax", ":expand_secret_gather", ":hlo_legalize_to_pphlo", + ":insert_deallocation", ":lower_conversion_cast", ":lower_mixed_type_op", ":optimize_denominator_with_broadcast", diff --git a/libspu/compiler/passes/insert_deallocation.cc b/libspu/compiler/passes/insert_deallocation.cc new file mode 100644 index 00000000..7b27fd4e --- /dev/null +++ b/libspu/compiler/passes/insert_deallocation.cc @@ -0,0 +1,127 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "mlir/Analysis/Liveness.h" +#include "mlir/Pass/Pass.h" + +#include "libspu/compiler/passes/pass_details.h" +#include "libspu/dialect/pphlo_ops.h" + +#ifdef ENABLE_LIVENESS_DEBUG + +#include "spdlog/spdlog.h" + +void printLiveness(mlir::Liveness *liveness) { + std::string buf; + llvm::raw_string_ostream os(buf); + + liveness->print(os); + + SPDLOG_INFO("liveness = {}", os.str()); +} + +#endif + +namespace mlir::pphlo { + +namespace { + +struct Deallocator { +private: + std::unique_ptr top_liveness_; + +public: + LogicalResult transformOp(Operation *op, + const LivenessBlockInfo *block_liveness) { + for (const auto &operand : op->getOperands()) { + if (block_liveness->isLiveOut(operand) || + mlir::isa(operand)) { + // skip live out values and block args + continue; + } + + if (operand.getDefiningOp()->getParentRegion() != op->getParentRegion()) { + // This value is captured by current region, right now we do not handle + // cross region ownership.. skip + continue; + } + + if (top_liveness_->isDeadAfter(operand, op)) { + OpBuilder builder(op->getContext()); + builder.setInsertionPointAfter(op); + builder.create(op->getLoc(), operand); + } + } + + for (int64_t idx = 0; idx < op->getNumRegions(); ++idx) { + if (failed(transformRegion(op->getRegion(idx)))) { + return failure(); + } + } + + return success(); + } + + LogicalResult transformBlock(Block &block) { + const auto *block_liveness = top_liveness_->getLiveness(&block); + for (auto &op : llvm::make_early_inc_range(block.without_terminator())) { + auto opResult = transformOp(&op, block_liveness); + if (failed(opResult)) { + return failure(); + } + } + return success(); + } + + LogicalResult transformRegion(Region &r) { + for (auto &b : r.getBlocks()) { + if (failed(transformBlock(b))) { + return failure(); + } + } + return success(); + } + + LogicalResult transformFuncOp(func::FuncOp op) { + if (op->getNumRegions() == 0) { + return success(); + } + + top_liveness_ = std::make_unique(op); + + // Transform function body. + if (failed(transformRegion(op.getBody()))) { + return failure(); + } + + return success(); + } +}; + +struct InsertDeallocation : public InsertDeallocationBase { + void runOnOperation() override { + if (failed(Deallocator().transformFuncOp(getOperation()))) { + signalPassFailure(); + } + } +}; +} // namespace + +std::unique_ptr> createInsertDeallocationOp() { + return std::make_unique(); +} + +} // namespace mlir::pphlo diff --git a/libspu/compiler/passes/passes.h b/libspu/compiler/passes/passes.h index c8eeeea7..ddcc32ba 100644 --- a/libspu/compiler/passes/passes.h +++ b/libspu/compiler/passes/passes.h @@ -71,6 +71,8 @@ std::unique_ptr> createRewriteDivSqrtPatterns(); std::unique_ptr> createOptimizeDenominatorWithBroadcast(); +std::unique_ptr> createInsertDeallocationOp(); + } // namespace pphlo } // namespace mlir diff --git a/libspu/compiler/passes/passes.td b/libspu/compiler/passes/passes.td index 4292b5c5..95581327 100644 --- a/libspu/compiler/passes/passes.td +++ b/libspu/compiler/passes/passes.td @@ -92,3 +92,9 @@ def OptimizeDenominatorWithBcast: Pass<"optimize-denominator-with-broadcast", "f let constructor = "createOptimizeDenominatorWithBroadcast()"; let dependentDialects = ["pphlo::PPHloDialect"]; } + +def InsertDeallocation: Pass<"insert-deallocation", "func::FuncOp"> { + let summary = "Insert deallocation ops"; + let constructor = "createInsertDeallocationOp()"; + let dependentDialects = ["pphlo::PPHloDialect"]; +} diff --git a/libspu/compiler/tests/pphlo_simple_dealloc.mlir b/libspu/compiler/tests/pphlo_simple_dealloc.mlir new file mode 100644 index 00000000..cbf1b957 --- /dev/null +++ b/libspu/compiler/tests/pphlo_simple_dealloc.mlir @@ -0,0 +1,13 @@ +// RUN: mlir-pphlo-opt --insert-deallocation --split-input-file %s | FileCheck %s + +func.func @main() -> (tensor>) { + %0 = "pphlo.constant"() {value = dense<0xFF800000> : tensor} : () -> tensor> + %1 = "pphlo.constant"() {value = dense<0.000000e+00> : tensor} : () -> tensor> + %2 = "pphlo.less"(%0, %1): (tensor>, tensor>) -> tensor> + %3 = "pphlo.select"(%2, %0, %1): (tensor>, tensor>, tensor>) -> tensor> + //CHECK: "pphlo.free"(%1) + //CHECK: "pphlo.free"(%0) + //CHECK: "pphlo.free"(%2) + //CHECK-NOT: "pphlo.free"(%3) + return %3: tensor> +} diff --git a/libspu/core/BUILD.bazel b/libspu/core/BUILD.bazel index 17113531..b6c26d18 100644 --- a/libspu/core/BUILD.bazel +++ b/libspu/core/BUILD.bazel @@ -38,6 +38,7 @@ spu_cc_library( deps = [ "//libspu/core:prelude", "@com_google_absl//absl/types:span", + "@yacl//yacl/link", "@yacl//yacl/utils:scope_guard", ], ) diff --git a/libspu/core/context.h b/libspu/core/context.h index d73f190d..af5309d2 100644 --- a/libspu/core/context.h +++ b/libspu/core/context.h @@ -33,7 +33,7 @@ namespace spu { class SPUContext final { RuntimeConfig config_; - // A dynamic object for polymophic(multi-stage) operations. + // A dynamic object for polymorphic(multi-stage) operations. std::unique_ptr prot_; // TODO(jint): do we really need a link here? how about a FHE context. @@ -107,6 +107,10 @@ class KernelEvalContext final { SPUContext* sctx() { return sctx_; } + const std::shared_ptr& lctx() const { + return sctx_->lctx(); + } + const std::string& id() { return sctx_->id(); } const std::string& pid() { return sctx_->pid(); } diff --git a/libspu/core/ndarray_ref.cc b/libspu/core/ndarray_ref.cc index 01b5af57..0be7f436 100644 --- a/libspu/core/ndarray_ref.cc +++ b/libspu/core/ndarray_ref.cc @@ -250,6 +250,11 @@ NdArrayRef NdArrayRef::reshape(absl::Span to_shape) const { SPU_ENFORCE(calcNumel(shape()) == calcNumel(to_shape), "reshape from {} to {} is changing numel", shape(), to_shape); + // Reshape empty is always a noop + if (calcNumel(to_shape) == 0) { + return NdArrayRef(buf(), eltype(), to_shape, strides(), offset()); + } + std::vector new_strides(to_shape.size(), 0); if (attempt_nocopy_reshape(*this, to_shape, new_strides)) { // No copy reshape @@ -548,11 +553,16 @@ NdArrayRef unflatten(const ArrayRef& arr, absl::Span shape) { std::vector(shape.size(), 0), arr.offset()); } - ArrayRef compact = arr.isCompact() ? arr : arr.clone(); - auto strides = makeCompactStrides(shape); - return NdArrayRef(compact.buf(), compact.eltype(), shape, std::move(strides), - compact.offset()); + + if (arr.stride() != 1) { + for (auto& s : strides) { + s *= arr.stride(); + } + } + + return NdArrayRef(arr.buf(), arr.eltype(), shape, std::move(strides), + arr.offset()); } namespace { diff --git a/libspu/core/trace.h b/libspu/core/trace.h index 0f0efebf..be82c660 100644 --- a/libspu/core/trace.h +++ b/libspu/core/trace.h @@ -25,6 +25,7 @@ #include "fmt/format.h" #include "fmt/ostream.h" #include "spdlog/spdlog.h" +#include "yacl/link/link.h" namespace std { @@ -149,6 +150,9 @@ struct ActionRecord final { // the action timing information. TimePoint start; TimePoint end; + // the communication bytes information. + size_t send_bytes_start; + size_t send_bytes_end; }; class ProfState final { @@ -203,6 +207,9 @@ class TraceAction final { // The tracer. std::shared_ptr const tracer_; + // The link context. + std::shared_ptr const lctx_; + // The static expected behavior of this action. int64_t const flag_; @@ -225,12 +232,18 @@ class TraceAction final { TimePoint start_; TimePoint end_; + // the action communication information. + size_t send_bytes_start_; + size_t send_bytes_end_; + int64_t saved_tracer_flag_; template void begin(Args&&... args) { start_ = std::chrono::high_resolution_clock::now(); - + if (lctx_) { + send_bytes_start_ = lctx_->GetStats()->sent_bytes.load(); + } const auto flag = flag_ & tracer_->getFlag(); if ((flag & TR_LOGB) != 0) { detail_ = internal::variadicToString(std::forward(args)...); @@ -249,7 +262,9 @@ class TraceAction final { // end_ = std::chrono::high_resolution_clock::now(); - + if (lctx_) { + send_bytes_end_ = lctx_->GetStats()->sent_bytes.load(); + } const auto flag = flag_ & tracer_->getFlag(); if ((flag & TR_LOGE) != 0) { tracer_->decDepth(); @@ -257,7 +272,8 @@ class TraceAction final { } if ((flag & TR_REC) != 0 && (flag & TR_MODALL) != 0) { tracer_->getProfState()->addRecord( - ActionRecord{id_, name_, std::move(detail_), flag_, start_, end_}); + ActionRecord{id_, name_, std::move(detail_), flag_, start_, end_, + send_bytes_start_, send_bytes_end_}); } } @@ -275,12 +291,14 @@ class TraceAction final { // mask = ~TR_MOD2, means disable further TR_MOD2 tracing. template explicit TraceAction( - std::shared_ptr tracer, // + std::shared_ptr tracer, // + std::shared_ptr lctx, // int64_t flag, // the static expected behaviour flag of action. int64_t mask, // the suppress mask of the action. std::string name, // name of this action. Args&&... args) : tracer_(std::move(tracer)), + lctx_(std::move(lctx)), flag_(flag), mask_(mask), name_(std::move(name)) { @@ -320,8 +338,8 @@ std::shared_ptr getTracer(const std::string& id, // Why add `##` to __VA_ARGS__, please see // https://stackoverflow.com/questions/5891221/variadic-macros-with-zero-arguments -#define SPU_TRACE_ACTION(TRACER, FLAG, MASK, NAME, ...) \ - TraceAction __trace_action(TRACER, FLAG, MASK, NAME, ##__VA_ARGS__); +#define SPU_TRACE_ACTION(TRACER, LINK, FLAG, MASK, NAME, ...) \ + TraceAction __trace_action(TRACER, LINK, FLAG, MASK, NAME, ##__VA_ARGS__); #else @@ -330,34 +348,34 @@ std::shared_ptr getTracer(const std::string& id, #endif // trace an hlo layer dispatch -#define SPU_TRACE_HLO_DISP(CTX, ...) \ - SPU_TRACE_ACTION(GET_TRACER(CTX), (TR_HLO | TR_LOG), (~0), __func__, \ - ##__VA_ARGS__) +#define SPU_TRACE_HLO_DISP(CTX, ...) \ + SPU_TRACE_ACTION(GET_TRACER(CTX), (CTX)->lctx(), (TR_HLO | TR_LOG), (~0), \ + __func__, ##__VA_ARGS__) // trace an hlo layer leaf -#define SPU_TRACE_HLO_LEAF(CTX, ...) \ - SPU_TRACE_ACTION(GET_TRACER(CTX), (TR_HLO | TR_LAR), (~TR_HLO), __func__, \ - ##__VA_ARGS__) +#define SPU_TRACE_HLO_LEAF(CTX, ...) \ + SPU_TRACE_ACTION(GET_TRACER(CTX), (CTX)->lctx(), (TR_HLO | TR_LAR), \ + (~TR_HLO), __func__, ##__VA_ARGS__) // trace an hal layer dispatch -#define SPU_TRACE_HAL_DISP(CTX, ...) \ - SPU_TRACE_ACTION(GET_TRACER(CTX), (TR_HAL | TR_LOG), (~0), __func__, \ - ##__VA_ARGS__) +#define SPU_TRACE_HAL_DISP(CTX, ...) \ + SPU_TRACE_ACTION(GET_TRACER(CTX), (CTX)->lctx(), (TR_HAL | TR_LOG), (~0), \ + __func__, ##__VA_ARGS__) // trace an hal layer leaf -#define SPU_TRACE_HAL_LEAF(CTX, ...) \ - SPU_TRACE_ACTION(GET_TRACER(CTX), (TR_HAL | TR_LAR), (~TR_HAL), __func__, \ - ##__VA_ARGS__) +#define SPU_TRACE_HAL_LEAF(CTX, ...) \ + SPU_TRACE_ACTION(GET_TRACER(CTX), (CTX)->lctx(), (TR_HAL | TR_LAR), \ + (~TR_HAL), __func__, ##__VA_ARGS__) // trace an mpc layer dispatch -#define SPU_TRACE_MPC_DISP(CTX, ...) \ - SPU_TRACE_ACTION(GET_TRACER(CTX), (TR_MPC | TR_LOG), (~0), __func__, \ - ##__VA_ARGS__) +#define SPU_TRACE_MPC_DISP(CTX, ...) \ + SPU_TRACE_ACTION(GET_TRACER(CTX), (CTX)->lctx(), (TR_MPC | TR_LOG), (~0), \ + __func__, ##__VA_ARGS__) // trace an mpc layer leaf -#define SPU_TRACE_MPC_LEAF(CTX, ...) \ - SPU_TRACE_ACTION(GET_TRACER(CTX), (TR_MPC | TR_LAR), (~TR_MPC), __func__, \ - ##__VA_ARGS__) +#define SPU_TRACE_MPC_LEAF(CTX, ...) \ + SPU_TRACE_ACTION(GET_TRACER(CTX), (CTX)->lctx(), (TR_MPC | TR_LAR), \ + (~TR_MPC), __func__, ##__VA_ARGS__) // Debug purpose only. class MemProfilingGuard { diff --git a/libspu/core/trace_test.cc b/libspu/core/trace_test.cc index e058c71c..a2b9324d 100644 --- a/libspu/core/trace_test.cc +++ b/libspu/core/trace_test.cc @@ -49,10 +49,11 @@ TEST(TraceTest, ActionWorks) { auto tracer = std::make_shared(TR_MODALL | TR_LAR); { - TraceAction ta0(tracer, (TR_MOD1 | TR_LOG), ~0, "f"); - TraceAction ta1(tracer, (TR_MOD1 | TR_LAR), ~TR_MOD1, "g", 10); - TraceAction ta2(tracer, (TR_MOD1 | TR_LAR), ~TR_MOD1, "ignored", 10); - TraceAction ta3(tracer, (TR_MOD2 | TR_LAR), ~0, "h", 10, 20); + TraceAction ta0(tracer, nullptr, (TR_MOD1 | TR_LOG), ~0, "f"); + TraceAction ta1(tracer, nullptr, (TR_MOD1 | TR_LAR), ~TR_MOD1, "g", 10); + TraceAction ta2(tracer, nullptr, (TR_MOD1 | TR_LAR), ~TR_MOD1, "ignored", + 10); + TraceAction ta3(tracer, nullptr, (TR_MOD2 | TR_LAR), ~0, "h", 10, 20); } ASSERT_EQ(tracer->getProfState()->getRecords().size(), 2); @@ -64,6 +65,7 @@ TEST(TraceTest, ActionWorks) { struct Context { static std::string id() { return "id"; } static std::string pid() { return ""; } + static std::shared_ptr lctx() { return nullptr; } }; void g(Context* ctx) { SPU_TRACE_HAL_LEAF(ctx); } diff --git a/libspu/core/value.cc b/libspu/core/value.cc index 56e02882..b8e73207 100644 --- a/libspu/core/value.cc +++ b/libspu/core/value.cc @@ -59,59 +59,106 @@ Value& Value::setDtype(DataType new_dtype, bool force) { return *this; } -ValueProto Value::toProto() const { +size_t Value::chunksCount(size_t max_chunk_size) const { + size_t total = numel() * data_.elsize(); + size_t num_chunks = (total + max_chunk_size - 1) / max_chunk_size; + return num_chunks; +} + +ValueProto Value::toProto(size_t max_chunk_size) const { + SPU_ENFORCE(max_chunk_size > 0); SPU_ENFORCE(dtype_ != DT_INVALID && vtype() != VIS_INVALID); - ValueProto proto; - proto.set_data_type(dtype_); - proto.set_visibility(vtype()); - proto.set_storage_type(data_.eltype().toString()); - for (const auto& d : shape()) { - proto.mutable_shape()->add_dims(d); - } + ValueProto ret; + + auto build_chunk = [&](const void* data, size_t size, size_t num_chunks) { + if (size == 0) { + return; + } + ret.chunks.reserve(num_chunks); + for (size_t i = 0; i < num_chunks; i++) { + size_t chunk_size = std::min(max_chunk_size, size - i * max_chunk_size); + + size_t offset = i * max_chunk_size; + ValueChunkProto chunk; + chunk.set_total_bytes(size); + chunk.set_chunk_offset(offset); + if (chunk_size > 0) { + chunk.set_content(static_cast(data) + offset, + chunk_size); + } + ret.chunks.emplace_back(std::move(chunk)); + } + }; + + const size_t num_chunks = chunksCount(max_chunk_size); + if (data_.isCompact()) { - proto.set_content(data_.data(), numel() * data_.elsize()); + build_chunk(data_.data(), numel() * data_.elsize(), num_chunks); } else { // Make a compact clone auto copy = data_.clone(); SPU_ENFORCE(copy.isCompact(), "Must be a compact copy."); - proto.set_content(copy.data(), copy.buf()->size()); + build_chunk(copy.data(), copy.buf()->size(), num_chunks); } - return proto; + + ret.meta.CopyFrom(toMetaProto()); + + return ret; } -ValueMeta Value::toMetaProto() const { +ValueMetaProto Value::toMetaProto() const { SPU_ENFORCE(dtype_ != DT_INVALID && vtype() != VIS_INVALID); - ValueMeta proto; + ValueMetaProto proto; proto.set_data_type(dtype_); proto.set_visibility(vtype()); for (const auto& d : shape()) { proto.mutable_shape()->add_dims(d); } + proto.set_storage_type(data_.eltype().toString()); return proto; } -Value Value::fromProto(const ValueProto& proto) { - const auto eltype = Type::fromString(proto.storage_type()); +Value Value::fromProto(const ValueProto& value) { + const auto& meta = value.meta; + const auto eltype = Type::fromString(meta.storage_type()); - SPU_ENFORCE(proto.data_type() != DT_INVALID, "invalid data type={}", - proto.data_type()); + SPU_ENFORCE(meta.data_type() != DT_INVALID, "invalid data type={}", + meta.data_type()); // vtype is deduced from storage_type. - SPU_ENFORCE(proto.visibility() == getVisibilityFromType(eltype), - "visibility {} does not match storage_type {}", - proto.visibility(), eltype); + SPU_ENFORCE(meta.visibility() == getVisibilityFromType(eltype), + "visibility {} does not match storage_type {}", meta.visibility(), + eltype); + + std::vector shape(meta.shape().dims().begin(), + meta.shape().dims().end()); - std::vector shape(proto.shape().dims().begin(), - proto.shape().dims().end()); + const auto& chunks = value.chunks; + const size_t total_bytes = chunks.empty() ? 0 : chunks[0].total_bytes(); + + std::map ordered_chunks; + for (const auto& s : chunks) { + SPU_ENFORCE(ordered_chunks.insert({s.chunk_offset(), &s}).second, + "Repeated chunk_offset {} found", s.chunk_offset()); + } NdArrayRef data(eltype, shape); - SPU_ENFORCE(static_cast(data.buf()->size()) == - proto.content().size()); - memcpy(data.data(), proto.content().c_str(), data.buf()->size()); + SPU_ENFORCE(static_cast(data.buf()->size()) == total_bytes); + + size_t chunk_end_pos = 0; + for (const auto& [offset, chunk] : ordered_chunks) { + SPU_ENFORCE(offset == chunk_end_pos, + "offset {} is not match to last chunk's end pos", offset); + memcpy(static_cast(data.data()) + offset, chunk->content().data(), + chunk->content().size()); + chunk_end_pos += chunk->content().size(); + } + + SPU_ENFORCE(total_bytes == chunk_end_pos); - return Value(data, proto.data_type()); + return Value(data, meta.data_type()); } Value Value::clone() const { return Value(data_.clone(), dtype()); } diff --git a/libspu/core/value.h b/libspu/core/value.h index 547ba92e..785138d0 100644 --- a/libspu/core/value.h +++ b/libspu/core/value.h @@ -25,6 +25,15 @@ namespace spu { +// In order to prevent a single protobuf from being larger than 2gb, a spu +// runtime value is represented by multiple chunked protobuf + meta, and +// std::vector is used to organize multiple chunks instead of repeated in +// protobuf. +struct ValueProto { + ValueMetaProto meta; + std::vector chunks; +}; + class Value final { NdArrayRef data_; DataType dtype_ = DT_INVALID; @@ -64,11 +73,12 @@ class Value final { Value& setDtype(DataType new_dtype, bool force = false); // Serialize to protobuf. - ValueProto toProto() const; - ValueMeta toMetaProto() const; + ValueProto toProto(size_t max_chunk_size) const; + size_t chunksCount(size_t max_chunk_size) const; + ValueMetaProto toMetaProto() const; // Deserialize from protobuf. - static Value fromProto(const ValueProto& proto); + static Value fromProto(const ValueProto& value); Value clone() const; }; diff --git a/libspu/device/BUILD.bazel b/libspu/device/BUILD.bazel index 671b6ce7..012998cf 100644 --- a/libspu/device/BUILD.bazel +++ b/libspu/device/BUILD.bazel @@ -26,25 +26,11 @@ spu_cc_library( ], ) -proto_library( - name = "device_proto", - srcs = ["device.proto"], - deps = [ - "//libspu:spu_proto", - ], -) - -cc_proto_library( - name = "device_cc_proto", - deps = [":device_proto"], -) - spu_cc_library( name = "io", srcs = ["io.cc"], hdrs = ["io.h"], deps = [ - ":device_cc_proto", ":symbol_table", "//libspu:spu_cc_proto", "//libspu/core:context", @@ -70,7 +56,6 @@ spu_cc_library( srcs = ["executor.cc"], hdrs = ["executor.h"], deps = [ - ":device_cc_proto", ":symbol_table", "//libspu:spu_cc_proto", "//libspu/core:context", @@ -111,7 +96,6 @@ spu_cc_library( srcs = ["symbol_table.cc"], hdrs = ["symbol_table.h"], deps = [ - ":device_cc_proto", "//libspu/core:pt_buffer_view", "//libspu/core:value", ], diff --git a/libspu/device/api.cc b/libspu/device/api.cc index bc780fbc..bb6d187d 100644 --- a/libspu/device/api.cc +++ b/libspu/device/api.cc @@ -27,8 +27,6 @@ #include "libspu/device/pphlo/pphlo_executor.h" #include "libspu/dialect/pphlo_dialect.h" -#include "libspu/device/device.pb.h" - namespace spu::device { namespace { @@ -99,6 +97,8 @@ struct ActionStats { size_t count = 0; // total duration time. Duration total_time = {}; + // total send bytes. + size_t send_bytes = 0; inline double getTotalTimeInSecond() const { return std::chrono::duration_cast>(total_time) @@ -106,8 +106,12 @@ struct ActionStats { } }; +/* + @shantang / @wuju + TODO: temporary remove, need to adapt value slice change void takeSnapshot(size_t rank, const RuntimeConfig &rt_config, const ExecutableProto &executable, const SymbolTable &env) { + const std::string &dump_dir = rt_config.processor_dump_dir(); // Naming convention for dumped files must align with debug runner. std::filesystem::path dump_folder(dump_dir); @@ -123,6 +127,7 @@ void takeSnapshot(size_t rank, const RuntimeConfig &rt_config, std::ofstream dump_file(dump_path, std::ios::binary | std::ios::out); dump_file << snapshot.SerializeAsString(); } +*/ void printProfilingData(spu::SPUContext *sctx, const std::string &name, const ExecutionStats &exec_stats, @@ -147,6 +152,7 @@ void printProfilingData(spu::SPUContext *sctx, const std::string &name, stat.count++; stat.total_time += std::chrono::duration_cast(rec.end - rec.start); + stat.send_bytes += (rec.send_bytes_end - rec.send_bytes_start); } static std::map kModules = { @@ -163,8 +169,9 @@ void printProfilingData(spu::SPUContext *sctx, const std::string &name, SPDLOG_INFO("{} profiling: total time {}", mod_name, total_time); for (const auto &[key, stat] : stats) { if ((key.flag & mod_flag) != 0) { - SPDLOG_INFO("- {}, executed {} times, duration {}s", key.name, - stat.count, stat.getTotalTimeInSecond()); + SPDLOG_INFO("- {}, executed {} times, duration {}s, send bytes {}", + key.name, stat.count, stat.getTotalTimeInSecond(), + stat.send_bytes); } } } @@ -217,11 +224,16 @@ void executeImpl(OpExecutor *executor, spu::SPUContext *sctx, // TODO: rename this flag, enable_execution_dump? const RuntimeConfig rt_config = sctx->config(); + /* + @shantang / @wuju + TODO: temporary remove, need to adapt value slice change if (rt_config.enable_processor_dump()) { const bool isRefHal = sctx->lctx() == nullptr; const size_t rank = isRefHal ? 0 : sctx->lctx()->Rank(); takeSnapshot(rank, rt_config, executable, *env); + } + */ // execution std::vector outputs; diff --git a/libspu/device/device.proto b/libspu/device/device.proto deleted file mode 100644 index 19186554..00000000 --- a/libspu/device/device.proto +++ /dev/null @@ -1,52 +0,0 @@ -// -// Copyright 2021 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -syntax = "proto3"; - -package spu.device; - -import "libspu/spu.proto"; - -message SymbolTableProto { - // - map symbols = 1; -} - -// A snapshot represents a pending-execution moment of SPU. -message SnapshotProto { - // The party's rank. - uint64 rank = 1; - - // The runtime configuration - RuntimeConfig runtime_cfg = 2; - - // The executable file. - ExecutableProto executable = 3; - - // The global symbols. - SymbolTableProto environ = 4; -} - -message RevealedSnapshotProto { - // The runtime configuration - RuntimeConfig runtime_cfg = 1; - - // The executable file. - ExecutableProto executable = 2; - - // The global symbols. - repeated SymbolTableProto parties = 3; -} diff --git a/libspu/device/executor.cc b/libspu/device/executor.cc index f8ed0f10..7fe01aa8 100644 --- a/libspu/device/executor.cc +++ b/libspu/device/executor.cc @@ -97,6 +97,11 @@ void SymbolScope::addValue(mlir::Value key, spu::Value &&val) { symbols_[key] = std::move(val); } +void SymbolScope::removeValue(mlir::Value key) { + std::lock_guard lk(mu_); + symbols_.erase(key); +} + std::vector runRegion(OpExecutor *executor, // SPUContext *sctx, // SymbolScope *parent_scope, // diff --git a/libspu/device/executor.h b/libspu/device/executor.h index 9c0e21d1..ccc04983 100644 --- a/libspu/device/executor.h +++ b/libspu/device/executor.h @@ -48,6 +48,7 @@ class SymbolScope final { spu::Value lookupValue(mlir::Value key) const; void addValue(::mlir::Value key, const spu::Value &val); void addValue(::mlir::Value key, spu::Value &&val); + void removeValue(::mlir::Value key); protected: bool hasValueUnsafe(mlir::Value key) const; diff --git a/libspu/device/io.cc b/libspu/device/io.cc index d4e356be..498a8683 100644 --- a/libspu/device/io.cc +++ b/libspu/device/io.cc @@ -30,6 +30,17 @@ IoClient::IoClient(size_t world_size, const RuntimeConfig &config) base_io_ = mpc::Factory::CreateIO(config_, world_size_); } +size_t IoClient::getShareSize(const PtBufferView &bv, Visibility vtype, + int owner_rank) { + if (bv.pt_type == PT_BOOL && vtype == VIS_SECRET && + base_io_->hasBitSecretSupport()) { + return base_io_->getBitSecretShareSize(calcNumel(bv.shape)); + } else { + return base_io_->getShareType(vtype, owner_rank).size() * + calcNumel(bv.shape); + } +} + std::vector IoClient::makeShares(const PtBufferView &bv, Visibility vtype, int owner_rank) { const size_t fxp_bits = config_.fxp_fraction_bits(); @@ -149,18 +160,41 @@ bool ColocatedIo::deviceHasVar(const std::string &name) const { // Alice: {x0, y0, z0} // Bob: {x1, y1, z1} // Carol: {x2, y2, z2} + +using SymbolTableProto = std::unordered_map; + static std::vector all2all( const std::shared_ptr &lctx, const std::vector &rows) { - // TODO: implement all2all in yacl::link + std::vector party_var_count; + { + const auto party_var_count_str = yacl::link::AllGather( + lctx, std::to_string(rows[0].size()), "all2all_var_count"); + + for (const auto &c : party_var_count_str) { + size_t count = 0; + SPU_ENFORCE(absl::SimpleAtoi(c, &count)); + party_var_count.push_back(count); + } + } + for (size_t idx = 0; idx < lctx->WorldSize(); idx++) { if (idx == lctx->Rank()) { continue; } - yacl::Buffer buf; - buf.resize(rows[idx].ByteSizeLong()); - SPU_ENFORCE(rows[idx].SerializeToArray(buf.data(), buf.size())); - lctx->SendAsync(idx, std::move(buf), "all2all"); + for (const auto &[key, value] : rows[idx]) { + // send var key + lctx->SendAsync(idx, key, "all2all_var_key"); + // send var meta + lctx->SendAsync(idx, value.meta.SerializeAsString(), "all2all_var_meta"); + // send chunks count + lctx->SendAsync(idx, std::to_string(value.chunks.size()), + "all2all_var_chunks_count"); + for (const auto &s : value.chunks) { + // send chunks + lctx->SendAsync(idx, s.SerializeAsString(), "all2all_var_chunk"); + } + } } std::vector cols; @@ -169,10 +203,30 @@ static std::vector all2all( cols.push_back(rows[idx]); continue; } - auto data = lctx->Recv(idx, "all2all"); - SymbolTableProto vars; - SPU_ENFORCE(vars.ParseFromArray(data.data(), data.size())); - cols.push_back(std::move(vars)); + SymbolTableProto st_proto; + for (size_t msg_idx = 0; msg_idx < party_var_count[idx]; msg_idx++) { + auto key = lctx->Recv(idx, "all2all_var_key"); + ValueProto proto; + { + auto data = lctx->Recv(idx, "all2all_var_meta"); + SPU_ENFORCE(proto.meta.ParseFromArray(data.data(), data.size())); + } + size_t chunk_count = 0; + { + auto data = lctx->Recv(idx, "all2all_var_chunks_count"); + SPU_ENFORCE(absl::SimpleAtoi(data, &chunk_count)); + } + proto.chunks.resize(chunk_count); + for (size_t s_idx = 0; s_idx < chunk_count; s_idx++) { + auto data = lctx->Recv(idx, "all2all_var_chunk"); + SPU_ENFORCE( + proto.chunks[s_idx].ParseFromArray(data.data(), data.size())); + } + st_proto.insert( + {std::string(static_cast(key.data()), key.size()), + std::move(proto)}); + } + cols.push_back(std::move(st_proto)); } return cols; @@ -214,8 +268,8 @@ void ColocatedIo::sync() { SPU_ENFORCE(shares.size() == lctx->WorldSize()); for (size_t idx = 0; idx < shares.size(); idx++) { - shares_per_party[idx].mutable_symbols()->insert( - {name, shares[idx].toProto()}); + shares_per_party[idx].insert( + {name, shares[idx].toProto(128UL * 1024 * 1024)}); } } @@ -224,7 +278,7 @@ void ColocatedIo::sync() { std::set all_names; for (const auto &values : values_per_party) { - for (const auto &[name, _] : values.symbols()) { + for (const auto &[name, _] : values) { SPU_ENFORCE(all_names.find(name) == all_names.end(), "name duplicated {}", name); all_names.insert(name); @@ -232,7 +286,7 @@ void ColocatedIo::sync() { } for (const auto &values : values_per_party) { - for (const auto &[name, proto] : values.symbols()) { + for (const auto &[name, proto] : values) { symbols_.setVar(name, spu::Value::fromProto(proto)); } } diff --git a/libspu/device/io.h b/libspu/device/io.h index 7e810feb..630e78a5 100644 --- a/libspu/device/io.h +++ b/libspu/device/io.h @@ -109,6 +109,9 @@ class IoClient { std::vector makeShares(const PtBufferView &bv, Visibility vtype, int owner_rank = -1); + size_t getShareSize(const PtBufferView &bv, Visibility vtype, + int owner_rank = -1); + // Combine shares to a plaintext ndarray. NdArrayRef combineShares(absl::Span values); }; diff --git a/libspu/device/pphlo/BUILD.bazel b/libspu/device/pphlo/BUILD.bazel index dceca77c..7e3c0e48 100644 --- a/libspu/device/pphlo/BUILD.bazel +++ b/libspu/device/pphlo/BUILD.bazel @@ -75,15 +75,3 @@ spu_cc_test( "@llvm-project//mlir:Parser", ], ) - -spu_cc_binary( - name = "executor_debug_runner", - testonly = True, - srcs = ["executor_debug_runner.cc"], - deps = [ - ":pphlo_executor", - "//libspu/device:api", - "//libspu/device:test_utils", - "@llvm-project//llvm:Support", - ], -) diff --git a/libspu/device/pphlo/executor_debug_runner.cc b/libspu/device/pphlo/executor_debug_runner.cc deleted file mode 100644 index 3876deed..00000000 --- a/libspu/device/pphlo/executor_debug_runner.cc +++ /dev/null @@ -1,138 +0,0 @@ -// Copyright 2021 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include - -#include "absl/strings/str_split.h" -#include "llvm/Support/CommandLine.h" -#include "spdlog/spdlog.h" - -#include "libspu/core/value.h" -#include "libspu/device/api.h" -#include "libspu/device/pphlo/pphlo_executor.h" -#include "libspu/device/symbol_table.h" -#include "libspu/device/test_utils.h" -#include "libspu/kernel/hal/debug.h" -#include "libspu/mpc/utils/simulate.h" - -llvm::cl::opt SnapshotDir( - "snapshot_dir", llvm::cl::desc("folder contains core snapshot files"), - llvm::cl::init(".")); - -// Mode switch -llvm::cl::opt LocalMode("local", llvm::cl::desc("local simulation mode"), - llvm::cl::init(false)); - -// Network only settings -llvm::cl::opt Parties( - "parties", llvm::cl::init("127.0.0.1:9530,127.0.0.1:9531,127.0.0.1:9532"), - llvm::cl::desc("server list, format: host1:port1[,host2:port2, ...]")); - -llvm::cl::opt Rank("rank", llvm::cl::init(0), - llvm::cl::desc("self rank")); - -// Local simulation only settings -llvm::cl::opt NumProc( - "num_processor", - llvm::cl::desc("number of processors to create (local simulation only)"), - llvm::cl::init(3)); - -std::shared_ptr MakeLink(const std::string &parties, - size_t rank) { - yacl::link::ContextDesc lctx_desc; - std::vector hosts = absl::StrSplit(parties, ','); - for (size_t rank = 0; rank < hosts.size(); rank++) { - const auto id = fmt::format("party{}", rank); - lctx_desc.parties.push_back({id, hosts[rank]}); - } - auto lctx = yacl::link::FactoryBrpc().CreateContext(lctx_desc, rank); - lctx->ConnectToMesh(); - return lctx; -} - -std::unique_ptr MakeSPUContext( - const spu::device::SnapshotProto &snapshot) { - auto lctx = MakeLink(Parties.getValue(), Rank.getValue()); - - return std::make_unique(snapshot.runtime_cfg(), lctx); -} - -spu::device::SnapshotProto ParseSnapshotFile( - const std::filesystem::path &snapshot_file) { - spu::device::SnapshotProto snapshot; - { - SPU_ENFORCE(std::filesystem::exists(snapshot_file), - "Serialized snapshot file {} does not exit", - snapshot_file.c_str()); - SPDLOG_INFO("Read snapshot file from {}", snapshot_file.c_str()); - std::ifstream stream(snapshot_file, std::ios::binary); - SPU_ENFORCE(snapshot.ParseFromIstream(&stream), - "Parse serialized snapshot file {} failed", - snapshot_file.c_str()); - } - - return snapshot; -} - -void RpcBasedRunner(const std::filesystem::path &snapshot_dir) { - auto snapshot_file = - snapshot_dir / fmt::format("snapshot_{}.spu", Rank.getValue()); - spu::device::SnapshotProto snapshot = ParseSnapshotFile(snapshot_file); - auto sctx = MakeSPUContext(snapshot); - - spu::device::SymbolTable table = - spu::device::SymbolTable::fromProto(snapshot.environ()); - - spu::device::pphlo::PPHloExecutor executor; - - SPDLOG_INFO("Run with config {}", sctx->config().DebugString()); - - spu::device::execute(&executor, sctx.get(), snapshot.executable(), &table); -} - -void MemBasedRunner(const std::filesystem::path &snapshot_dir) { - auto world_size = NumProc.getValue(); - - spu::mpc::utils::simulate( - world_size, [&](const std::shared_ptr<::yacl::link::Context> &lctx) { - auto snapshot_file = - snapshot_dir / fmt::format("snapshot_{}.spu", lctx->Rank()); - - spu::device::SnapshotProto snapshot = ParseSnapshotFile(snapshot_file); - - spu::SPUContext sctx(snapshot.runtime_cfg(), lctx); - - spu::device::pphlo::PPHloExecutor executor; - spu::device::SymbolTable table = - spu::device::SymbolTable::fromProto(snapshot.environ()); - spu::device::execute(&executor, &sctx, snapshot.executable(), &table); - }); -} - -int main(int argc, char **argv) { - llvm::cl::ParseCommandLineOptions(argc, argv); - - std::filesystem::path snapshot_dir = SnapshotDir.getValue(); - - auto local = LocalMode.getValue(); - - if (local) { - MemBasedRunner(snapshot_dir); - } else { - RpcBasedRunner(snapshot_dir); - } -} diff --git a/libspu/device/pphlo/pphlo_executor.cc b/libspu/device/pphlo/pphlo_executor.cc index c1b4e896..c122f639 100644 --- a/libspu/device/pphlo/pphlo_executor.cc +++ b/libspu/device/pphlo/pphlo_executor.cc @@ -204,6 +204,11 @@ void addValue(SymbolScope *scope, mlir::Value key, spu::Value &&val, scope->addValue(key, val); } +void removeValue(SymbolScope *scope, mlir::Value key, + const ExecutionOptions &) { + scope->removeValue(key); +} + // #define STANDARD_UNARY_OP_EXEC_IMPL(OpName, KernelName) \ void execute(OpExecutor *executor, SPUContext *sctx, SymbolScope *sscope, \ @@ -1109,6 +1114,25 @@ void execute(OpExecutor *executor, SPUContext *sctx, SymbolScope *sscope, kernel::hal::dbg_print(sctx, lookupValue(sscope, op.getOperand(), opts)); } +void execute(OpExecutor *executor, SPUContext *sctx, SymbolScope *sscope, + mlir::pphlo::FreeOp &op, const ExecutionOptions &opts) { + if (opts.do_parallel) { + // Think about the following case + // %a = def + // use(%a) + // use(%a) + // free(%a) + // Here free is also a consider a use...so under parallel execution free + // will be invoked once a is defined. + // This will make %a randomly dealloced after defined. + // FreeOp has an implicit requirement that it needs to be invoked after all + // other uses are done. + // FIXME(xiaochen): Enable this... + return; + } + removeValue(sscope, op.getOperand(), opts); +} + #define DEFINE_UNIMPLEMENTED_OP(OpName) \ void execute(OpExecutor *executor, SPUContext *sctx, SymbolScope *sscope, \ mlir::pphlo::OpName &, const ExecutionOptions &opts) { \ @@ -1149,7 +1173,8 @@ static void dispatchOp(OpExecutor *executor, SPUContext *sctx, // Execute op { const auto fn_name = op.getName().getStringRef().str(); - SPU_TRACE_ACTION(GET_TRACER(sctx), (TR_HLO | TR_LAR), ~TR_HLO, fn_name); + SPU_TRACE_ACTION(GET_TRACER(sctx), sctx->lctx(), (TR_HLO | TR_LAR), + ~TR_HLO, fn_name); execute(executor, sctx, sscope, casted, opts); } diff --git a/libspu/device/pphlo/pphlo_verifier.h b/libspu/device/pphlo/pphlo_verifier.h index 755aa81a..c9798068 100644 --- a/libspu/device/pphlo/pphlo_verifier.h +++ b/libspu/device/pphlo/pphlo_verifier.h @@ -142,6 +142,7 @@ class PPHloVerifier { NO_VERIFY_DEFN(PreferAOp) NO_VERIFY_DEFN(ArgMaxOp) NO_VERIFY_DEFN(EpsilonOp) + NO_VERIFY_DEFN(FreeOp) #undef NO_VERIFY_DEFN }; diff --git a/libspu/device/symbol_table.cc b/libspu/device/symbol_table.cc index 33128df4..12fae07a 100644 --- a/libspu/device/symbol_table.cc +++ b/libspu/device/symbol_table.cc @@ -34,10 +34,12 @@ void SymbolTable::delVar(const std::string &name) { data_.erase(name); } void SymbolTable::clear() { data_.clear(); } +/* SymbolTableProto SymbolTable::toProto() const { + const static size_t max_slice_size = 128UL * 1024 * 1024; SymbolTableProto proto; for (const auto &[name, value] : data_) { - proto.mutable_symbols()->insert({name, value.toProto()}); + proto.mutable_symbols()->insert({name, value.toProto(max_slice_size)}); } return proto; } @@ -49,5 +51,6 @@ SymbolTable SymbolTable::fromProto(const SymbolTableProto &proto) { } return st; } +*/ } // namespace spu::device diff --git a/libspu/device/symbol_table.h b/libspu/device/symbol_table.h index 2619456c..f82a2b66 100644 --- a/libspu/device/symbol_table.h +++ b/libspu/device/symbol_table.h @@ -19,8 +19,6 @@ #include "libspu/core/value.h" -#include "libspu/device/device.pb.h" - namespace spu::device { class SymbolTable { @@ -39,8 +37,10 @@ class SymbolTable { auto begin() { return data_.begin(); } auto end() { return data_.end(); } - SymbolTableProto toProto() const; - static SymbolTable fromProto(const SymbolTableProto &proto); + // @shantang / @wuju + // TODO: temporary remove, need to adapt value slice change + // SymbolTableProto toProto() const; + // static SymbolTable fromProto(const SymbolTableProto &proto); }; } // namespace spu::device diff --git a/libspu/dialect/pphlo_ops.td b/libspu/dialect/pphlo_ops.td index 6c5e309c..0607c8cb 100644 --- a/libspu/dialect/pphlo_ops.td +++ b/libspu/dialect/pphlo_ops.td @@ -860,6 +860,10 @@ def PPHLO_DbgPrintOp : PPHLO_Op<"dbg_print", []> { let arguments = (ins PPHLO_Tensor : $operand); } +def PPHLO_FreeOp : PPHLO_Op<"free", []> { + let arguments = (ins PPHLO_Tensor : $operand); +} + def PPHLO_ClampOp : PPHLO_Op<"clamp", [Pure, SameOperandsAndResultShape]> { let summary = "Clamp operator"; diff --git a/libspu/kernel/hal/BUILD.bazel b/libspu/kernel/hal/BUILD.bazel index bec7b44b..ff4d3d49 100644 --- a/libspu/kernel/hal/BUILD.bazel +++ b/libspu/kernel/hal/BUILD.bazel @@ -122,6 +122,7 @@ spu_cc_library( deps = [ ":fxp_base", ":fxp_cleartext", + ":shape_ops", ":type_cast", ], ) diff --git a/libspu/kernel/hal/fxp_approx.cc b/libspu/kernel/hal/fxp_approx.cc index 681726b7..6b03a7b1 100644 --- a/libspu/kernel/hal/fxp_approx.cc +++ b/libspu/kernel/hal/fxp_approx.cc @@ -15,6 +15,7 @@ #include "libspu/kernel/hal/fxp_approx.h" #include +#include #include #include @@ -22,19 +23,9 @@ #include "libspu/kernel/hal/fxp_base.h" #include "libspu/kernel/hal/fxp_cleartext.h" #include "libspu/kernel/hal/ring.h" +#include "libspu/kernel/hal/shape_ops.h" namespace spu::kernel::hal { - -namespace { - -// simple convenient function. -Value f_constant(SPUContext* ctx, const PtBufferView& init, DataType dtype, - absl::Span shape) { - return constant(ctx, init, dtype, shape); -} - -} // namespace - namespace detail { // Pade approximation fo x belongs to [0.5, 1]: @@ -49,19 +40,19 @@ namespace detail { // + x^3 * 0.1 *10 // log2(x) = p2524(x) / q2524(x) // -Value log2_pade_approx_for_normalized(SPUContext* ctx, const Value& x) { +Value log2_pade_normalized(SPUContext* ctx, const Value& x) { const auto x2 = f_square(ctx, x); const auto x3 = f_mul(ctx, x2, x); - const auto p0 = f_constant(ctx, -0.205466671951F * 10, x.dtype(), x.shape()); - const auto p1 = f_constant(ctx, -0.88626599391F * 10, x.dtype(), x.shape()); - const auto p2 = f_constant(ctx, 0.610585199015F * 10, x.dtype(), x.shape()); - const auto p3 = f_constant(ctx, 0.481147460989F * 10, x.dtype(), x.shape()); + const auto p0 = constant(ctx, -0.205466671951F * 10, x.dtype(), x.shape()); + const auto p1 = constant(ctx, -0.88626599391F * 10, x.dtype(), x.shape()); + const auto p2 = constant(ctx, 0.610585199015F * 10, x.dtype(), x.shape()); + const auto p3 = constant(ctx, 0.481147460989F * 10, x.dtype(), x.shape()); - const auto q0 = f_constant(ctx, 0.353553425277F, x.dtype(), x.shape()); - const auto q1 = f_constant(ctx, 0.454517087629F * 10, x.dtype(), x.shape()); - const auto q2 = f_constant(ctx, 0.642784209029F * 10, x.dtype(), x.shape()); - const auto q3 = f_constant(ctx, 0.1F * 10, x.dtype(), x.shape()); + const auto q0 = constant(ctx, 0.353553425277F, x.dtype(), x.shape()); + const auto q1 = constant(ctx, 0.454517087629F * 10, x.dtype(), x.shape()); + const auto q2 = constant(ctx, 0.642784209029F * 10, x.dtype(), x.shape()); + const auto q3 = constant(ctx, 0.1F * 10, x.dtype(), x.shape()); auto p2524 = _mul(ctx, x, p1); p2524 = _add(ctx, p2524, _mul(ctx, x2, p2)); @@ -80,7 +71,7 @@ Value log2_pade_approx_for_normalized(SPUContext* ctx, const Value& x) { // Chapter 5 Exponentiation and Logarithms // Benchmarking Privacy Preserving Scientific Operations // https://www.esat.kuleuven.be/cosic/publications/article-3013.pdf -Value log2_pade_approx(SPUContext* ctx, const Value& x) { +Value log2_pade(SPUContext* ctx, const Value& x) { SPU_TRACE_HAL_DISP(ctx, x); const size_t bit_width = SizeOf(ctx->config().field()) * 8; @@ -98,7 +89,7 @@ Value log2_pade_approx(SPUContext* ctx, const Value& x) { // = log2(x_norm) + log2(factor) // = log2(x_norm) + (k-fxp_bits) return _add( - ctx, log2_pade_approx_for_normalized(ctx, norm), + ctx, log2_pade_normalized(ctx, norm), _lshift(ctx, _sub(ctx, k, _constant(ctx, num_fxp_bits, x.shape())), num_fxp_bits)) .setDtype(x.dtype()); @@ -110,34 +101,33 @@ Value log2_pade_approx(SPUContext* ctx, const Value& x) { // Approximates the natural logarithm using 8th order modified // Householder iterations. This approximation is accurate within 2% relative // error on [0.0001, 250]. -Value log_householder_approx(SPUContext* ctx, const Value& x) { - Value term_1 = f_div(ctx, x, f_constant(ctx, 120.0, x.dtype(), x.shape())); +Value log_householder(SPUContext* ctx, const Value& x) { + Value term_1 = f_div(ctx, x, constant(ctx, 120.0, x.dtype(), x.shape())); Value term_2 = f_mul( ctx, f_exp(ctx, - f_negate( - ctx, - f_add(ctx, - f_mul(ctx, x, f_constant(ctx, 2.0, x.dtype(), x.shape())), - f_constant(ctx, 1.0, x.dtype(), x.shape())))), - f_constant(ctx, 20.0, x.dtype(), x.shape())); + f_negate(ctx, f_add(ctx, + f_mul(ctx, x, + constant(ctx, 2.0, x.dtype(), x.shape())), + constant(ctx, 1.0, x.dtype(), x.shape())))), + constant(ctx, 20.0, x.dtype(), x.shape())); Value y = f_add(ctx, f_sub(ctx, term_1, term_2), - f_constant(ctx, 3.0, x.dtype(), x.shape())); + constant(ctx, 3.0, x.dtype(), x.shape())); - std::vector coeffs; const size_t fxp_log_orders = ctx->config().fxp_log_orders(); SPU_ENFORCE(fxp_log_orders != 0, "fxp_log_orders should not be {}", fxp_log_orders); + std::vector coeffs; for (size_t i = 0; i < fxp_log_orders; i++) { - coeffs.emplace_back(f_constant(ctx, 1.0 / (1.0 + i), x.dtype(), x.shape())); + coeffs.emplace_back(1.0 / (1.0 + i)); } const size_t num_iters = ctx->config().fxp_log_iters(); SPU_ENFORCE(num_iters != 0, "fxp_log_iters should not be {}", num_iters); for (size_t i = 0; i < num_iters; i++) { - Value h = f_sub(ctx, f_constant(ctx, 1.0, x.dtype(), x.shape()), + Value h = f_sub(ctx, constant(ctx, 1.0, x.dtype(), x.shape()), f_mul(ctx, x, f_exp(ctx, f_negate(ctx, y)))); - y = f_sub(ctx, y, detail::f_polynomial(ctx, h, coeffs)); + y = f_sub(ctx, y, detail::polynomial(ctx, h, coeffs)); } return y; @@ -145,13 +135,13 @@ Value log_householder_approx(SPUContext* ctx, const Value& x) { // see https://lvdmaaten.github.io/publications/papers/crypten.pdf // exp(x) = (1 + x / n) ^ n, when n is infinite large. -Value exp_taylor_series(SPUContext* ctx, const Value& x) { +Value exp_taylor(SPUContext* ctx, const Value& x) { const size_t fxp_exp_iters = ctx->config().fxp_exp_iters(); SPU_ENFORCE(fxp_exp_iters != 0, "fxp_exp_iters should not be {}", fxp_exp_iters); Value res = f_add(ctx, _trunc(ctx, x, fxp_exp_iters).setDtype(x.dtype()), - f_constant(ctx, 1.0F, x.dtype(), x.shape())); + constant(ctx, 1.0F, x.dtype(), x.shape())); for (size_t i = 0; i < fxp_exp_iters; i++) { res = f_square(ctx, res); @@ -160,6 +150,8 @@ Value exp_taylor_series(SPUContext* ctx, const Value& x) { return res; } +namespace { + // Pade approximation of exp2(x), x is in [0, 1]. // p1015(x) = 0.100000007744302 * 10 // + x * 0.693147180426163 @@ -167,23 +159,18 @@ Value exp_taylor_series(SPUContext* ctx, const Value& x) { // + x^3 * 0.555040686204663 / 10 // + x^4 * 0.961834122588046 / 100 // + x^5 * 0.133273035928143 / 100 -Value exp2_pade_approx_for_positive_pure_decimal(SPUContext* ctx, - const Value& x) { +Value exp2_pade_normalized(SPUContext* ctx, const Value& x) { auto x2 = f_mul(ctx, x, x); auto x3 = f_mul(ctx, x, x2); auto x4 = f_mul(ctx, x, x3); auto x5 = f_mul(ctx, x, x4); - const auto p0 = - f_constant(ctx, 0.100000007744302F * 10, x.dtype(), x.shape()); - const auto p1 = f_constant(ctx, 0.693147180426163F, x.dtype(), x.shape()); - const auto p2 = f_constant(ctx, 0.240226510710170F, x.dtype(), x.shape()); - const auto p3 = - f_constant(ctx, 0.555040686204663F / 10, x.dtype(), x.shape()); - const auto p4 = - f_constant(ctx, 0.961834122588046F / 100, x.dtype(), x.shape()); - const auto p5 = - f_constant(ctx, 0.133273035928143F / 100, x.dtype(), x.shape()); + const auto p0 = constant(ctx, 0.100000007744302F * 10, x.dtype(), x.shape()); + const auto p1 = constant(ctx, 0.693147180426163F, x.dtype(), x.shape()); + const auto p2 = constant(ctx, 0.240226510710170F, x.dtype(), x.shape()); + const auto p3 = constant(ctx, 0.555040686204663F / 10, x.dtype(), x.shape()); + const auto p4 = constant(ctx, 0.961834122588046F / 100, x.dtype(), x.shape()); + const auto p5 = constant(ctx, 0.133273035928143F / 100, x.dtype(), x.shape()); auto res = _mul(ctx, x, p1); res = _add(ctx, res, _mul(ctx, x2, p2)); @@ -194,13 +181,15 @@ Value exp2_pade_approx_for_positive_pure_decimal(SPUContext* ctx, return _add(ctx, _trunc(ctx, res), p0).setDtype(x.dtype()); } +} // namespace + // Refer to // Chapter 5 Exponentiation and Logarithms // Benchmarking Privacy Preserving Scientific Operations // https://www.esat.kuleuven.be/cosic/publications/article-3013.pdf // NOTE(junfeng): The valid integer bits of x is 5. Otherwise, the output is // incorrect. -Value exp2_pade_approx(SPUContext* ctx, const Value& x) { +Value exp2_pade(SPUContext* ctx, const Value& x) { const size_t fbits = ctx->getFxpBits(); const auto k1 = _constant(ctx, 1U, x.shape()); // TODO(junfeng): Make int_bits configurable. @@ -212,7 +201,7 @@ Value exp2_pade_approx(SPUContext* ctx, const Value& x) { auto x_integer = _rshift(ctx, x_bshare, fbits); auto x_fraction = _sub(ctx, x, _lshift(ctx, x_integer, fbits)).setDtype(x.dtype()); - auto ret = exp2_pade_approx_for_positive_pure_decimal(ctx, x_fraction); + auto ret = exp2_pade_normalized(ctx, x_fraction); for (size_t idx = 0; idx < int_bits; idx++) { auto a = _and(ctx, _rshift(ctx, x_integer, idx), k1); @@ -241,17 +230,17 @@ Value exp2_pade_approx(SPUContext* ctx, const Value& x) { _mul(ctx, x_msb, f_sub(ctx, ret_reciprocal, ret)).setDtype(ret.dtype())); } -Value exp_pade_approx(SPUContext* ctx, const Value& x) { +Value exp_pade(SPUContext* ctx, const Value& x) { return f_exp2(ctx, f_mul(ctx, x, - f_constant(ctx, std::log2(std::exp(1.0F)), x.dtype(), - x.shape()))); + constant(ctx, std::log2(std::exp(1.0F)), x.dtype(), + x.shape()))); } // Refer to // https://www.wolframalpha.com/input?i=Pade+approximation+tanh%28x%29+order+5%2C5. // tanh(x) = (x + x^3 / 9.0 + x^5 /945.0) / // (1 + 4 * x^2 / 9.0 + x^4 / 63.0) -Value tanh_pade_approx(SPUContext* ctx, const Value& x) { +Value tanh_pade(SPUContext* ctx, const Value& x) { const auto x_2 = f_square(ctx, x); const auto x_4 = f_square(ctx, x_2); @@ -261,7 +250,7 @@ Value tanh_pade_approx(SPUContext* ctx, const Value& x) { // = x * (945 + 105 * x^2 + x^4) / (945 + 420 * x^2 + 15 * x^4) // This can save some truncations - const auto c_945 = f_constant(ctx, 945.0F, x.dtype(), x.shape()); + const auto c_945 = constant(ctx, 945.0F, x.dtype(), x.shape()); const auto c_105 = constant(ctx, 105, DT_I32, x.shape()); const auto c_420 = constant(ctx, 420, DT_I32, x.shape()); const auto c_15 = constant(ctx, 15, DT_I32, x.shape()); @@ -278,6 +267,66 @@ Value tanh_pade_approx(SPUContext* ctx, const Value& x) { return f_div(ctx, nominator, denominator); } +// Reference: +// https://github.com/facebookresearch/CrypTen/blob/6ef151101668591bcfb2bbf7e7ebd39ab6db0413/crypten/common/functions/approximations.py#L365 +Value compute_chebyshev_polynomials(SPUContext* ctx, const Value& x, + int64_t terms) { + // Ref: + // https://en.wikipedia.org/wiki/Chebyshev_polynomials#Recurrence_definition + // Chebyshev Polynomials of the first kind are defined as + //.. math:: + // P_0(x) = 1 + // P_1(x) = x + // P_{n+1}(x) = 2xP_{n}(x) - P_{n-1}(x) + std::vector poly = {x}; + + // y = 4*x^2 - 2 + auto four = constant(ctx, 4, DT_I32, x.shape()); + auto two = constant(ctx, 2.0F, x.dtype(), x.shape()); + auto y = + f_sub(ctx, _mul(ctx, four, f_square(ctx, x)).setDtype(x.dtype()), two); + // z = y - 1 + auto one = constant(ctx, 1.0F, x.dtype(), x.shape()); + auto z = f_sub(ctx, y, one); + + poly.emplace_back(f_mul(ctx, x, z)); + + for (int64_t idx = 2; idx < terms; ++idx) { + // next_polynomial = y * polynomials[k - 1] - polynomials[k - 2] + auto next = f_sub(ctx, f_mul(ctx, y, poly[idx - 1]), poly[idx - 2]); + poly.emplace_back(std::move(next)); + } + + return concatenate(ctx, poly, 0); +} + +Value tanh_chebyshev(SPUContext* ctx, const Value& x) { + // Cheb coeff, deg = 17, domain = [-5,5] + static const std::array kCoeffs = { + 1.2514045938932097, -0.3655987797163166, 0.17253141478140663, + -0.08943445792774211, 0.047703017901250824, -0.025830290571688078, + 0.014338801903468182, -0.008541730970059077, 0.0061230685785789475}; + + auto coeff_value = constant(ctx, kCoeffs, x.dtype(), + {1, static_cast(kCoeffs.size())}); + + auto normalized_x = reshape(ctx, x, {1, x.numel()}); + + normalized_x = + _clamp(ctx, normalized_x, + constant(ctx, -5.0F, normalized_x.dtype(), normalized_x.shape()), + constant(ctx, 5.0F, normalized_x.dtype(), normalized_x.shape())) + .setDtype(x.dtype()); + + normalized_x = f_mul( + ctx, constant(ctx, 0.2F, x.dtype(), normalized_x.shape()), normalized_x); + auto poly = compute_chebyshev_polynomials(ctx, normalized_x, kCoeffs.size()); + + auto ret = f_mmul(ctx, coeff_value, poly); + + return reshape(ctx, ret, x.shape()); +} + } // namespace detail Value f_exp(SPUContext* ctx, const Value& x) { @@ -292,16 +341,16 @@ Value f_exp(SPUContext* ctx, const Value& x) { switch (ctx->config().fxp_exp_mode()) { case RuntimeConfig::EXP_DEFAULT: case RuntimeConfig::EXP_TAYLOR: - return detail::exp_taylor_series(ctx, x); + return detail::exp_taylor(ctx, x); case RuntimeConfig::EXP_PADE: { - // The valid input for exp_pade_approx is [-kInputLimit, kInputLimit]. - // TODO(junfeng): should merge clamp into exp_pade_approx to save msb ops. + // The valid input for exp_pade is [-kInputLimit, kInputLimit]. + // TODO(junfeng): should merge clamp into exp_pade to save msb ops. const float kInputLimit = 32 / std::log2(std::exp(1)); const auto clamped_x = - _clamp(ctx, x, f_constant(ctx, -kInputLimit, x.dtype(), x.shape()), - f_constant(ctx, kInputLimit, x.dtype(), x.shape())) + _clamp(ctx, x, constant(ctx, -kInputLimit, x.dtype(), x.shape()), + constant(ctx, kInputLimit, x.dtype(), x.shape())) .setDtype(x.dtype()); - return detail::exp_pade_approx(ctx, clamped_x); + return detail::exp_pade(ctx, clamped_x); } default: SPU_THROW("unexpected exp approximation method {}", @@ -321,10 +370,10 @@ Value f_log(SPUContext* ctx, const Value& x) { switch (ctx->config().fxp_log_mode()) { case RuntimeConfig::LOG_DEFAULT: case RuntimeConfig::LOG_PADE: - return f_mul(ctx, f_constant(ctx, std::log(2.0F), x.dtype(), x.shape()), + return f_mul(ctx, constant(ctx, std::log(2.0F), x.dtype(), x.shape()), f_log2(ctx, x)); case RuntimeConfig::LOG_NEWTON: - return detail::log_householder_approx(ctx, x); + return detail::log_householder(ctx, x); default: SPU_THROW("unexpected log approximation method {}", ctx->config().fxp_log_mode()); @@ -336,7 +385,7 @@ Value f_log1p(SPUContext* ctx, const Value& x) { SPU_ENFORCE(x.isFxp()); - return f_log(ctx, f_add(ctx, f_constant(ctx, 1.0F, x.dtype(), x.shape()), x)); + return f_log(ctx, f_add(ctx, constant(ctx, 1.0F, x.dtype(), x.shape()), x)); } Value f_log2(SPUContext* ctx, const Value& x) { @@ -344,18 +393,21 @@ Value f_log2(SPUContext* ctx, const Value& x) { SPU_ENFORCE(x.isFxp()); - return detail::log2_pade_approx(ctx, x).setDtype(x.dtype()); + return detail::log2_pade(ctx, x).setDtype(x.dtype()); } Value f_exp2(SPUContext* ctx, const Value& x) { SPU_TRACE_HAL_LEAF(ctx, x); - return detail::exp2_pade_approx(ctx, x); + return detail::exp2_pade(ctx, x); } Value f_tanh(SPUContext* ctx, const Value& x) { SPU_TRACE_HAL_LEAF(ctx, x); +#ifndef TANH_USE_PADE + return detail::tanh_chebyshev(ctx, x); +#elif // For tanh inputs beyond [-3, 3], result is infinitely close to -1, 1 // pade approximation has a relative ok result between [-3, 3], so clamp // inputs to this range. @@ -363,7 +415,8 @@ Value f_tanh(SPUContext* ctx, const Value& x) { constant(ctx, 3.F, x.dtype(), x.shape())) .setDtype(x.dtype()); - return detail::tanh_pade_approx(ctx, normalized_x); + return detail::tanh_pade(ctx, normalized_x); +#endif } static Value rsqrt_init_guess(SPUContext* ctx, const Value& x, const Value& z) { @@ -380,19 +433,13 @@ static Value rsqrt_init_guess(SPUContext* ctx, const Value& x, const Value& z) { // - 15.47994394 * u + 4.14285016 spu::Value r; if (!ctx->config().enable_lower_accuracy_rsqrt()) { - std::vector coeffs = { - f_constant(ctx, -15.47994394F, x.dtype(), x.shape()), - f_constant(ctx, 38.4714796F, x.dtype(), x.shape()), - f_constant(ctx, -49.86605845F, x.dtype(), x.shape()), - f_constant(ctx, 26.02942339F, x.dtype(), x.shape())}; - r = f_add(ctx, detail::f_polynomial(ctx, u, coeffs), - f_constant(ctx, 4.14285016F, x.dtype(), x.shape())); + auto coeffs = {-15.47994394F, 38.4714796F, -49.86605845F, 26.02942339F}; + r = f_add(ctx, detail::polynomial(ctx, u, coeffs), + constant(ctx, 4.14285016F, x.dtype(), x.shape())); } else { - std::vector coeffs = { - f_constant(ctx, -5.9417F, x.dtype(), x.shape()), - f_constant(ctx, 4.7979F, x.dtype(), x.shape())}; - r = f_add(ctx, detail::f_polynomial(ctx, u, coeffs), - f_constant(ctx, 3.1855F, x.dtype(), x.shape())); + auto coeffs = {-5.9417F, 4.7979F}; + r = f_add(ctx, detail::polynomial(ctx, u, coeffs), + constant(ctx, 3.1855F, x.dtype(), x.shape())); } return r; @@ -494,8 +541,8 @@ Value f_rsqrt(SPUContext* ctx, const Value& x) { Value f_sqrt(SPUContext* ctx, const Value& x) { SPU_TRACE_HAL_LEAF(ctx, x); - const auto c0 = f_constant(ctx, 0.5F, x.dtype(), x.shape()); - const auto c1 = f_constant(ctx, 1.5F, x.dtype(), x.shape()); + const auto c0 = constant(ctx, 0.5F, x.dtype(), x.shape()); + const auto c1 = constant(ctx, 1.5F, x.dtype(), x.shape()); Value y0 = f_rsqrt(ctx, x); Value g = f_mul(ctx, x, y0); @@ -515,30 +562,30 @@ Value f_sqrt(SPUContext* ctx, const Value& x) { namespace { -Value sigmiod_real(SPUContext* ctx, const Value& x) { +Value sigmoid_real(SPUContext* ctx, const Value& x) { // f(x) = 1/(1+exp(-x)) - const auto c1 = f_constant(ctx, 1.0F, x.dtype(), x.shape()); + const auto c1 = constant(ctx, 1.0F, x.dtype(), x.shape()); return f_reciprocal(ctx, f_add(ctx, c1, f_exp(ctx, f_negate(ctx, x)))); } -Value sigmiod_mm1(SPUContext* ctx, const Value& x) { +Value sigmoid_mm1(SPUContext* ctx, const Value& x) { // SigmoidMM1: f(x) = 0.5 + 0.125 * x - const auto c1 = f_constant(ctx, 0.5F, x.dtype(), x.shape()); - const auto c2 = f_constant(ctx, 0.125F, x.dtype(), x.shape()); + const auto c1 = constant(ctx, 0.5F, x.dtype(), x.shape()); + const auto c2 = constant(ctx, 0.125F, x.dtype(), x.shape()); return f_add(ctx, c1, f_mul(ctx, c2, x)); } -Value sigmiod_seg3(SPUContext* ctx, const Value& x) { +Value sigmoid_seg3(SPUContext* ctx, const Value& x) { // f(x) = 0.5 + 0.125x if -4 <= x <= 4 // 1 if x > 4 // 0 if -4 > x // Rounds = Gt + Mux*2 = 4 + Log(K) - auto upper = f_constant(ctx, 1.0F, x.dtype(), x.shape()); - auto lower = f_constant(ctx, 0.0F, x.dtype(), x.shape()); - auto middle = sigmiod_mm1(ctx, x); + auto upper = constant(ctx, 1.0F, x.dtype(), x.shape()); + auto lower = constant(ctx, 0.0F, x.dtype(), x.shape()); + auto middle = sigmoid_mm1(ctx, x); - auto upper_bound = f_constant(ctx, 4.0F, x.dtype(), x.shape()); - auto lower_bound = f_constant(ctx, -4.0F, x.dtype(), x.shape()); + auto upper_bound = constant(ctx, 4.0F, x.dtype(), x.shape()); + auto lower_bound = constant(ctx, -4.0F, x.dtype(), x.shape()); auto ret = _mux(ctx, f_less(ctx, upper_bound, x), upper, middle); return _mux(ctx, f_less(ctx, x, lower_bound), lower, ret).setDtype(x.dtype()); @@ -554,13 +601,13 @@ Value f_sigmoid(SPUContext* ctx, const Value& x) { switch (ctx->config().sigmoid_mode()) { case RuntimeConfig::SIGMOID_DEFAULT: case RuntimeConfig::SIGMOID_MM1: { - return sigmiod_mm1(ctx, x); + return sigmoid_mm1(ctx, x); } case RuntimeConfig::SIGMOID_SEG3: { - return sigmiod_seg3(ctx, x); + return sigmoid_seg3(ctx, x); } case RuntimeConfig::SIGMOID_REAL: { - return sigmiod_real(ctx, x); + return sigmoid_real(ctx, x); } default: { SPU_THROW("Should not hit"); diff --git a/libspu/kernel/hal/fxp_approx.h b/libspu/kernel/hal/fxp_approx.h index 84febd9a..c9179b05 100644 --- a/libspu/kernel/hal/fxp_approx.h +++ b/libspu/kernel/hal/fxp_approx.h @@ -21,17 +21,19 @@ namespace spu::kernel::hal { namespace detail { -Value log2_pade_approx(SPUContext* ctx, const Value& x); +Value log2_pade(SPUContext* ctx, const Value& x); -Value log_householder_approx(SPUContext* ctx, const Value& x); +Value log_householder(SPUContext* ctx, const Value& x); // Works for range [-500, 2.1] -Value exp_taylor_series(SPUContext* ctx, const Value& x); +Value exp_taylor(SPUContext* ctx, const Value& x); -Value exp2_pade_approx(SPUContext* ctx, const Value& x); +Value exp2_pade(SPUContext* ctx, const Value& x); // Works for range [-12.0, 18.0] -Value exp_pade_approx(SPUContext* ctx, const Value& x); +Value exp_pade(SPUContext* ctx, const Value& x); + +Value tanh_chebyshev(SPUContext* ctx, const Value& x); } // namespace detail diff --git a/libspu/kernel/hal/fxp_approx_test.cc b/libspu/kernel/hal/fxp_approx_test.cc index f8b9da39..df4269ae 100644 --- a/libspu/kernel/hal/fxp_approx_test.cc +++ b/libspu/kernel/hal/fxp_approx_test.cc @@ -54,7 +54,7 @@ TEST(FxpTest, ExponentialTaylorSeries) { }; Value a = test::makeValue(&ctx, x, VIS_SECRET); - Value c = detail::exp_taylor_series(&ctx, a); + Value c = detail::exp_taylor(&ctx, a); EXPECT_EQ(c.dtype(), DT_F32); auto y = dump_public_as(&ctx, reveal(&ctx, c)); @@ -69,7 +69,7 @@ TEST(FxpTest, ExponentialPade) { xt::xarray x = xt::linspace(-22., 22., 4000); Value a = test::makeValue(&ctx, x, VIS_SECRET); - Value c = detail::exp_pade_approx(&ctx, a); + Value c = detail::exp_pade(&ctx, a); EXPECT_EQ(c.dtype(), DT_F32); auto y = dump_public_as(&ctx, reveal(&ctx, c)); diff --git a/libspu/kernel/hal/fxp_base.cc b/libspu/kernel/hal/fxp_base.cc index 8b80e15a..cf2ce251 100644 --- a/libspu/kernel/hal/fxp_base.cc +++ b/libspu/kernel/hal/fxp_base.cc @@ -29,8 +29,8 @@ namespace detail { // // Coefficients should be ordered from the order 1 (linear) term first, ending // with the highest order term. (Constant is not included). -Value f_polynomial(SPUContext* ctx, const Value& x, - const std::vector& coeffs) { +Value polynomial(SPUContext* ctx, const Value& x, + absl::Span coeffs) { SPU_TRACE_HAL_DISP(ctx, x); SPU_ENFORCE(x.isFxp()); SPU_ENFORCE(!coeffs.empty()); @@ -39,7 +39,7 @@ Value f_polynomial(SPUContext* ctx, const Value& x, Value res = _mul(ctx, x_pow, coeffs[0]); for (size_t i = 1; i < coeffs.size(); i++) { - if (i & 1) { + if ((i & 1) != 0U) { // x^{even order} is always positive x_pow = _trunc_with_sign(ctx, _mul(ctx, x_pow, x), ctx->getFxpBits(), true); @@ -52,6 +52,16 @@ Value f_polynomial(SPUContext* ctx, const Value& x, return _trunc(ctx, res).setDtype(x.dtype()); } +Value polynomial(SPUContext* ctx, const Value& x, + absl::Span coeffs) { + std::vector cs; + cs.reserve(coeffs.size()); + for (const auto& c : coeffs) { + cs.push_back(constant(ctx, c, x.dtype(), x.shape())); + } + return polynomial(ctx, x, cs); +} + Value highestOneBit(SPUContext* ctx, const Value& x) { auto y = _prefix_or(ctx, x); auto y1 = _rshift(ctx, y, 1); diff --git a/libspu/kernel/hal/fxp_base.h b/libspu/kernel/hal/fxp_base.h index 44029b22..c9db4a2a 100644 --- a/libspu/kernel/hal/fxp_base.h +++ b/libspu/kernel/hal/fxp_base.h @@ -34,8 +34,11 @@ Value reciprocal_goldschmidt_positive(SPUContext* ctx, const Value& b_abs); Value reciprocal_goldschmidt(SPUContext* ctx, const Value& b); -Value f_polynomial(SPUContext* ctx, const Value& x, - const std::vector& coeffs); +Value polynomial(SPUContext* ctx, const Value& x, + absl::Span coeffs); + +Value polynomial(SPUContext* ctx, const Value& x, + absl::Span coeffs); } // namespace detail diff --git a/libspu/mpc/aby3/io.cc b/libspu/mpc/aby3/io.cc index f1afd224..a51fdbf6 100644 --- a/libspu/mpc/aby3/io.cc +++ b/libspu/mpc/aby3/io.cc @@ -25,6 +25,20 @@ namespace spu::mpc::aby3 { +Type Aby3Io::getShareType(Visibility vis, int owner_rank) const { + if (vis == VIS_PUBLIC) { + return makeType(field_); + } else if (vis == VIS_SECRET) { + if (owner_rank >= 0 && owner_rank <= 2) { + return makeType(field_, owner_rank); + } else { + return makeType(field_); + } + } + + SPU_THROW("unsupported vis type {}", vis); +} + std::vector Aby3Io::toShares(const ArrayRef& raw, Visibility vis, int owner_rank) const { SPU_ENFORCE(raw.eltype().isa(), "expected RingTy, got {}", @@ -74,6 +88,11 @@ std::vector Aby3Io::toShares(const ArrayRef& raw, Visibility vis, SPU_THROW("unsupported vis type {}", vis); } +size_t Aby3Io::getBitSecretShareSize(size_t numel) const { + const auto type = makeType(PT_U8, 1); + return numel * type.size(); +} + std::vector Aby3Io::makeBitSecret(const ArrayRef& in) const { SPU_ENFORCE(in.eltype().isa(), "expected PtType, got {}", in.eltype()); PtType in_pt_type = in.eltype().as()->pt_type(); diff --git a/libspu/mpc/aby3/io.h b/libspu/mpc/aby3/io.h index d2c46907..ebc6d689 100644 --- a/libspu/mpc/aby3/io.h +++ b/libspu/mpc/aby3/io.h @@ -25,9 +25,12 @@ class Aby3Io final : public BaseIo { std::vector toShares(const ArrayRef& raw, Visibility vis, int owner_rank) const override; + Type getShareType(Visibility vis, int owner_rank = -1) const override; + ArrayRef fromShares(const std::vector& shares) const override; std::vector makeBitSecret(const ArrayRef& in) const override; + size_t getBitSecretShareSize(size_t numel) const override; bool hasBitSecretSupport() const override { return true; } }; diff --git a/libspu/mpc/cheetah/ot/ferret.cc b/libspu/mpc/cheetah/ot/ferret.cc index 130925ea..b1bec43d 100644 --- a/libspu/mpc/cheetah/ot/ferret.cc +++ b/libspu/mpc/cheetah/ot/ferret.cc @@ -194,11 +194,10 @@ struct FerretOT::Impl { } public: - Impl(std::shared_ptr conn, bool is_sender) + Impl(std::shared_ptr conn, bool is_sender, bool malicious) : is_sender_(is_sender) { SPU_ENFORCE(conn != nullptr); constexpr int thread = 1; - constexpr bool malicious = false; constexpr bool run_setup = true; int role = is_sender ? emp::ALICE : emp::BOB; io_ = std::make_shared(conn); @@ -675,10 +674,46 @@ struct FerretOT::Impl { } } } + + template + void SendRMCC(absl::Span output0, absl::Span output1, + size_t bit_width) { + size_t n = output0.size(); + SPU_ENFORCE(n > 0); + SPU_ENFORCE_EQ(n, output1.size()); + + std::vector rm_data(2 * n); + auto* rm_data0 = rm_data.data(); + auto* rm_data1 = rm_data.data() + n; + SendRandMsgChosenChoice(rm_data0, rm_data1, n); + + const T msg_mask = makeBitsMask(bit_width); + for (size_t i = 0; i < n; ++i) { + output0[i] = ConvFromBlock(rm_data[i]) & msg_mask; + output1[i] = ConvFromBlock(rm_data1[i]) & msg_mask; + } + } + + template + void RecvRMCC(absl::Span choices, absl::Span output, + size_t bit_width) { + size_t n = choices.size(); + SPU_ENFORCE(n > 0); + SPU_ENFORCE_EQ(n, output.size()); + + std::vector rm_data(n); + RecvRandMsgChosenChoice(choices, absl::MakeSpan(rm_data)); + + const T msg_mask = makeBitsMask(bit_width); + for (size_t i = 0; i < n; ++i) { + output[i] = ConvFromBlock(rm_data[i]) & msg_mask; + } + } }; -FerretOT::FerretOT(std::shared_ptr conn, bool is_sender) { - impl_ = std::make_shared(conn, is_sender); +FerretOT::FerretOT(std::shared_ptr conn, bool is_sender, + bool malicious) { + impl_ = std::make_shared(conn, is_sender, malicious); } int FerretOT::Rank() const { return impl_->Rank(); } @@ -725,6 +760,16 @@ size_t CheckBitWidth(size_t bw) { absl::Span output, size_t bit_width) { \ bit_width = CheckBitWidth(bit_width); \ impl_->RecvChosenMsgChosenChoice(choices, N, output, bit_width); \ + } \ + void FerretOT::SendRMCC(absl::Span output0, absl::Span output1, \ + size_t bit_width) { \ + bit_width = CheckBitWidth(bit_width); \ + impl_->SendRMCC(output0, output1, bit_width); \ + } \ + void FerretOT::RecvRMCC(absl::Span choices, \ + absl::Span output, size_t bit_width) { \ + bit_width = CheckBitWidth(bit_width); \ + impl_->RecvRMCC(choices, output, bit_width); \ } DEF_SEND_RECV(uint8_t) diff --git a/libspu/mpc/cheetah/ot/ferret.h b/libspu/mpc/cheetah/ot/ferret.h index 9f1b2645..01b61d98 100644 --- a/libspu/mpc/cheetah/ot/ferret.h +++ b/libspu/mpc/cheetah/ot/ferret.h @@ -30,7 +30,8 @@ class FerretOT { std::shared_ptr impl_; public: - FerretOT(std::shared_ptr conn, bool is_sender); + FerretOT(std::shared_ptr conn, bool is_sender, + bool malicious = false); ~FerretOT(); @@ -98,6 +99,25 @@ class FerretOT { absl::Span output, int bit_width = 0); void RecvCAMCC(absl::Span binary_choices, absl::Span output, int bit_width = 0); + + // Random Message Chosen Choice + void SendRMCC(absl::Span output0, absl::Span output1, + size_t bit_width = 0); + void SendRMCC(absl::Span output0, absl::Span output1, + size_t bit_width = 0); + void SendRMCC(absl::Span output0, absl::Span output1, + size_t bit_width = 0); + void SendRMCC(absl::Span output0, absl::Span output1, + size_t bit_width = 0); + + void RecvRMCC(absl::Span binary_choices, + absl::Span output, size_t bit_width = 0); + void RecvRMCC(absl::Span binary_choices, + absl::Span output, size_t bit_width = 0); + void RecvRMCC(absl::Span binary_choices, + absl::Span output, size_t bit_width = 0); + void RecvRMCC(absl::Span binary_choices, + absl::Span output, size_t bit_width = 0); }; } // namespace spu::mpc::cheetah diff --git a/libspu/mpc/cheetah/ot/ferret_test.cc b/libspu/mpc/cheetah/ot/ferret_test.cc index 6239a431..df27ed9a 100644 --- a/libspu/mpc/cheetah/ot/ferret_test.cc +++ b/libspu/mpc/cheetah/ot/ferret_test.cc @@ -110,6 +110,48 @@ TEST_P(FerretCOTTest, RndMsgRndChoice) { }); } +TEST_P(FerretCOTTest, RndMsgChosenChoice) { + size_t kWorldSize = 2; + auto field = GetParam(); + constexpr size_t bw = 2; + + size_t n = 10; + DISPATCH_ALL_FIELDS(field, "", [&]() { + std::vector msg0(n); + std::vector msg1(n); + ring2k_t max = static_cast(1) << bw; + + std::vector choices(n); + std::default_random_engine rdv; + std::uniform_int_distribution uniform(0, -1); + std::generate_n(choices.begin(), n, [&]() -> uint8_t { + return static_cast(uniform(rdv) & 1); + }); + + std::vector selected(n); + + utils::simulate(kWorldSize, [&](std::shared_ptr ctx) { + auto conn = std::make_shared(ctx); + int rank = ctx->Rank(); + FerretOT ferret(conn, rank == 0); + if (rank == 0) { + ferret.SendRMCC(absl::MakeSpan(msg0), absl::MakeSpan(msg1), bw); + ferret.Flush(); + } else { + ferret.RecvRMCC(absl::MakeSpan(choices), absl::MakeSpan(selected), bw); + } + }); + + for (size_t i = 0; i < n; ++i) { + ring2k_t e = choices[i] ? msg1[i] : msg0[i]; + ring2k_t c = selected[i]; + EXPECT_LT(e, max); + EXPECT_LT(c, max); + EXPECT_EQ(e, c); + } + }); +} + TEST_P(FerretCOTTest, ChosenMsgChosenChoice) { size_t kWorldSize = 2; size_t n = 106; diff --git a/libspu/mpc/io_interface.h b/libspu/mpc/io_interface.h index 9ee1fe1a..8c2aa336 100644 --- a/libspu/mpc/io_interface.h +++ b/libspu/mpc/io_interface.h @@ -40,12 +40,14 @@ class IoInterface { // resposibility to encode to it to ring. virtual std::vector toShares(const ArrayRef& raw, Visibility vis, int owner_rank = -1) const = 0; + virtual Type getShareType(Visibility vis, int owner_rank = -1) const = 0; // Make a secret from a bit array, if the element type is large than one bit, // only the lsb is considered. // // @param raw, with type as PtType. virtual std::vector makeBitSecret(const ArrayRef& raw) const = 0; + virtual size_t getBitSecretShareSize(size_t numel) const = 0; virtual bool hasBitSecretSupport() const = 0; // Reconstruct shares into a RingTy value. @@ -68,6 +70,9 @@ class BaseIo : public IoInterface { std::vector makeBitSecret(const ArrayRef& raw) const override { SPU_THROW("should not be here"); } + size_t getBitSecretShareSize(size_t numel) const override { + SPU_THROW("should not be here"); + } bool hasBitSecretSupport() const override { return false; } }; diff --git a/libspu/mpc/ref2k/ref2k.cc b/libspu/mpc/ref2k/ref2k.cc index 9a42c81b..55a6e766 100644 --- a/libspu/mpc/ref2k/ref2k.cc +++ b/libspu/mpc/ref2k/ref2k.cc @@ -506,6 +506,16 @@ std::unique_ptr makeRef2kProtocol( return ctx; } +Type Ref2kIo::getShareType(Visibility vis, int owner_rank) const { + if (vis == VIS_PUBLIC) { + return makeType(field_); + } else if (vis == VIS_SECRET) { + return makeType(field_); + } + + SPU_THROW("unsupported vis type {}", vis); +} + std::vector Ref2kIo::toShares(const ArrayRef& raw, Visibility vis, int owner_rank) const { SPU_ENFORCE(raw.eltype().isa(), "expected RingTy, got {}", diff --git a/libspu/mpc/ref2k/ref2k.h b/libspu/mpc/ref2k/ref2k.h index c5d93ad1..617e56e0 100644 --- a/libspu/mpc/ref2k/ref2k.h +++ b/libspu/mpc/ref2k/ref2k.h @@ -28,6 +28,8 @@ class Ref2kIo final : public BaseIo { std::vector toShares(const ArrayRef& raw, Visibility vis, int owner_rank) const override; + Type getShareType(Visibility vis, int owner_rank = -1) const override; + ArrayRef fromShares(const std::vector& shares) const override; }; diff --git a/libspu/mpc/semi2k/io.cc b/libspu/mpc/semi2k/io.cc index 36479b5f..459b3a5a 100644 --- a/libspu/mpc/semi2k/io.cc +++ b/libspu/mpc/semi2k/io.cc @@ -20,6 +20,20 @@ namespace spu::mpc::semi2k { +Type Semi2kIo::getShareType(Visibility vis, int owner_rank) const { + if (vis == VIS_PUBLIC) { + return makeType(field_); + } else if (vis == VIS_SECRET) { + if (owner_rank >= 0 && owner_rank < static_cast(world_size_)) { + return makeType(field_, owner_rank); + } else { + return makeType(field_); + } + } + + SPU_THROW("unsupported vis type {}", vis); +} + std::vector Semi2kIo::toShares(const ArrayRef& raw, Visibility vis, int owner_rank) const { SPU_ENFORCE(raw.eltype().isa(), "expected RingTy, got {}", diff --git a/libspu/mpc/semi2k/io.h b/libspu/mpc/semi2k/io.h index 27233b36..3957beca 100644 --- a/libspu/mpc/semi2k/io.h +++ b/libspu/mpc/semi2k/io.h @@ -27,6 +27,8 @@ class Semi2kIo : public BaseIo { std::vector toShares(const ArrayRef& raw, Visibility vis, int owner_rank) const override; + Type getShareType(Visibility vis, int owner_rank = -1) const override; + ArrayRef fromShares(const std::vector& shares) const override; }; diff --git a/libspu/mpc/spdz2k/BUILD.bazel b/libspu/mpc/spdz2k/BUILD.bazel index e2e9e16c..57a555d2 100644 --- a/libspu/mpc/spdz2k/BUILD.bazel +++ b/libspu/mpc/spdz2k/BUILD.bazel @@ -30,6 +30,8 @@ spu_cc_library( hdrs = ["protocol.h"], deps = [ ":arithmetic", + ":boolean", + ":conversion", ":state", ":value", "//libspu/core:context", @@ -43,6 +45,8 @@ spu_cc_test( deps = [ ":abprotocol_spdz2k_test", ":protocol", + "//libspu/mpc:api_test", + "//libspu/mpc:ab_api_test", ], ) @@ -52,6 +56,7 @@ spu_cc_library( deps = [ ":commitment", "//libspu/mpc/spdz2k/beaver:beaver_tfp", + "//libspu/mpc/spdz2k/beaver:beaver_tinyot", ], ) @@ -74,11 +79,46 @@ spu_cc_library( ], ) +spu_cc_library( + name = "boolean", + srcs = ["boolean.cc"], + hdrs = ["boolean.h"], + deps = [ + ":state", + ":type", + ":value", + "//libspu/mpc:kernel", + "//libspu/mpc:ab_api", + "//libspu/mpc/common:prg_state", + "//libspu/mpc/common:communicator", + ], +) + + +spu_cc_library( + name = "conversion", + srcs = ["conversion.cc"], + hdrs = ["conversion.h"], + deps = [ + ":state", + ":type", + ":boolean", + ":arithmetic", + "//libspu/core:vectorize", + "//libspu/mpc:kernel", + "//libspu/mpc:ab_api", + "//libspu/mpc:api", + "//libspu/mpc/common:communicator", + "//libspu/mpc/utils:circuits", + ], +) + spu_cc_library( name = "commitment", srcs = ["commitment.cc"], hdrs = ["commitment.h"], deps = [ + "//libspu/core:prelude", "@yacl//yacl/crypto/base/hash:blake3", "@yacl//yacl/crypto/base/hash:hash_utils", "@yacl//yacl/crypto/utils:rand", @@ -131,7 +171,10 @@ spu_cc_library( deps = [ ":type", "//libspu/core", + "//libspu/mpc/common:pv2k", "//libspu/mpc/utils:ring_ops", + "//libspu/mpc/spdz2k/beaver:beaver_tfp", + "//libspu/mpc/spdz2k/beaver:beaver_tinyot", ], ) @@ -139,8 +182,8 @@ spu_cc_library( name = "abprotocol_spdz2k_test", testonly = 1, srcs = ["abprotocol_spdz2k_test.cc"], - hdrs = ["abprotocol_spdz2k_test.h"], deps = [ + "//libspu/mpc:ab_api_test", "//libspu/mpc:ab_api", "//libspu/mpc:api", "//libspu/mpc/common:communicator", diff --git a/libspu/mpc/spdz2k/abprotocol_spdz2k_test.cc b/libspu/mpc/spdz2k/abprotocol_spdz2k_test.cc index 5af107d1..3a90cf45 100644 --- a/libspu/mpc/spdz2k/abprotocol_spdz2k_test.cc +++ b/libspu/mpc/spdz2k/abprotocol_spdz2k_test.cc @@ -12,10 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "libspu/mpc/spdz2k/abprotocol_spdz2k_test.h" - #include "libspu/core/shape_util.h" #include "libspu/mpc/ab_api.h" +#include "libspu/mpc/ab_api_test.h" #include "libspu/mpc/api.h" #include "libspu/mpc/common/communicator.h" #include "libspu/mpc/kernel.h" @@ -43,7 +42,7 @@ const std::vector kShiftBits = {0, 1, 2, 31, 32, 33, 64, 1000}; EXPECT_EQ((X).shape(), (Y).shape()); \ auto [x_data, x_shape, x_dtype] = UnwrapValue(X); \ auto [y_data, y_shape, y_dtype] = UnwrapValue(Y); \ - EXPECT_TRUE(ring_all_equal(x_data, y_data, ERR)); \ + EXPECT_TRUE(ring_all_equal(x_data, y_data)); \ } bool verifyCost(Kernel* kernel, std::string_view name, FieldType field, @@ -77,278 +76,320 @@ bool verifyCost(Kernel* kernel, std::string_view name, FieldType field, } // namespace -TEST_P(ArithmeticTest, P2A) { +TEST_P(BooleanTest, NotB) { const auto factory = std::get<0>(GetParam()); const RuntimeConfig& conf = std::get<1>(GetParam()); const size_t npc = std::get<2>(GetParam()); - utils::simulate(npc, [&](const std::shared_ptr& lctx) { - auto sctx = factory(conf, lctx); + utils::simulate(npc, [&](std::shared_ptr lctx) { + auto obj = factory(conf, lctx); /* GIVEN */ - auto p0 = rand_p(sctx.get(), kShape); + auto p0 = rand_p(obj.get(), kShape); + auto b0 = p2b(obj.get(), p0); /* WHEN */ - auto prev = sctx->prot()->getState()->getStats(); - auto a0 = p2a(sctx.get(), p0); - auto cost = sctx->prot()->getState()->getStats() - prev; - auto p1 = a2p(sctx.get(), a0); + auto prev = obj->prot()->getState()->getStats(); + auto r_b = dynDispatch(obj.get(), "not_b", b0); + auto cost = obj->prot()->getState()->getStats() - prev; + auto r_p = b2p(obj.get(), r_b); + auto r_pp = not_p(obj.get(), p0); /* THEN */ - EXPECT_VALUE_EQ(p0, p1); - EXPECT_TRUE(verifyCost(sctx->prot()->getKernel("p2a"), "p2a", conf.field(), - kShape, npc, cost)); + EXPECT_VALUE_EQ(r_p, r_pp); + EXPECT_TRUE(verifyCost(obj->prot()->getKernel("not_b"), "not_b", + conf.field(), kShape, npc, cost)); }); } -TEST_P(ArithmeticTest, A2P) { +TEST_P(ConversionTest, AddBB) { const auto factory = std::get<0>(GetParam()); const RuntimeConfig& conf = std::get<1>(GetParam()); const size_t npc = std::get<2>(GetParam()); - utils::simulate(npc, [&](const std::shared_ptr& lctx) { - auto sctx = factory(conf, lctx); + utils::simulate(npc, [&](std::shared_ptr lctx) { + auto obj = factory(conf, lctx); + if (!obj->hasKernel("and_bb")) { + return; + } /* GIVEN */ - auto p0 = rand_p(sctx.get(), kShape); + auto p0 = rand_p(obj.get(), kShape); + auto p1 = rand_p(obj.get(), kShape); /* WHEN */ - auto a0 = p2a(sctx.get(), p0); - auto prev = sctx->prot()->getState()->getStats(); - auto p1 = a2p(sctx.get(), a0); - [[maybe_unused]] auto cost = - sctx->prot()->getState()->getStats() - prev; + auto b0 = p2b(obj.get(), p0); + auto b1 = p2b(obj.get(), p1); + auto prev = obj->prot()->getState()->getStats(); + auto tmp = add_bb(obj.get(), b0, b1); + auto cost = obj->prot()->getState()->getStats() - prev; + auto re = b2p(obj.get(), tmp); + auto rp = add_pp(obj.get(), p0, p1); /* THEN */ - EXPECT_VALUE_EQ(p0, p1); - EXPECT_TRUE(verifyCost(sctx->prot()->getKernel("a2p"), "a2p", conf.field(), + EXPECT_VALUE_EQ(re, rp); + EXPECT_TRUE(verifyCost(obj->getKernel("add_bb"), "add_bb", conf.field(), kShape, npc, cost)); }); } -#define TEST_ARITHMETIC_BINARY_OP_AA(OP) \ - TEST_P(ArithmeticTest, OP##AA) { \ - const auto factory = std::get<0>(GetParam()); \ - const RuntimeConfig& conf = std::get<1>(GetParam()); \ - const size_t npc = std::get<2>(GetParam()); \ - \ - utils::simulate(npc, [&](std::shared_ptr lctx) { \ - auto obj = factory(conf, lctx); \ - \ - /* GIVEN */ \ - auto p0 = rand_p(obj.get(), kShape); \ - auto p1 = rand_p(obj.get(), kShape); \ - \ - /* WHEN */ \ - auto a0 = p2a(obj.get(), p0); \ - auto a1 = p2a(obj.get(), p1); \ - auto prev = obj->prot()->getState()->getStats(); \ - auto tmp = OP##_aa(obj.get(), a0, a1); \ - auto cost = obj->prot()->getState()->getStats() - prev; \ - auto re = a2p(obj.get(), tmp); \ - auto rp = OP##_pp(obj.get(), p0, p1); \ - \ - /* THEN */ \ - EXPECT_VALUE_EQ(re, rp); \ - EXPECT_TRUE(verifyCost(obj->prot()->getKernel(#OP "_aa"), #OP "_aa", \ - conf.field(), kShape, npc, cost)); \ - }); \ - } +TEST_P(ConversionTest, AddBP) { + const auto factory = std::get<0>(GetParam()); + const RuntimeConfig& conf = std::get<1>(GetParam()); + const size_t npc = std::get<2>(GetParam()); -#define TEST_ARITHMETIC_BINARY_OP_AP(OP) \ - TEST_P(ArithmeticTest, OP##AP) { \ - const auto factory = std::get<0>(GetParam()); \ - const RuntimeConfig& conf = std::get<1>(GetParam()); \ - const size_t npc = std::get<2>(GetParam()); \ - \ - utils::simulate(npc, [&](std::shared_ptr lctx) { \ - auto obj = factory(conf, lctx); \ - \ - /* GIVEN */ \ - auto p0 = rand_p(obj.get(), kShape); \ - auto p1 = rand_p(obj.get(), kShape); \ - \ - /* WHEN */ \ - auto a0 = p2a(obj.get(), p0); \ - auto prev = obj->prot()->getState()->getStats(); \ - auto tmp = OP##_ap(obj.get(), a0, p1); \ - auto cost = obj->prot()->getState()->getStats() - prev; \ - auto re = a2p(obj.get(), tmp); \ - auto rp = OP##_pp(obj.get(), p0, p1); \ - \ - /* THEN */ \ - EXPECT_VALUE_EQ(re, rp); \ - EXPECT_TRUE(verifyCost(obj->prot()->getKernel(#OP "_ap"), #OP "_ap", \ - conf.field(), kShape, npc, cost)); \ - }); \ - } + utils::simulate(npc, [&](std::shared_ptr lctx) { + auto obj = factory(conf, lctx); + + if (!obj->hasKernel("and_bp")) { + return; + } + /* GIVEN */ + auto p0 = rand_p(obj.get(), kShape); + auto p1 = rand_p(obj.get(), kShape); -#define TEST_ARITHMETIC_BINARY_OP(OP) \ - TEST_ARITHMETIC_BINARY_OP_AA(OP) \ - TEST_ARITHMETIC_BINARY_OP_AP(OP) + /* WHEN */ + auto b0 = p2b(obj.get(), p0); + auto prev = obj->prot()->getState()->getStats(); + // Not a common test!!!! + auto tmp = dynDispatch(obj.get(), "add_bp", b0, p1); + auto cost = obj->prot()->getState()->getStats() - prev; + auto re = b2p(obj.get(), tmp); + auto rp = add_pp(obj.get(), p0, p1); -TEST_ARITHMETIC_BINARY_OP(add) -TEST_ARITHMETIC_BINARY_OP(mul) + /* THEN */ + EXPECT_VALUE_EQ(re, rp); + EXPECT_TRUE(verifyCost(obj->getKernel("add_bp"), "add_bp", conf.field(), + kShape, npc, cost)); + }); +} -TEST_P(ArithmeticTest, NotA) { +TEST_P(ConversionTest, Bit2A) { const auto factory = std::get<0>(GetParam()); const RuntimeConfig& conf = std::get<1>(GetParam()); const size_t npc = std::get<2>(GetParam()); - utils::simulate(npc, [&](const std::shared_ptr& lctx) { + utils::simulate(npc, [&](std::shared_ptr lctx) { auto obj = factory(conf, lctx); + if (!obj->hasKernel("bit2a")) { + return; + } /* GIVEN */ auto p0 = rand_p(obj.get(), kShape); - auto a0 = p2a(obj.get(), p0); + p0 = msb_p(obj.get(), p0); /* WHEN */ + auto b = p2b(obj.get(), p0); auto prev = obj->prot()->getState()->getStats(); - auto r_a = not_a(obj.get(), a0); + auto a = dynDispatch(obj.get(), "bit2a", b); auto cost = obj->prot()->getState()->getStats() - prev; - auto r_p = a2p(obj.get(), r_a); - auto r_pp = a2p(obj.get(), not_a(obj.get(), a0)); - + auto p1 = a2p(obj.get(), a); /* THEN */ - EXPECT_VALUE_EQ(r_p, r_pp); - EXPECT_TRUE(verifyCost(obj->prot()->getKernel("not_a"), "not_a", - conf.field(), kShape, npc, cost)); + EXPECT_VALUE_EQ(p0, p1); + EXPECT_TRUE(verifyCost(obj->getKernel("bit2a"), "bit2a", conf.field(), + kShape, npc, cost)); }); } -TEST_P(ArithmeticTest, MatMulAP) { +TEST_P(ConversionTest, A2Bit) { const auto factory = std::get<0>(GetParam()); const RuntimeConfig& conf = std::get<1>(GetParam()); const size_t npc = std::get<2>(GetParam()); - const int64_t M = 3; - const int64_t K = 4; - const int64_t N = 3; - const Shape shape_A = {M, K}; - const Shape shape_B = {K, N}; - const Shape shape_C = {M, N}; - - utils::simulate(npc, [&](const std::shared_ptr& lctx) { + utils::simulate(npc, [&](std::shared_ptr lctx) { auto obj = factory(conf, lctx); + if (!obj->hasKernel("a2bit")) { + return; + } /* GIVEN */ - auto p0 = rand_p(obj.get(), shape_A); - auto p1 = rand_p(obj.get(), shape_B); - auto a0 = p2a(obj.get(), p0); + auto p0 = rand_p(obj.get(), kShape); + p0 = msb_p(obj.get(), p0); /* WHEN */ + auto a = p2a(obj.get(), p0); + p2b(obj.get(), p0); auto prev = obj->prot()->getState()->getStats(); - auto tmp = mmul_ap(obj.get(), a0, p1, M, N, K); + auto b = dynDispatch(obj.get(), "a2bit", a); auto cost = obj->prot()->getState()->getStats() - prev; - - auto r_aa = a2p(obj.get(), tmp); - auto r_pp = mmul_pp(obj.get(), p0, p1, M, N, K); - + // reserve the least significant bit only + auto p1 = b2p(obj.get(), b); /* THEN */ - EXPECT_VALUE_EQ(r_aa, r_pp); - EXPECT_TRUE(verifyCost(obj->prot()->getKernel("mmul_ap"), "mmul_ap", - conf.field(), shape_C, npc, cost)); + EXPECT_VALUE_EQ(p0, p1); + EXPECT_TRUE(verifyCost(obj->getKernel("a2bit"), "a2bit", conf.field(), + kShape, npc, cost)); }); } -TEST_P(ArithmeticTest, MatMulAA) { +TEST_P(ConversionTest, BitLT) { const auto factory = std::get<0>(GetParam()); const RuntimeConfig& conf = std::get<1>(GetParam()); const size_t npc = std::get<2>(GetParam()); - const int64_t M = 3; - const int64_t K = 4; - const int64_t N = 3; - const Shape shape_A = {M, K}; - const Shape shape_B = {K, N}; - const Shape shape_C = {M, N}; - - utils::simulate(npc, [&](const std::shared_ptr& lctx) { + utils::simulate(npc, [&](std::shared_ptr lctx) { auto obj = factory(conf, lctx); + if (!obj->hasKernel("bitlt_bb")) { + return; + } /* GIVEN */ - auto p0 = rand_p(obj.get(), shape_A); - auto p1 = rand_p(obj.get(), shape_B); - auto a0 = p2a(obj.get(), p0); - auto a1 = p2a(obj.get(), p1); + auto p0 = rand_p(obj.get(), kShape); + auto p1 = rand_p(obj.get(), kShape); /* WHEN */ + auto b0 = p2b(obj.get(), p0); + auto b1 = p2b(obj.get(), p1); auto prev = obj->prot()->getState()->getStats(); - auto tmp = mmul_aa(obj.get(), a0, a1, M, N, K); + // Not a common test!!!! + auto tmp = dynDispatch(obj.get(), "bitlt_bb", b0, b1); auto cost = obj->prot()->getState()->getStats() - prev; - - auto r_aa = a2p(obj.get(), tmp); - auto r_pp = mmul_pp(obj.get(), p0, p1, M, N, K); + auto re = b2p(obj.get(), tmp); + + const auto field = p0.storage_type().as()->field(); + DISPATCH_ALL_FIELDS(field, "_", [&]() { + using U = std::make_unsigned::type; + size_t numel = kShape.numel(); + + auto p0_data = ArrayView(flatten(p0.data())); + auto p1_data = ArrayView(flatten(p1.data())); + auto re_data = ArrayView(flatten(re.data())); + for (size_t i = 0; i < numel; ++i) { + if ((p0_data[i] < p1_data[i])) { + SPU_ENFORCE((re_data[i] == 1)); + } else { + SPU_ENFORCE((re_data[i] == 0)); + } + } + }); /* THEN */ - EXPECT_VALUE_EQ(r_aa, r_pp); - EXPECT_TRUE(verifyCost(obj->prot()->getKernel("mmul_aa"), "mmul_aa", - conf.field(), shape_C, npc, cost)); + EXPECT_TRUE(verifyCost(obj->getKernel("bitlt_bb"), "bitlt_bb", conf.field(), + kShape, npc, cost)); }); } -TEST_P(ArithmeticTest, LShiftA) { +TEST_P(ConversionTest, BitLE) { const auto factory = std::get<0>(GetParam()); const RuntimeConfig& conf = std::get<1>(GetParam()); const size_t npc = std::get<2>(GetParam()); - utils::simulate(npc, [&](const std::shared_ptr& lctx) { + utils::simulate(npc, [&](std::shared_ptr lctx) { auto obj = factory(conf, lctx); + if (!obj->hasKernel("bitle_bb")) { + return; + } /* GIVEN */ auto p0 = rand_p(obj.get(), kShape); - auto a0 = p2a(obj.get(), p0); + auto p1 = p0.clone(); - for (auto bits : kShiftBits) { - if (bits >= p0.elsize() * 8) { - // Shift more than elsize is a UB - continue; + /* WHEN */ + auto b0 = p2b(obj.get(), p0); + auto b1 = p2b(obj.get(), p1); + auto prev = obj->prot()->getState()->getStats(); + // Not a common test!!!! + auto tmp = dynDispatch(obj.get(), "bitle_bb", b0, b1); + auto cost = obj->prot()->getState()->getStats() - prev; + auto re = b2p(obj.get(), tmp); + + const auto field = p0.storage_type().as()->field(); + DISPATCH_ALL_FIELDS(field, "_", [&]() { + using U = std::make_unsigned::type; + size_t numel = kShape.numel(); + auto p0_data = ArrayView(flatten(p0.data())); + auto p1_data = ArrayView(flatten(p1.data())); + auto re_data = ArrayView(flatten(re.data())); + for (size_t i = 0; i < numel; ++i) { + if ((p0_data[i] <= p1_data[i])) { + SPU_ENFORCE((re_data[i] == 1)); + } else { + SPU_ENFORCE((re_data[i] == 0)); + } } - /* WHEN */ - auto prev = obj->prot()->getState()->getStats(); - auto tmp = lshift_a(obj.get(), a0, bits); - auto cost = obj->prot()->getState()->getStats() - prev; - auto r_b = a2p(obj.get(), tmp); - auto r_p = lshift_p(obj.get(), p0, bits); - - /* THEN */ - EXPECT_VALUE_EQ(r_b, r_p); - EXPECT_TRUE(verifyCost(obj->prot()->getKernel("lshift_a"), "lshift_a", - conf.field(), kShape, npc, cost)); - } + }); + + /* THEN */ + EXPECT_TRUE(verifyCost(obj->getKernel("bitlt_bb"), "bitlt_bb", conf.field(), + kShape, npc, cost)); }); } -TEST_P(ArithmeticTest, TruncA) { +TEST_P(BooleanTest, BitIntl) { const auto factory = std::get<0>(GetParam()); const RuntimeConfig& conf = std::get<1>(GetParam()); const size_t npc = std::get<2>(GetParam()); + size_t stride = 0; - utils::simulate(npc, [&](const std::shared_ptr& lctx) { + utils::simulate(npc, [&](std::shared_ptr lctx) { auto obj = factory(conf, lctx); + /* GIVEN */ auto p0 = rand_p(obj.get(), kShape); - EXPECT_TRUE(static_cast(obj->prot()->getKernel("trunc_a")) - ->hasMsbError()); - // has msb error, only use lowest 10 bits. - p0 = arshift_p(obj.get(), p0, SizeOf(conf.field()) * 8 - 10); + /* WHEN */ + auto b = p2b(obj.get(), p0); + auto prev = obj->prot()->getState()->getStats(); + auto tmp = bitintl_b(obj.get(), b, stride); + auto cost = obj->prot()->getState()->getStats() - prev; + + auto p1 = b2p(obj.get(), tmp); + auto pp1 = bitintl_b(obj.get(), p0, stride); + /* THEN */ + EXPECT_VALUE_EQ(p1, pp1); + EXPECT_TRUE(verifyCost(obj->getKernel("bitintl_b"), "bitintl_b", + conf.field(), kShape, npc, cost)); + }); +} + +TEST_P(BooleanTest, BitDeintl) { + const auto factory = std::get<0>(GetParam()); + const RuntimeConfig& conf = std::get<1>(GetParam()); + const size_t npc = std::get<2>(GetParam()); + size_t stride = 0; + utils::simulate(npc, [&](std::shared_ptr lctx) { + auto obj = factory(conf, lctx); /* GIVEN */ - const size_t bits = 2; - auto a0 = p2a(obj.get(), p0); + auto p0 = rand_p(obj.get(), kShape); /* WHEN */ + auto b = p2b(obj.get(), p0); auto prev = obj->prot()->getState()->getStats(); - auto a1 = trunc_a(obj.get(), a0, bits); + auto tmp = bitdeintl_b(obj.get(), b, stride); auto cost = obj->prot()->getState()->getStats() - prev; - auto r_a = a2p(obj.get(), a1); - auto r_p = arshift_p(obj.get(), p0, bits); + auto p1 = b2p(obj.get(), tmp); + auto pp1 = bitdeintl_b(obj.get(), p0, stride); + /* THEN */ + EXPECT_VALUE_EQ(p1, pp1); + EXPECT_TRUE(verifyCost(obj->getKernel("bitintl_b"), "bitintl_b", + conf.field(), kShape, npc, cost)); + }); +} +TEST_P(BooleanTest, BitIntlAndDeintl) { + const auto factory = std::get<0>(GetParam()); + const RuntimeConfig& conf = std::get<1>(GetParam()); + const size_t npc = std::get<2>(GetParam()); + size_t stride = 0; + + utils::simulate(npc, [&](std::shared_ptr lctx) { + auto obj = factory(conf, lctx); + /* GIVEN */ + auto p0 = rand_p(obj.get(), kShape); + + /* WHEN */ + auto b = p2b(obj.get(), p0); + auto prev = obj->prot()->getState()->getStats(); + auto b0 = bitintl_b(obj.get(), b, stride); + auto cost = obj->prot()->getState()->getStats() - prev; + + auto b1 = bitdeintl_b(obj.get(), b0, stride); + auto p1 = b2p(obj.get(), b1); /* THEN */ - EXPECT_VALUE_ALMOST_EQ(r_a, r_p, npc); - EXPECT_TRUE(verifyCost(obj->prot()->getKernel("trunc_a"), "trunc_a", + EXPECT_VALUE_EQ(p0, p1); + EXPECT_TRUE(verifyCost(obj->getKernel("bitintl_b"), "bitintl_b", conf.field(), kShape, npc, cost)); }); } diff --git a/libspu/mpc/spdz2k/abprotocol_spdz2k_test.h b/libspu/mpc/spdz2k/abprotocol_spdz2k_test.h deleted file mode 100644 index b3d9698d..00000000 --- a/libspu/mpc/spdz2k/abprotocol_spdz2k_test.h +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2021 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "gtest/gtest.h" -#include "yacl/link/link.h" - -#include "libspu/core/context.h" - -namespace spu::mpc::test { - -using CreateObjectFn = std::function( - const RuntimeConfig& conf, - const std::shared_ptr& lctx)>; - -class ArithmeticTest : public ::testing::TestWithParam< - std::tuple> { -}; - -} // namespace spu::mpc::test diff --git a/libspu/mpc/spdz2k/arithmetic.cc b/libspu/mpc/spdz2k/arithmetic.cc index d6a8c8b2..2fffa849 100644 --- a/libspu/mpc/spdz2k/arithmetic.cc +++ b/libspu/mpc/spdz2k/arithmetic.cc @@ -39,35 +39,85 @@ namespace spu::mpc::spdz2k { namespace { +// Input a plaintext +// Output the B-share without MAC +// LSB first, MSB last +// ArrayRef CastToLargeRing(const ArrayRef& in, FieldType out_field) { +ArrayRef CastRing(const ArrayRef& in, FieldType out_field) { + const auto* in_ty = in.eltype().as(); + const auto in_field = in_ty->field(); + return DISPATCH_ALL_FIELDS(in_field, "_", [&]() { + auto _in = ArrayView(in); + + const size_t out_numel = in.numel(); + auto out = ring_zeros(out_field, out_numel); + return DISPATCH_ALL_FIELDS(out_field, "_", [&]() { + auto _out = ArrayView(out); + pforeach(0, in.numel(), [&](int64_t idx) { + _out[idx] = static_cast(_in[idx]); + }); + + return out; + }); + }); +} + ArrayRef zero_a_impl(KernelEvalContext* ctx, size_t size) { auto* prg_state = ctx->getState(); - const auto field = ctx->getState()->getDefaultField(); + const auto field = ctx->getState()->getDefaultField(); auto [r0, r1] = prg_state->genPrssPair(field, size); auto [r2, r3] = prg_state->genPrssPair(field, size); - // NOTES for ring_rshift to 2 bits. - // Refer to: - // New Primitives for Actively-Secure MPC over Rings with Applications to - // Private Machine Learning - // - https://eprint.iacr.org/2019/599.pdf - // It's safer to keep the number within [-2**(k-2), 2**(k-2)) for comparison - // operations. auto x = ring_sub(r0, r1); auto x_mac = ring_sub(r2, r3); return makeAShare(x, x_mac, field); } - } // namespace +ArrayRef GetMacShare(KernelEvalContext* ctx, const ArrayRef& in) { + const auto field = in.eltype().as()->field(); + auto* beaver = ctx->getState()->beaver(); + const size_t k = ctx->getState()->k(); + const size_t s = ctx->getState()->s(); + + const auto& x = getValueShare(in); + ArrayRef x_mac; + if (in.eltype().as()->hasMac()) { + x_mac = getMacShare(in); + } else { + SPDLOG_DEBUG("generate mac share"); + x_mac = beaver->AuthArrayRef(x, field, k, s); + } + return x_mac; +} + ArrayRef RandA::proc(KernelEvalContext* ctx, size_t size) const { SPU_TRACE_MPC_LEAF(ctx, size); - SPU_THROW("NotImplemented"); + + const auto field = ctx->getState()->getDefaultField(); + auto* prg_state = ctx->getState(); + auto* beaver = ctx->getState()->beaver(); + const auto k = ctx->getState()->k(); + const auto s = ctx->getState()->s(); + + // NOTES for ring_rshift to 2 bits. + // Refer to: + // New Primitives for Actively-Secure MPC over Rings with Applications to + // Private Machine Learning + // - https://eprint.iacr.org/2019/599.pdf + // It's safer to keep the number within [-2**(k-2), 2**(k-2)) for comparison + // operations. + auto x = ring_rshift(prg_state->genPriv(field, size), 2) + .as(makeType(field)); + auto x_mac = beaver->AuthArrayRef(x, field, k, s); + return makeAShare(x, x_mac, field); } ArrayRef P2A::proc(KernelEvalContext* ctx, const ArrayRef& in) const { SPU_TRACE_MPC_LEAF(ctx, in); + const auto field = ctx->getState()->getDefaultField(); auto* comm = ctx->getState(); const auto key = ctx->getState()->key(); @@ -75,11 +125,12 @@ ArrayRef P2A::proc(KernelEvalContext* ctx, const ArrayRef& in) const { auto z = getValueShare(res); auto z_mac = getMacShare(res); + auto t_in = CastRing(in, field); if (comm->getRank() == 0) { - ring_add_(z, in); + ring_add_(z, t_in); } - ring_add_(z_mac, ring_mul(in, key)); + ring_add_(z_mac, ring_mul(t_in, key)); return res; } @@ -87,28 +138,86 @@ ArrayRef P2A::proc(KernelEvalContext* ctx, const ArrayRef& in) const { ArrayRef A2P::proc(KernelEvalContext* ctx, const ArrayRef& in) const { SPU_TRACE_MPC_LEAF(ctx, in); - const auto field = in.eltype().as()->field(); + const auto out_field = ctx->getState()->getDefaultField(); + auto* beaver = ctx->getState()->beaver(); + const auto k = ctx->getState()->k(); + const auto s = ctx->getState()->s(); + + // in + const auto& x = getValueShare(in); + const auto& x_mac = getMacShare(in); + auto [t, check_mac] = beaver->BatchOpen(x, x_mac, k, s); + SPU_ENFORCE(beaver->BatchMacCheck(t, check_mac, k, s)); + + // Notice that only the last sth bits is correct + ring_bitmask_(t, 0, k); + + auto res = CastRing(t, out_field); + return res.as(makeType(out_field)); +} + +ArrayRef A2V::proc(KernelEvalContext* ctx, const ArrayRef& in, + size_t rank) const { + SPU_TRACE_MPC_LEAF(ctx, in); + + const auto field = ctx->getState()->getDefaultField(); + const auto out_field = ctx->getState()->getDefaultField(); auto* comm = ctx->getState(); - auto* arr_ref_v = ctx->getState()->arr_ref_v(); - arr_ref_v->emplace_back(in); - -// #define SINGLE_CHECK -#ifdef SINGLE_CHECK - for (auto& x : *arr_ref_v) { - bool success = SingleCheck(ctx, x); - SPU_ENFORCE(success, "single check fail"); + auto* beaver = ctx->getState()->beaver(); + const auto k = ctx->getState()->k(); + const auto s = ctx->getState()->s(); + + // generate mask + auto zero = zero_a_impl(ctx, in.numel()); + auto z = getValueShare(zero); + auto z_mac = getMacShare(zero); + auto r = ring_rand(field, in.numel()); + if (comm->getRank() == rank) { + ring_add_(z, r); } -#else - bool success = BatchCheck(ctx, *arr_ref_v); - arr_ref_v->clear(); - SPU_ENFORCE(success, "batch check fail"); -#endif + ring_add_(z_mac, beaver->AuthArrayRef(z, field, k, s)); - // in + // add mask const auto& x = getValueShare(in); - auto t = comm->allReduce(ReduceOp::ADD, x, kBindName); - auto ty = makeType(field); - return t.as(ty); + const auto& x_mac = getMacShare(in); + auto mask_x = ring_add(x, z); + auto mask_x_mac = ring_add(x_mac, z_mac); + + auto [t, check_mac] = beaver->BatchOpen(mask_x, mask_x_mac, k, s); + SPU_ENFORCE(beaver->BatchMacCheck(t, check_mac, k, s)); + + // Notice that only the last s bits is correct + if (comm->getRank() == rank) { + auto t_r = ring_bitmask(ring_sub(t, r), 0, k); + auto res = CastRing(t_r, out_field); + return res.as(makeType(out_field, rank)); + } else { + auto out_ty = makeType(out_field, rank); + return makeConstantArrayRef(out_ty, in.numel()); + } +} + +ArrayRef V2A::proc(KernelEvalContext* ctx, const ArrayRef& in) const { + const auto* in_ty = in.eltype().as(); + const size_t owner_rank = in_ty->owner(); + const auto field = ctx->getState()->getDefaultField(); + auto* comm = ctx->getState(); + auto* beaver = ctx->getState()->beaver(); + const size_t k = ctx->getState()->k(); + const size_t s = ctx->getState()->s(); + + auto res = zero_a_impl(ctx, in.numel()); + auto z = getValueShare(res); + auto z_mac = getMacShare(res); + + auto t_in = CastRing(in, field); + if (comm->getRank() == owner_rank) { + ring_add_(z, t_in); + } + + ring_add_(z_mac, beaver->AuthArrayRef(z, field, k, s)); + + return res; } ArrayRef NotA::proc(KernelEvalContext* ctx, const ArrayRef& in) const { @@ -120,7 +229,7 @@ ArrayRef NotA::proc(KernelEvalContext* ctx, const ArrayRef& in) const { // in const auto& x = getValueShare(in); - const auto& x_mac = getMacShare(in); + const auto& x_mac = GetMacShare(ctx, in); // compute neg_x, neg_x_mac auto neg_x = ring_neg(x); @@ -150,14 +259,16 @@ ArrayRef AddAP::proc(KernelEvalContext* ctx, const ArrayRef& lhs, // lhs const auto& x = getValueShare(lhs); - const auto& x_mac = getMacShare(lhs); + const auto& x_mac = GetMacShare(ctx, lhs); + + auto t_rhs = CastRing(rhs, field); // remember that rhs is public auto z = x.clone(); if (comm->getRank() == 0) { - ring_add_(z, rhs); + ring_add_(z, t_rhs); } - auto z_mac = ring_add(x_mac, ring_mul(rhs, key)); + auto z_mac = ring_add(x_mac, ring_mul(t_rhs, key)); return makeAShare(z, z_mac, field); } @@ -170,11 +281,13 @@ ArrayRef AddAA::proc(KernelEvalContext* ctx, const ArrayRef& lhs, // lhs const auto& x = getValueShare(lhs); - const auto& x_mac = getMacShare(lhs); + const auto& x_mac = GetMacShare(ctx, lhs); + // const auto& x_mac = getMacShare(lhs); // rhs const auto& y = getValueShare(rhs); - const auto& y_mac = getMacShare(rhs); + const auto& y_mac = GetMacShare(ctx, rhs); + // const auto& y_mac = getMacShare(rhs); // ret const auto& z = ring_add(x, y); @@ -195,14 +308,13 @@ bool SingleCheck(KernelEvalContext* ctx, const ArrayRef& in) { const auto field = in.eltype().as()->field(); auto* comm = ctx->getState(); - const auto& lctx = comm->lctx(); auto* beaver = ctx->getState()->beaver(); const auto key = ctx->getState()->key(); const size_t k = ctx->getState()->k(); const size_t s = ctx->getState()->s(); // 1. Generate a random, shared value [r] - auto [r, r_mac] = beaver->AuthCoinTossing(field, in.numel(), s); + auto [r, r_mac] = beaver->AuthCoinTossing(field, in.numel(), k, s); // 2. Locally construct [y] const auto& x = getValueShare(in); @@ -217,8 +329,8 @@ bool SingleCheck(KernelEvalContext* ctx, const ArrayRef& in) { auto z = ring_sub(y_mac, ring_mul(plain_y, key)); std::string z_str(reinterpret_cast(z.data()), z.numel() * z.elsize()); std::vector z_strs; - YACL_ENFORCE(commit_and_open(lctx, z_str, &z_strs)); - YACL_ENFORCE(z_strs.size() == comm->getWorldSize()); + SPU_ENFORCE(commit_and_open(comm->lctx(), z_str, &z_strs)); + SPU_ENFORCE(z_strs.size() == comm->getWorldSize()); auto plain_z = ring_zeros(field, in.numel()); for (size_t i = 0; i < comm->getWorldSize(); ++i) { @@ -229,6 +341,7 @@ bool SingleCheck(KernelEvalContext* ctx, const ArrayRef& in) { } auto ret = spu::mpc::ring_all_equal(plain_z, ring_zeros(field, in.numel())); + SPU_ENFORCE(ret, "single check fail"); return ret; } @@ -274,7 +387,7 @@ bool BatchCheck(KernelEvalContext* ctx, const std::vector& ins) { for (const auto& in : ins) { // 1. get random r and r_mac - auto [r, r_mac] = beaver->AuthCoinTossing(field, numel, s); + auto [r, r_mac] = beaver->AuthCoinTossing(field, numel, k, s); auto rmac = makeAShare(r, r_mac, field); // 2. [x_hat] = [x] + 2^k * [r] @@ -295,7 +408,7 @@ bool BatchCheck(KernelEvalContext* ctx, const std::vector& ins) { }); // 5. get l public random values, compute plain y - auto pub_r = ctx->getState()->genPublCoin(field, size); + auto pub_r = beaver->genPublCoin(field, size); std::vector rv; uint128_t mask = (static_cast(1) << s) - 1; for (size_t i = 0; i < size; ++i) { @@ -350,17 +463,19 @@ ArrayRef MulAP::proc(KernelEvalContext* ctx, const ArrayRef& lhs, // lhs const auto& x = getValueShare(lhs); - const auto& x_mac = getMacShare(lhs); + const auto& x_mac = GetMacShare(ctx, lhs); // ret - const auto& z = ring_mul(x, rhs); - const auto& z_mac = ring_mul(x_mac, rhs); + auto t_rhs = CastRing(rhs, field); + const auto& z = ring_mul(x, t_rhs); + const auto& z_mac = ring_mul(x_mac, t_rhs); return makeAShare(z, z_mac, field); } // Refer to: -// 4 Online Phase, SPDZ2k: Efficient MPC mod 2k for Dishonest Majority +// 3.3 Reducing the Number of Masks && 4 Online Phase +// SPDZ2k: Efficient MPC mod 2k for Dishonest Majority // - https://eprint.iacr.org/2018/482.pdf // // TODO: use DISPATCH_ALL_FIELDS instead of ring ops to improve performance @@ -371,17 +486,19 @@ ArrayRef MulAA::proc(KernelEvalContext* ctx, const ArrayRef& lhs, const auto field = lhs.eltype().as()->field(); auto* comm = ctx->getState(); auto* beaver = ctx->getState()->beaver(); - auto* arr_ref_v = ctx->getState()->arr_ref_v(); const auto key = ctx->getState()->key(); + const auto k = ctx->getState()->k(); + const auto s = ctx->getState()->s(); // in const auto& x = getValueShare(lhs); - const auto& x_mac = getMacShare(lhs); + const auto& x_mac = GetMacShare(ctx, lhs); const auto& y = getValueShare(rhs); - const auto& y_mac = getMacShare(rhs); + const auto& y_mac = GetMacShare(ctx, rhs); // e = x - a, f = y - b - auto [vec, mac_vec] = beaver->AuthMul(field, lhs.numel()); + auto [vec, mac_vec] = beaver->AuthMul(field, lhs.numel(), k, s); + auto [a, b, c] = vec; auto [a_mac, b_mac, c_mac] = mac_vec; @@ -390,17 +507,19 @@ ArrayRef MulAA::proc(KernelEvalContext* ctx, const ArrayRef& lhs, auto f = ring_sub(y, b); auto f_mac = ring_sub(y_mac, b_mac); - // add to check array - arr_ref_v->emplace_back(makeAShare(e, e_mac, field)); - arr_ref_v->emplace_back(makeAShare(f, f_mac, field)); - // open e, f auto res = vectorize({e, f}, [&](const ArrayRef& s) { return comm->allReduce(ReduceOp::ADD, s, kBindName); }); - auto p_e = std::move(res[0]); auto p_f = std::move(res[1]); + + // don't use BatchOpen to reduce the number of masks + // auto [p_e, masked_e_mac] = beaver->BatchOpen(e, e_mac, k, s); + // auto [p_f, masked_f_mac] = beaver->BatchOpen(f, f_mac, k, s); + SPU_ENFORCE(beaver->BatchMacCheck(p_e, e_mac, k, s)); + SPU_ENFORCE(beaver->BatchMacCheck(p_f, f_mac, k, s)); + auto p_ef = ring_mul(p_e, p_f); // z = p_e * b + p_f * a + c; @@ -431,8 +550,8 @@ ArrayRef MatMulAP::proc(KernelEvalContext* ctx, const ArrayRef& lhs, // in const auto& x = getValueShare(lhs); - const auto& x_mac = getMacShare(lhs); - const auto& y = rhs; + const auto& x_mac = GetMacShare(ctx, lhs); + const auto& y = CastRing(rhs, field); // ret auto z = ring_mmul(x, y, m, n, k); @@ -449,14 +568,14 @@ ArrayRef MatMulAA::proc(KernelEvalContext* ctx, const ArrayRef& lhs, auto* comm = ctx->getState(); auto* beaver = ctx->getState()->beaver(); const auto key = ctx->getState()->key(); + const auto k_bits = ctx->getState()->k(); + const auto s_bits = ctx->getState()->s(); const auto& x = getValueShare(lhs); const auto& y = getValueShare(rhs); - // const auto& x_mac = getMacShare(lhs); - // const auto& y_mac = getMacShare(rhs); // generate beaver multiple triple. - auto [vec, mac_vec] = beaver->AuthDot(field, m, n, k); + auto [vec, mac_vec] = beaver->AuthDot(field, m, n, k, k_bits, s_bits); auto [a, b, c] = vec; auto [a_mac, b_mac, c_mac] = mac_vec; @@ -495,7 +614,7 @@ ArrayRef LShiftA::proc(KernelEvalContext* ctx, const ArrayRef& in, // in const auto& x = getValueShare(in); - const auto& x_mac = getMacShare(in); + const auto& x_mac = GetMacShare(ctx, in); // ret const auto& z = ring_lshift(x, bits); @@ -505,8 +624,6 @@ ArrayRef LShiftA::proc(KernelEvalContext* ctx, const ArrayRef& in, // ABY3, truncation pair method. // Ref: Section 5.1.2 https://eprint.iacr.org/2018/403.pdf -// -// TODO: optimize for 2pc. ArrayRef TruncA::proc(KernelEvalContext* ctx, const ArrayRef& in, size_t bits) const { SPU_TRACE_MPC_LEAF(ctx, in, bits); @@ -515,15 +632,22 @@ ArrayRef TruncA::proc(KernelEvalContext* ctx, const ArrayRef& in, const auto field = in.eltype().as()->field(); auto* comm = ctx->getState(); auto* beaver = ctx->getState()->beaver(); + const auto k = ctx->getState()->k(); + const auto s = ctx->getState()->s(); const auto& x = getValueShare(in); - const auto& [vec, mac_vec] = beaver->AuthTrunc(field, x.numel(), bits); + const auto& x_mac = getMacShare(in); + const auto& [vec, mac_vec] = beaver->AuthTrunc(field, x.numel(), bits, k, s); const auto& [r, rb] = vec; const auto& [r_mac, rb_mac] = mac_vec; // open x - r - auto x_r = comm->allReduce(ReduceOp::ADD, ring_sub(x, r), kBindName); - auto tr_x_r = ring_arshift(x_r, bits); + auto [x_r, check_mac] = + beaver->BatchOpen(ring_sub(x, r), ring_sub(x_mac, r_mac), k, s); + SPU_ENFORCE(beaver->BatchMacCheck(x_r, check_mac, k, s)); + size_t bit_len = SizeOf(field) * 8; + auto tr_x_r = ring_arshift(ring_lshift(x_r, bit_len - k), bit_len - k + bits); + ring_bitmask_(tr_x_r, 0, k); // res = [x-r] + [r], which [*] is truncation operation. auto res = rb; diff --git a/libspu/mpc/spdz2k/arithmetic.h b/libspu/mpc/spdz2k/arithmetic.h index 8fa69ca0..32aad802 100644 --- a/libspu/mpc/spdz2k/arithmetic.h +++ b/libspu/mpc/spdz2k/arithmetic.h @@ -18,6 +18,8 @@ namespace spu::mpc::spdz2k { +ArrayRef GetMacShare(KernelEvalContext* ctx, const ArrayRef& in); + class RandA : public RandKernel { public: static constexpr char kBindName[] = "rand_a"; @@ -44,9 +46,34 @@ class A2P : public UnaryKernel { public: static constexpr char kBindName[] = "a2p"; - ce::CExpr latency() const override { return ce::Const(4); } + Kind kind() const override { return Kind::Dynamic; } + + ArrayRef proc(KernelEvalContext* ctx, const ArrayRef& in) const override; +}; + +class A2V : public RevealToKernel { + public: + static constexpr char kBindName[] = "a2v"; - ce::CExpr comm() const override { return ce::K() * 3 * (ce::N() - 1); } + Kind kind() const override { return Kind::Dynamic; } + + ce::CExpr latency() const override { return ce::Const(1); } + + ce::CExpr comm() const override { return ce::K(); } + + ArrayRef proc(KernelEvalContext* ctx, const ArrayRef& in, + size_t rank) const override; +}; + +class V2A : public UnaryKernel { + public: + static constexpr char kBindName[] = "v2a"; + + Kind kind() const override { return Kind::Dynamic; } + + ce::CExpr latency() const override { return ce::Const(1); } + + ce::CExpr comm() const override { return ce::K(); } ArrayRef proc(KernelEvalContext* ctx, const ArrayRef& in) const override; }; @@ -92,9 +119,6 @@ class AddAA : public BinaryKernel { //////////////////////////////////////////////////////////////////// // multiply family //////////////////////////////////////////////////////////////////// -bool SingleCheck(KernelEvalContext* ctx, const ArrayRef& in); -bool BatchCheck(KernelEvalContext* ctx, const std::vector& ins); - class MulAP : public BinaryKernel { public: static constexpr char kBindName[] = "mul_ap"; @@ -111,9 +135,7 @@ class MulAA : public BinaryKernel { public: static constexpr char kBindName[] = "mul_aa"; - ce::CExpr latency() const override { return ce::Const(1); } - - ce::CExpr comm() const override { return ce::K() * 2 * (ce::N() - 1); } + Kind kind() const override { return Kind::Dynamic; } ArrayRef proc(KernelEvalContext* ctx, const ArrayRef& lhs, const ArrayRef& rhs) const override; @@ -163,22 +185,21 @@ class LShiftA : public ShiftKernel { size_t bits) const override; }; +// Refer to: +// New Primitives for Actively-Secure MPC over Rings with Applications to +// Private Machine Learning, Appedix C. Probabilistic Truncation +// - https://eprint.iacr.org/2019/599.pdf class TruncA : public TruncAKernel { public: static constexpr char kBindName[] = "trunc_a"; Kind kind() const override { return Kind::Dynamic; } - ce::CExpr latency() const override { return ce::Const(0); } - - ce::CExpr comm() const override { return ce::Const(0); } - ArrayRef proc(KernelEvalContext* ctx, const ArrayRef& in, size_t bits) const override; bool hasMsbError() const override { return true; } - // FIXME(shangqi) what the type? TruncLsbRounding lsbRounding() const override { return TruncLsbRounding::Random; } diff --git a/libspu/mpc/spdz2k/beaver/BUILD.bazel b/libspu/mpc/spdz2k/beaver/BUILD.bazel index 8f835c00..c5ad223f 100644 --- a/libspu/mpc/spdz2k/beaver/BUILD.bazel +++ b/libspu/mpc/spdz2k/beaver/BUILD.bazel @@ -16,13 +16,24 @@ load("//bazel:spu.bzl", "spu_cc_library", "spu_cc_test") package(default_visibility = ["//visibility:public"]) +spu_cc_library( + name = "beaver_interface", + hdrs = ["beaver_interface.h"], + deps = [ + "//libspu/core", + ], +) + spu_cc_library( name = "beaver_tfp", srcs = ["beaver_tfp.cc"], hdrs = ["beaver_tfp.h"], deps = [ + ":beaver_interface", ":trusted_party", + "//libspu/mpc/common:communicator", "//libspu/mpc/common:prg_tensor", + "//libspu/mpc/spdz2k:commitment", "//libspu/mpc/utils:ring_ops", "@com_github_microsoft_seal//:seal", "@yacl//yacl/link", @@ -31,10 +42,11 @@ spu_cc_library( ) spu_cc_test( - name = "beaver_tfp_test", - srcs = ["beaver_tfp_test.cc"], + name = "beaver_test", + srcs = ["beaver_test.cc"], deps = [ ":beaver_tfp", + ":beaver_tinyot", "//libspu/mpc/utils:simulate", "@com_google_googletest//:gtest", ], @@ -50,3 +62,24 @@ spu_cc_library( "//libspu/mpc/utils:ring_ops", ], ) + +spu_cc_library( + name = "beaver_tinyot", + srcs = ["beaver_tinyot.cc"], + hdrs = ["beaver_tinyot.h"], + deps = [ + ":beaver_interface", + ":trusted_party", + "//libspu/mpc/spdz2k:commitment", + "//libspu/mpc/spdz2k/ot:basic_ot_prot", + "//libspu/mpc/spdz2k/ot:kos_ote", + "//libspu/mpc/spdz2k/ot:tiny_ot", + "//libspu/mpc/common:prg_state", + "//libspu/mpc/utils:ring_ops", + "@yacl//yacl/link", + "@yacl//yacl/utils:serialize", + "@yacl//yacl/crypto/primitives/ot:base_ot", + "@yacl//yacl/crypto/tools:prg", + "@yacl//yacl/utils:matrix_utils", + ] +) diff --git a/libspu/mpc/spdz2k/beaver/beaver_interface.h b/libspu/mpc/spdz2k/beaver/beaver_interface.h new file mode 100644 index 00000000..414f3e35 --- /dev/null +++ b/libspu/mpc/spdz2k/beaver/beaver_interface.h @@ -0,0 +1,68 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "libspu/core/array_ref.h" + +namespace spu::mpc::spdz2k { + +class Beaver { + public: + using Triple = std::tuple; + using Pair = std::pair; + using Pair_Pair = std::pair; + using Triple_Pair = std::pair; + + virtual ~Beaver() = default; + + virtual uint128_t InitSpdzKey(FieldType field, size_t s) = 0; + + virtual ArrayRef AuthArrayRef(const ArrayRef& value, FieldType field, + size_t k, size_t s) = 0; + + virtual Pair AuthCoinTossing(FieldType field, size_t size, size_t k, + size_t s) = 0; + + virtual Triple_Pair AuthMul(FieldType field, size_t size, size_t k, + size_t s) = 0; + + virtual Triple_Pair AuthDot(FieldType field, size_t M, size_t N, size_t K, + size_t k, size_t s) = 0; + + virtual Triple_Pair AuthAnd(FieldType field, size_t size, size_t s) = 0; + + virtual Pair_Pair AuthTrunc(FieldType field, size_t size, size_t bits, + size_t k, size_t s) = 0; + + virtual Pair AuthRandBit(FieldType field, size_t size, size_t k, + size_t s) = 0; + + // Check the opened value only + virtual bool BatchMacCheck(const ArrayRef& open_value, const ArrayRef& mac, + size_t k, size_t s) = 0; + + // Open the low k_bits of value only + virtual std::pair BatchOpen(const ArrayRef& value, + const ArrayRef& mac, size_t k, + size_t s) = 0; + + // public coin, used in malicious model, all party generate new seed, then + // get exactly the same random variable. + virtual ArrayRef genPublCoin(FieldType field, size_t numel) = 0; +}; + +} // namespace spu::mpc::spdz2k diff --git a/libspu/mpc/spdz2k/beaver/beaver_test.cc b/libspu/mpc/spdz2k/beaver/beaver_test.cc new file mode 100644 index 00000000..fd2e95df --- /dev/null +++ b/libspu/mpc/spdz2k/beaver/beaver_test.cc @@ -0,0 +1,394 @@ +// Copyright 2021 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" +#include "xtensor/xarray.hpp" +#include "yacl/link/link.h" + +#include "libspu/core/type_util.h" +#include "libspu/core/xt_helper.h" +#include "libspu/mpc/spdz2k/beaver/beaver_tfp.h" +#include "libspu/mpc/spdz2k/beaver/beaver_tinyot.h" +#include "libspu/mpc/utils/ring_ops.h" +#include "libspu/mpc/utils/simulate.h" + +namespace spu::mpc::spdz2k { + +class BeaverTest + : public ::testing::TestWithParam( + const std::shared_ptr& lctx)>, + std::string>, + size_t, FieldType, long, size_t, size_t>> { + public: + using Pair = typename Beaver::Pair; + using PairPair = typename Beaver::Pair_Pair; + using TriplePair = typename Beaver::Triple_Pair; +}; + +INSTANTIATE_TEST_SUITE_P( + BeaverTestSuite, BeaverTest, + testing::Values( + std::tuple{std::make_pair( + [](const std::shared_ptr& lctx) { + return std::make_unique(lctx); + }, + "BeaverTfpUnsafe"), + 2, FieldType::FM128, 0, 64, 64}, + std::tuple{std::make_pair( + [](const std::shared_ptr& lctx) { + return std::make_unique(lctx); + }, + "BeaverTfpUnsafe"), + 2, FieldType::FM64, 0, 32, 32}, + std::tuple{std::make_pair( + [](const std::shared_ptr& lctx) { + return std::make_unique(lctx); + }, + "BeaverTinyOt"), + 2, FieldType::FM64, 0, 32, 32}, + std::tuple{std::make_pair( + [](const std::shared_ptr& lctx) { + return std::make_unique(lctx); + }, + "BeaverTinyOt"), + 2, FieldType::FM128, 0, 64, 64}), + [](const testing::TestParamInfo& p) { + return fmt::format("{}x{}x{}", std::get<0>(p.param).second, + std::get<1>(p.param), std::get<2>(p.param)); + }); + +TEST_P(BeaverTest, AuthAnd) { + const auto factory = std::get<0>(GetParam()).first; + const size_t kWorldSize = std::get<1>(GetParam()); + const FieldType kField = std::get<2>(GetParam()); + const int64_t kMaxDiff = std::get<3>(GetParam()); + const size_t s = std::get<5>(GetParam()); + const size_t kNumel = 10; + + std::vector keys(kWorldSize); + std::vector triples(kWorldSize); + + utils::simulate(kWorldSize, [&](std::shared_ptr lctx) { + auto beaver = factory(lctx); + keys[lctx->Rank()] = beaver->InitSpdzKey(kField, s); + triples[lctx->Rank()] = beaver->AuthAnd(kField, kNumel, s); + }); + + uint128_t sum_key = 0; + auto sum_a = ring_zeros(kField, kNumel); + auto sum_b = ring_zeros(kField, kNumel); + auto sum_c = ring_zeros(kField, kNumel); + auto sum_a_mac = ring_zeros(kField, kNumel); + auto sum_b_mac = ring_zeros(kField, kNumel); + auto sum_c_mac = ring_zeros(kField, kNumel); + for (Rank r = 0; r < kWorldSize; r++) { + sum_key += keys[r]; + + const auto& [vec, mac_vec] = triples[r]; + const auto& [a, b, c] = vec; + const auto& [a_mac, b_mac, c_mac] = mac_vec; + EXPECT_EQ(a.numel(), kNumel); + EXPECT_EQ(b.numel(), kNumel); + EXPECT_EQ(c.numel(), kNumel); + EXPECT_EQ(a_mac.numel(), kNumel); + EXPECT_EQ(b_mac.numel(), kNumel); + EXPECT_EQ(c_mac.numel(), kNumel); + + ring_add_(sum_a, a); + ring_add_(sum_b, b); + ring_add_(sum_c, c); + ring_add_(sum_a_mac, a_mac); + ring_add_(sum_b_mac, b_mac); + ring_add_(sum_c_mac, c_mac); + } + + auto valid_a = ring_bitmask(sum_a, 0, 1); + auto valid_b = ring_bitmask(sum_b, 0, 1); + auto valid_c = ring_bitmask(sum_c, 0, 1); + + EXPECT_EQ(ring_mul(valid_a, valid_b), valid_c) << sum_a << sum_b << sum_c; + EXPECT_EQ(ring_mul(sum_a, sum_key), sum_a_mac) + << sum_a << sum_key << sum_a_mac; + EXPECT_EQ(ring_mul(sum_b, sum_key), sum_b_mac) + << sum_b << sum_key << sum_b_mac; + EXPECT_EQ(ring_mul(sum_c, sum_key), sum_c_mac) + << sum_c << sum_key << sum_c_mac; + + DISPATCH_ALL_FIELDS(kField, "_", [&]() { + auto _a = ArrayView(valid_a); + auto _b = ArrayView(valid_b); + auto _c = ArrayView(valid_c); + for (auto idx = 0; idx < sum_a.numel(); idx++) { + auto t = _a[idx] * _b[idx]; + auto err = t > _c[idx] ? t - _c[idx] : _c[idx] - t; + EXPECT_LE(err, kMaxDiff); + } + }); +} + +TEST_P(BeaverTest, AuthArrayRef) { + const auto factory = std::get<0>(GetParam()).first; + const size_t kWorldSize = std::get<1>(GetParam()); + const FieldType kField = std::get<2>(GetParam()); + const size_t k = std::get<4>(GetParam()); + const size_t s = std::get<5>(GetParam()); + const size_t kNumel = 10; + + std::vector values(kWorldSize); + std::vector keys(kWorldSize); + std::vector macs(kWorldSize); + + utils::simulate(kWorldSize, [&](std::shared_ptr lctx) { + auto beaver = factory(lctx); + keys[lctx->Rank()] = beaver->InitSpdzKey(kField, s); + values[lctx->Rank()] = ring_rand(kField, kNumel); + macs[lctx->Rank()] = + beaver->AuthArrayRef(values[lctx->Rank()], kField, k, s); + }); + + uint128_t sum_key = 0; + auto sum_a = ring_zeros(kField, kNumel); + auto sum_a_mac = ring_zeros(kField, kNumel); + for (Rank r = 0; r < kWorldSize; r++) { + sum_key += keys[r]; + + const auto& a = values[r]; + const auto& a_mac = macs[r]; + EXPECT_EQ(a.numel(), kNumel); + EXPECT_EQ(a_mac.numel(), kNumel); + + ring_add_(sum_a, a); + ring_add_(sum_a_mac, a_mac); + } + + EXPECT_EQ(ring_mul(sum_a, sum_key), sum_a_mac) + << sum_a << sum_key << sum_a_mac; +} + +TEST_P(BeaverTest, AuthCoinTossing) { + const auto factory = std::get<0>(GetParam()).first; + const size_t kWorldSize = std::get<1>(GetParam()); + const FieldType kField = std::get<2>(GetParam()); + const size_t k = std::get<4>(GetParam()); + const size_t s = std::get<5>(GetParam()); + const size_t kNumel = 10; + + std::vector keys(kWorldSize); + std::vector pairs(kWorldSize); + + utils::simulate(kWorldSize, [&](std::shared_ptr lctx) { + auto beaver = factory(lctx); + keys[lctx->Rank()] = beaver->InitSpdzKey(kField, s); + pairs[lctx->Rank()] = beaver->AuthCoinTossing(kField, kNumel, k, s); + }); + + uint128_t sum_key = 0; + auto sum_a = ring_zeros(kField, kNumel); + auto sum_a_mac = ring_zeros(kField, kNumel); + for (Rank r = 0; r < kWorldSize; r++) { + sum_key += keys[r]; + + const auto& [a, a_mac] = pairs[r]; + EXPECT_EQ(a.numel(), kNumel); + EXPECT_EQ(a_mac.numel(), kNumel); + + ring_add_(sum_a, a); + ring_add_(sum_a_mac, a_mac); + } + + EXPECT_EQ(ring_mul(sum_a, sum_key), sum_a_mac) + << sum_a << sum_key << sum_a_mac; +} + +TEST_P(BeaverTest, AuthMul) { + const auto factory = std::get<0>(GetParam()).first; + const size_t kWorldSize = std::get<1>(GetParam()); + const FieldType kField = std::get<2>(GetParam()); + const size_t k = std::get<4>(GetParam()); + const size_t s = std::get<5>(GetParam()); + const size_t kNumel = 10; + + std::vector keys(kWorldSize); + std::vector triples(kWorldSize); + + utils::simulate(kWorldSize, [&](std::shared_ptr lctx) { + auto beaver = factory(lctx); + keys[lctx->Rank()] = beaver->InitSpdzKey(kField, s); + triples[lctx->Rank()] = beaver->AuthMul(kField, kNumel, k, s); + }); + + uint128_t sum_key = 0; + auto sum_a = ring_zeros(kField, kNumel); + auto sum_b = ring_zeros(kField, kNumel); + auto sum_c = ring_zeros(kField, kNumel); + auto sum_a_mac = ring_zeros(kField, kNumel); + auto sum_b_mac = ring_zeros(kField, kNumel); + auto sum_c_mac = ring_zeros(kField, kNumel); + for (Rank r = 0; r < kWorldSize; r++) { + sum_key += keys[r]; + + const auto& [vec, mac_vec] = triples[r]; + const auto& [a, b, c] = vec; + const auto& [a_mac, b_mac, c_mac] = mac_vec; + EXPECT_EQ(a.numel(), kNumel); + EXPECT_EQ(b.numel(), kNumel); + EXPECT_EQ(c.numel(), kNumel); + EXPECT_EQ(a_mac.numel(), kNumel); + EXPECT_EQ(b_mac.numel(), kNumel); + EXPECT_EQ(c_mac.numel(), kNumel); + + ring_add_(sum_a, a); + ring_add_(sum_b, b); + ring_add_(sum_c, c); + ring_add_(sum_a_mac, a_mac); + ring_add_(sum_b_mac, b_mac); + ring_add_(sum_c_mac, c_mac); + } + + EXPECT_EQ(ring_mul(sum_a, sum_b), sum_c) << sum_a << sum_b << sum_c; + EXPECT_EQ(ring_mul(sum_a, sum_key), sum_a_mac) + << sum_a << sum_key << sum_a_mac; + EXPECT_EQ(ring_mul(sum_b, sum_key), sum_b_mac) + << sum_b << sum_key << sum_b_mac; + EXPECT_EQ(ring_mul(sum_c, sum_key), sum_c_mac) + << sum_c << sum_key << sum_c_mac; +} + +TEST_P(BeaverTest, AuthTrunc) { + const auto factory = std::get<0>(GetParam()).first; + const size_t kWorldSize = std::get<1>(GetParam()); + const FieldType kField = std::get<2>(GetParam()); + const size_t k = std::get<4>(GetParam()); + const size_t s = std::get<5>(GetParam()); + const size_t kNumel = 7; + const size_t kBits = 4; + + std::vector keys(kWorldSize); + std::vector pairs(kWorldSize); + + utils::simulate( + kWorldSize, [&](const std::shared_ptr& lctx) { + auto beaver = factory(lctx); + keys[lctx->Rank()] = beaver->InitSpdzKey(kField, s); + pairs[lctx->Rank()] = beaver->AuthTrunc(kField, kNumel, kBits, k, s); + }); + + EXPECT_EQ(pairs.size(), kWorldSize); + uint128_t sum_key = 0; + auto sum_a = ring_zeros(kField, kNumel); + auto sum_b = ring_zeros(kField, kNumel); + auto sum_a_mac = ring_zeros(kField, kNumel); + auto sum_b_mac = ring_zeros(kField, kNumel); + for (Rank r = 0; r < kWorldSize; r++) { + sum_key += keys[r]; + + const auto& [vec, mac_vec] = (pairs[r]); + const auto& [a, b] = vec; + const auto& [a_mac, b_mac] = mac_vec; + + EXPECT_EQ(a.numel(), kNumel); + EXPECT_EQ(b.numel(), kNumel); + EXPECT_EQ(a_mac.numel(), kNumel); + EXPECT_EQ(b_mac.numel(), kNumel); + + ring_add_(sum_a, a); + ring_add_(sum_b, b); + ring_add_(sum_a_mac, a_mac); + ring_add_(sum_b_mac, b_mac); + } + + const size_t bit_len = SizeOf(kField) * 8; + auto trunc_sum_a = + ring_arshift(ring_lshift(sum_a, bit_len - k), bit_len - k + kBits); + ring_bitmask_(trunc_sum_a, 0, k); + + EXPECT_EQ(trunc_sum_a, ring_bitmask(sum_b, 0, k)) << trunc_sum_a << sum_b; + EXPECT_EQ(ring_mul(sum_a, sum_key), sum_a_mac) + << sum_a << sum_key << sum_a_mac; + EXPECT_EQ(ring_mul(sum_b, sum_key), sum_b_mac) + << sum_b << sum_key << sum_b_mac; +} + +TEST_P(BeaverTest, AuthDot) { + const auto factory = std::get<0>(GetParam()).first; + const size_t kWorldSize = std::get<1>(GetParam()); + const FieldType kField = std::get<2>(GetParam()); + const int64_t kMaxDiff = std::get<3>(GetParam()); + const size_t k = std::get<4>(GetParam()); + const size_t s = std::get<5>(GetParam()); + // M > N + const size_t M = 17; + const size_t N = 8; + const size_t K = 13; + + std::vector keys(kWorldSize); + std::vector triples(kWorldSize); + + utils::simulate( + kWorldSize, [&](const std::shared_ptr& lctx) { + auto beaver = factory(lctx); + keys[lctx->Rank()] = beaver->InitSpdzKey(kField, s); + triples[lctx->Rank()] = beaver->AuthDot(kField, M, N, K, k, s); + yacl::link::Barrier(lctx, "BeaverUT"); + }); + + EXPECT_EQ(triples.size(), kWorldSize); + uint128_t sum_key = 0; + auto sum_a = ring_zeros(kField, M * K); + auto sum_b = ring_zeros(kField, K * N); + auto sum_c = ring_zeros(kField, M * N); + auto sum_a_mac = ring_zeros(kField, M * K); + auto sum_b_mac = ring_zeros(kField, K * N); + auto sum_c_mac = ring_zeros(kField, M * N); + for (Rank r = 0; r < kWorldSize; r++) { + sum_key += keys[r]; + + const auto& [vec, mac_vec] = triples[r]; + const auto& [a, b, c] = vec; + const auto& [a_mac, b_mac, c_mac] = mac_vec; + EXPECT_EQ(a.numel(), M * K); + EXPECT_EQ(b.numel(), K * N); + EXPECT_EQ(c.numel(), M * N); + EXPECT_EQ(a_mac.numel(), M * K); + EXPECT_EQ(b_mac.numel(), K * N); + EXPECT_EQ(c_mac.numel(), M * N); + + ring_add_(sum_a, a); + ring_add_(sum_b, b); + ring_add_(sum_c, c); + ring_add_(sum_a_mac, a_mac); + ring_add_(sum_b_mac, b_mac); + ring_add_(sum_c_mac, c_mac); + } + + EXPECT_EQ(ring_mul(sum_a, sum_key), sum_a_mac) + << sum_a << sum_key << sum_a_mac; + EXPECT_EQ(ring_mul(sum_b, sum_key), sum_b_mac) + << sum_b << sum_key << sum_b_mac; + EXPECT_EQ(ring_mul(sum_c, sum_key), sum_c_mac) + << sum_c << sum_key << sum_c_mac; + + auto res = ring_mmul(sum_a, sum_b, M, N, K); + DISPATCH_ALL_FIELDS(kField, "_", [&]() { + auto _r = ArrayView(res); + auto _c = ArrayView(sum_c); + for (auto idx = 0; idx < _r.numel(); idx++) { + auto err = _r[idx] > _c[idx] ? _r[idx] - _c[idx] : _c[idx] - _r[idx]; + EXPECT_LE(err, kMaxDiff); + } + }); +} + +} // namespace spu::mpc::spdz2k diff --git a/libspu/mpc/spdz2k/beaver/beaver_tfp.cc b/libspu/mpc/spdz2k/beaver/beaver_tfp.cc index d4d372fa..58d378bc 100644 --- a/libspu/mpc/spdz2k/beaver/beaver_tfp.cc +++ b/libspu/mpc/spdz2k/beaver/beaver_tfp.cc @@ -22,52 +22,88 @@ #include "yacl/utils/serialize.h" #include "libspu/mpc/common/prg_tensor.h" +#include "libspu/mpc/spdz2k/commitment.h" #include "libspu/mpc/utils/ring_ops.h" namespace spu::mpc::spdz2k { BeaverTfpUnsafe::BeaverTfpUnsafe(std::shared_ptr lctx) - : lctx_(std::move(std::move(lctx))), - seed_(yacl::crypto::RandSeed(true)), - counter_(0) { + : seed_(yacl::crypto::SecureRandSeed()), counter_(0) { + comm_ = std::make_unique(lctx); auto buf = yacl::SerializeUint128(seed_); std::vector all_bufs = - yacl::link::Gather(lctx_, buf, 0, "BEAVER_TFP:SYNC_SEEDS"); + yacl::link::Gather(lctx, buf, 0, "BEAVER_TFP:SYNC_SEEDS"); - if (lctx_->Rank() == 0) { + if (comm_->getRank() == 0) { // Collects seeds from all parties. - for (size_t rank = 0; rank < lctx_->WorldSize(); ++rank) { + for (size_t rank = 0; rank < comm_->getWorldSize(); ++rank) { PrgSeed seed = yacl::DeserializeUint128(all_bufs[rank]); - tp_.setSeed(rank, lctx_->WorldSize(), seed); + tp_.setSeed(rank, comm_->getWorldSize(), seed); } } } -uint128_t BeaverTfpUnsafe::GetSpdzKey(FieldType field, size_t s) { +uint128_t BeaverTfpUnsafe::InitSpdzKey(FieldType field, size_t s) { PrgArrayDesc desc{}; const size_t size = 1; auto a = prgCreateArray(field, size, seed_, &counter_, &desc); - if (lctx_->Rank() == 0) { - auto t = tp_.adjustSpdzKey(desc); - global_key_ = yacl::crypto::RandSeed(true); - global_key_ &= (static_cast(1) << s) - 1; - a.at(0) += global_key_ - t.at(0); + return DISPATCH_ALL_FIELDS(field, "_", [&]() { + if (comm_->getRank() == 0) { + auto t = tp_.adjustSpdzKey(desc); + global_key_ = yacl::crypto::SecureRandSeed(); + global_key_ &= (static_cast(1) << s) - 1; + + a.at(0) += global_key_ - t.at(0); + } + + spdz_key_ = a.at(0); + + return spdz_key_; + }); +} + +ArrayRef BeaverTfpUnsafe::AuthArrayRef(const ArrayRef& value, FieldType field, + size_t k, size_t s) { + auto [r, r_mac] = AuthCoinTossing(field, value.numel(), k, s); + auto x_r = + comm_->reduce(ReduceOp::ADD, ring_sub(value, r), 0, "auth_arrayref"); + + if (comm_->getRank() == 0) { + ring_add_(r_mac, ring_mul(x_r, global_key_)); } - return a.at(0); + return r_mac; } BeaverTfpUnsafe::Pair BeaverTfpUnsafe::AuthCoinTossing(FieldType field, - size_t size, size_t s) { + size_t size, size_t k, + size_t s) { + PrgArrayDesc desc{}; + PrgArrayDesc mac_desc{}; + + auto x = prgCreateArray(field, size, seed_, &counter_, &desc); + auto x_mac = prgCreateArray(field, size, seed_, &counter_, &mac_desc); + + if (comm_->getRank() == 0) { + auto v = tp_.adjustAuthCoinTossing(desc, mac_desc, global_key_, k, s); + x = v[0]; + x_mac = v[1]; + } + + return {x, x_mac}; +} + +BeaverTfpUnsafe::Pair BeaverTfpUnsafe::AuthRandBit(FieldType field, size_t size, + size_t k, size_t s) { PrgArrayDesc desc{}; PrgArrayDesc mac_desc{}; auto x = prgCreateArray(field, size, seed_, &counter_, &desc); auto x_mac = prgCreateArray(field, size, seed_, &counter_, &mac_desc); - if (lctx_->Rank() == 0) { - auto v = tp_.adjustAuthCoinTossing(desc, mac_desc, global_key_, s); + if (comm_->getRank() == 0) { + auto v = tp_.adjustAuthRandBit(desc, mac_desc, global_key_, s); x = v[0]; x_mac = v[1]; } @@ -76,7 +112,8 @@ BeaverTfpUnsafe::Pair BeaverTfpUnsafe::AuthCoinTossing(FieldType field, } BeaverTfpUnsafe::Triple_Pair BeaverTfpUnsafe::AuthMul(FieldType field, - size_t size) { + size_t size, size_t k, + size_t s) { std::vector descs(3); std::vector mac_descs(3); @@ -88,7 +125,7 @@ BeaverTfpUnsafe::Triple_Pair BeaverTfpUnsafe::AuthMul(FieldType field, auto b_mac = prgCreateArray(field, size, seed_, &counter_, &mac_descs[1]); auto c_mac = prgCreateArray(field, size, seed_, &counter_, &mac_descs[2]); - if (lctx_->Rank() == 0) { + if (comm_->getRank() == 0) { auto v = tp_.adjustAuthMul(descs, mac_descs, global_key_); c = v[0]; a_mac = v[1]; @@ -100,7 +137,9 @@ BeaverTfpUnsafe::Triple_Pair BeaverTfpUnsafe::AuthMul(FieldType field, } BeaverTfpUnsafe::Triple_Pair BeaverTfpUnsafe::AuthDot(FieldType field, size_t m, - size_t n, size_t k) { + size_t n, size_t k, + size_t k_bits, + size_t s_bits) { std::vector descs(3); std::vector mac_descs(3); @@ -112,7 +151,7 @@ BeaverTfpUnsafe::Triple_Pair BeaverTfpUnsafe::AuthDot(FieldType field, size_t m, auto b_mac = prgCreateArray(field, k * n, seed_, &counter_, &mac_descs[1]); auto c_mac = prgCreateArray(field, m * n, seed_, &counter_, &mac_descs[2]); - if (lctx_->Rank() == 0) { + if (comm_->getRank() == 0) { auto v = tp_.adjustAuthDot(descs, mac_descs, m, n, k, global_key_); c = v[0]; a_mac = v[1]; @@ -123,9 +162,33 @@ BeaverTfpUnsafe::Triple_Pair BeaverTfpUnsafe::AuthDot(FieldType field, size_t m, return {{a, b, c}, {a_mac, b_mac, c_mac}}; } +BeaverTfpUnsafe::Triple_Pair BeaverTfpUnsafe::AuthAnd(FieldType field, + size_t size, size_t s) { + std::vector descs(3); + std::vector mac_descs(3); + + auto a = prgCreateArray(field, size, seed_, &counter_, descs.data()); + auto b = prgCreateArray(field, size, seed_, &counter_, &descs[1]); + auto c = prgCreateArray(field, size, seed_, &counter_, &descs[2]); + + auto a_mac = prgCreateArray(field, size, seed_, &counter_, mac_descs.data()); + auto b_mac = prgCreateArray(field, size, seed_, &counter_, &mac_descs[1]); + auto c_mac = prgCreateArray(field, size, seed_, &counter_, &mac_descs[2]); + + if (comm_->getRank() == 0) { + auto v = tp_.adjustAuthAnd(descs, mac_descs, global_key_); + c = v[0]; + a_mac = v[1]; + b_mac = v[2]; + c_mac = v[3]; + } + + return {{a, b, c}, {a_mac, b_mac, c_mac}}; +} + BeaverTfpUnsafe::Pair_Pair BeaverTfpUnsafe::AuthTrunc(FieldType field, - size_t size, - size_t bits) { + size_t size, size_t bits, + size_t k, size_t s) { std::vector descs(2); std::vector mac_descs(2); @@ -134,14 +197,125 @@ BeaverTfpUnsafe::Pair_Pair BeaverTfpUnsafe::AuthTrunc(FieldType field, auto a_mac = prgCreateArray(field, size, seed_, &counter_, mac_descs.data()); auto b_mac = prgCreateArray(field, size, seed_, &counter_, &mac_descs[1]); - if (lctx_->Rank() == 0) { - auto v = tp_.adjustAuthTrunc(descs, mac_descs, bits, global_key_); - b = v[0]; - a_mac = v[1]; - b_mac = v[2]; + if (comm_->getRank() == 0) { + auto v = tp_.adjustAuthTrunc(descs, mac_descs, bits, global_key_, k, s); + a = v[0]; + b = v[1]; + a_mac = v[2]; + b_mac = v[3]; } return {{a, b}, {a_mac, b_mac}}; } +ArrayRef BeaverTfpUnsafe::genPublCoin(FieldType field, size_t numel) { + ArrayRef res(makeType(field), numel); + + // generate new seed + uint128_t self_pk = yacl::crypto::SecureRandSeed(); + std::vector all_strs; + + std::string self_pk_str(reinterpret_cast(&self_pk), sizeof(self_pk)); + SPU_ENFORCE(commit_and_open(comm_->lctx(), self_pk_str, &all_strs)); + + uint128_t public_seed = 0; + for (const auto& str : all_strs) { + uint128_t seed = *(reinterpret_cast(str.data())); + public_seed += seed; + } + + auto kAesType = yacl::crypto::SymmetricCrypto::CryptoType::AES128_CTR; + yacl::crypto::FillPRand( + kAesType, public_seed, 0, 0, + absl::MakeSpan(static_cast(res.data()), res.buf()->size())); + + return res; +} + +// Refer to: +// Procedure BatchCheck, 3.2 Batch MAC Checking with Random Linear +// Combinations, SPDZ2k: Efficient MPC mod 2k for Dishonest Majority +// - https://eprint.iacr.org/2018/482.pdf +// +// Open the value only +// Notice return { open_val , zero_mac = open_val * \sum spdz_key_ } +// the last kth bits in open_val is valid +std::pair BeaverTfpUnsafe::BatchOpen(const ArrayRef& value, + const ArrayRef& mac, + size_t k, size_t s) { + static constexpr char kBindName[] = "batch_open"; + SPU_ENFORCE(value.numel() == mac.numel()); + + const auto field = value.eltype().as()->field(); + const auto numel = value.numel(); + + auto [r_val, r_mac] = AuthCoinTossing(field, numel, k, s); + + // Open the low k_bits only + // value = value + r_val * 2^k + // mac = mac + r_mac * 2^k + auto masked_val = ring_add(value, ring_lshift(r_val, k)); + auto masked_mac = ring_add(mac, ring_lshift(r_mac, k)); + + auto open_val = comm_->allReduce(ReduceOp::ADD, masked_val, kBindName); + return {open_val, masked_mac}; +} + +// Refer to: +// Procedure BatchCheck, 3.2 Batch MAC Checking with Random Linear +// Combinations, SPDZ2k: Efficient MPC mod 2k for Dishonest Majority +// - https://eprint.iacr.org/2018/482.pdf +// +// Check the opened value only +bool BeaverTfpUnsafe::BatchMacCheck(const ArrayRef& open_value, + const ArrayRef& mac, size_t k, size_t s) { + const auto field = open_value.eltype().as()->field(); + const size_t numel = open_value.numel(); + const size_t mac_bits = k + s; + + auto* comm = comm_.get(); + const auto& lctx = comm_->lctx(); + const auto key = spdz_key_; + + // 1. get l public random values, compute plain y + auto pub_r = genPublCoin(field, numel); + ring_bitmask_(pub_r, 0, s); + + // 2. check_value = pub_r * open_value + // check_mac = pub_r * mac + auto check_value = ring_mmul(pub_r, open_value, 1, 1, numel); + auto check_mac = ring_mmul(pub_r, mac, 1, 1, numel); + + // 3. compute z, commit and open z + auto z = ring_sub(check_mac, ring_mul(check_value, key)); + + std::string z_str(reinterpret_cast(z.data()), z.numel() * z.elsize()); + std::vector z_strs; + SPU_ENFORCE(commit_and_open(lctx, z_str, &z_strs)); + SPU_ENFORCE(z_strs.size() == comm->getWorldSize()); + + // since the commit size in commit_and_open is independent with numel, we + // ignore it + comm->addCommStatsManually(1, 0); + // since the random string size in commit_and_open is independent with numel, + // we ignore it + comm->addCommStatsManually(1, + z_str.size() / numel * (comm->getWorldSize() - 1)); + + // 4. verify whether plain z is zero + auto plain_z = ring_zeros(field, 1); + for (size_t i = 0; i < comm->getWorldSize(); ++i) { + const auto& _z_str = z_strs[i]; + auto mem = std::make_shared(_z_str.data(), _z_str.size()); + ArrayRef a(mem, plain_z.eltype(), _z_str.size() / SizeOf(field), 1, 0); + ring_add_(plain_z, a); + } + + if (mac_bits != 0) { + ring_bitmask_(plain_z, 0, mac_bits); + } + + return ring_all_equal(plain_z, ring_zeros(field, 1)); +} + } // namespace spu::mpc::spdz2k diff --git a/libspu/mpc/spdz2k/beaver/beaver_tfp.h b/libspu/mpc/spdz2k/beaver/beaver_tfp.h index 75f29299..6c548733 100644 --- a/libspu/mpc/spdz2k/beaver/beaver_tfp.h +++ b/libspu/mpc/spdz2k/beaver/beaver_tfp.h @@ -18,7 +18,8 @@ #include "yacl/link/context.h" -#include "libspu/mpc/spdz2k/beaver/beaver_tfp.h" +#include "libspu/mpc/common/communicator.h" +#include "libspu/mpc/spdz2k/beaver/beaver_interface.h" #include "libspu/mpc/spdz2k/beaver/trusted_party.h" namespace spu::mpc::spdz2k { @@ -29,12 +30,12 @@ namespace spu::mpc::spdz2k { // NOT BE used in production. // // Check security implications before moving on. -class BeaverTfpUnsafe final { +class BeaverTfpUnsafe final : public Beaver { protected: // Only for rank0 party. TrustedParty tp_; - std::shared_ptr lctx_; + std::unique_ptr comm_; PrgSeed seed_; @@ -42,6 +43,12 @@ class BeaverTfpUnsafe final { uint128_t global_key_; + // spzd key + uint128_t spdz_key_; + + // security parameters + static constexpr int kappa_ = 128; + public: using Triple = std::tuple; using Pair = std::pair; @@ -51,17 +58,39 @@ class BeaverTfpUnsafe final { public: explicit BeaverTfpUnsafe(std::shared_ptr lctx); - std::shared_ptr GetLink() const { return lctx_; } + uint128_t InitSpdzKey(FieldType field, size_t s) override; + + ArrayRef AuthArrayRef(const ArrayRef& value, FieldType field, size_t k, + size_t s) override; + + Pair AuthCoinTossing(FieldType field, size_t size, size_t k, + size_t s) override; + + Triple_Pair AuthMul(FieldType field, size_t size, size_t k, + size_t s) override; + + Triple_Pair AuthDot(FieldType field, size_t M, size_t N, size_t K, size_t k, + size_t s) override; + + Triple_Pair AuthAnd(FieldType field, size_t size, size_t s) override; - uint128_t GetSpdzKey(FieldType field, size_t s); + Pair_Pair AuthTrunc(FieldType field, size_t size, size_t bits, size_t k, + size_t s) override; - Pair AuthCoinTossing(FieldType field, size_t size, size_t s); + Pair AuthRandBit(FieldType field, size_t size, size_t k, size_t s) override; - Triple_Pair AuthMul(FieldType field, size_t size); + // Check the opened value only + bool BatchMacCheck(const ArrayRef& open_value, const ArrayRef& mac, size_t k, + size_t s); - Triple_Pair AuthDot(FieldType field, size_t M, size_t N, size_t K); + // Open the low k_bits of value only + std::pair BatchOpen(const ArrayRef& value, + const ArrayRef& mac, size_t k, + size_t s); - Pair_Pair AuthTrunc(FieldType field, size_t size, size_t bits); + // public coin, used in malicious model, all party generate new seed, then + // get exactly the same random variable. + ArrayRef genPublCoin(FieldType field, size_t numel); }; } // namespace spu::mpc::spdz2k diff --git a/libspu/mpc/spdz2k/beaver/beaver_tfp_test.cc b/libspu/mpc/spdz2k/beaver/beaver_tfp_test.cc deleted file mode 100644 index 23fbc03f..00000000 --- a/libspu/mpc/spdz2k/beaver/beaver_tfp_test.cc +++ /dev/null @@ -1,257 +0,0 @@ -// Copyright 2021 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "libspu/mpc/spdz2k/beaver/beaver_tfp.h" - -#include "gtest/gtest.h" -#include "xtensor/xarray.hpp" -#include "yacl/link/link.h" - -#include "libspu/core/type_util.h" -#include "libspu/core/xt_helper.h" -#include "libspu/mpc/utils/ring_ops.h" -#include "libspu/mpc/utils/simulate.h" - -namespace spu::mpc::spdz2k { - -class BeaverTest - : public ::testing::TestWithParam< - std::tuple( - const std::shared_ptr& lctx)>, - size_t, FieldType, long>> { - public: - using Pair = typename BeaverTfpUnsafe::Pair; - using PairPair = typename BeaverTfpUnsafe::Pair_Pair; - using TriplePair = typename BeaverTfpUnsafe::Triple_Pair; -}; - -INSTANTIATE_TEST_SUITE_P( - BeaverTfpUnsafeTest, BeaverTest, - testing::Combine( - testing::Values([](const std::shared_ptr& lctx) { - return std::make_unique(lctx); - }), - testing::Values(4, 3, 2), - testing::Values(FieldType::FM32, FieldType::FM64, FieldType::FM128), - testing::Values(0)), // max beaver diff, - [](const testing::TestParamInfo& p) { - return fmt::format("{}x{}", std::get<1>(p.param), std::get<2>(p.param)); - }); - -// TODO(@zanxiaopeng.zxp): Add UT for mac and AuthCoinTossing api. -TEST_P(BeaverTest, Mul_large) { - const auto factory = std::get<0>(GetParam()); - const size_t kWorldSize = std::get<1>(GetParam()); - const FieldType kField = std::get<2>(GetParam()); - const int64_t kMaxDiff = std::get<3>(GetParam()); - const size_t kNumel = 10000; - - std::vector triples; - triples.resize(kWorldSize); - - utils::simulate(kWorldSize, - [&](const std::shared_ptr& lctx) { - auto beaver = factory(lctx); - triples[lctx->Rank()] = beaver->AuthMul(kField, kNumel); - }); - - auto sum_a = ring_zeros(kField, kNumel); - auto sum_b = ring_zeros(kField, kNumel); - auto sum_c = ring_zeros(kField, kNumel); - for (Rank r = 0; r < kWorldSize; r++) { - const auto& [a, b, c] = std::get<0>(triples[r]); - EXPECT_EQ(a.numel(), kNumel); - EXPECT_EQ(b.numel(), kNumel); - EXPECT_EQ(c.numel(), kNumel); - - ring_add_(sum_a, a); - ring_add_(sum_b, b); - ring_add_(sum_c, c); - } - - DISPATCH_ALL_FIELDS(kField, "_", [&]() { - auto _a = ArrayView(sum_a); - auto _b = ArrayView(sum_b); - auto _c = ArrayView(sum_c); - for (auto idx = 0; idx < sum_a.numel(); idx++) { - auto t = _a[idx] * _b[idx]; - auto err = t > _c[idx] ? t - _c[idx] : _c[idx] - t; - EXPECT_LE(err, kMaxDiff); - } - }); -} - -TEST_P(BeaverTest, Mul) { - const auto factory = std::get<0>(GetParam()); - const size_t kWorldSize = std::get<1>(GetParam()); - const FieldType kField = std::get<2>(GetParam()); - const int64_t kMaxDiff = std::get<3>(GetParam()); - const size_t kNumel = 7; - - std::vector triples; - triples.resize(kWorldSize); - - utils::simulate(kWorldSize, - [&](const std::shared_ptr& lctx) { - auto beaver = factory(lctx); - triples[lctx->Rank()] = beaver->AuthMul(kField, kNumel); - }); - - auto sum_a = ring_zeros(kField, kNumel); - auto sum_b = ring_zeros(kField, kNumel); - auto sum_c = ring_zeros(kField, kNumel); - for (Rank r = 0; r < kWorldSize; r++) { - const auto& [a, b, c] = std::get<0>(triples[r]); - EXPECT_EQ(a.numel(), kNumel); - EXPECT_EQ(b.numel(), kNumel); - EXPECT_EQ(c.numel(), kNumel); - - ring_add_(sum_a, a); - ring_add_(sum_b, b); - ring_add_(sum_c, c); - } - - DISPATCH_ALL_FIELDS(kField, "_", [&]() { - auto _a = ArrayView(sum_a); - auto _b = ArrayView(sum_b); - auto _c = ArrayView(sum_c); - for (auto idx = 0; idx < sum_a.numel(); idx++) { - auto t = _a[idx] * _b[idx]; - auto err = t > _c[idx] ? t - _c[idx] : _c[idx] - t; - EXPECT_LE(err, kMaxDiff); - } - }); -} - -TEST_P(BeaverTest, Dot) { - const auto factory = std::get<0>(GetParam()); - const size_t kWorldSize = std::get<1>(GetParam()); - const FieldType kField = std::get<2>(GetParam()); - const int64_t kMaxDiff = std::get<3>(GetParam()); - // M > N - const size_t M = 17; - const size_t N = 8; - const size_t K = 1024; - - std::vector triples; - triples.resize(kWorldSize); - - utils::simulate(kWorldSize, - [&](const std::shared_ptr& lctx) { - auto beaver = factory(lctx); - triples[lctx->Rank()] = beaver->AuthDot(kField, M, N, K); - }); - - EXPECT_EQ(triples.size(), kWorldSize); - auto sum_a = ring_zeros(kField, M * K); - auto sum_b = ring_zeros(kField, K * N); - auto sum_c = ring_zeros(kField, M * N); - for (Rank r = 0; r < kWorldSize; r++) { - const auto& [a, b, c] = std::get<0>(triples[r]); - EXPECT_EQ(a.numel(), M * K); - EXPECT_EQ(b.numel(), K * N); - EXPECT_EQ(c.numel(), M * N); - - ring_add_(sum_a, a); - ring_add_(sum_b, b); - ring_add_(sum_c, c); - } - - auto res = ring_mmul(sum_a, sum_b, M, N, K); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { - auto _r = ArrayView(res); - auto _c = ArrayView(sum_c); - for (auto idx = 0; idx < _r.numel(); idx++) { - auto err = _r[idx] > _c[idx] ? _r[idx] - _c[idx] : _c[idx] - _r[idx]; - EXPECT_LE(err, kMaxDiff); - } - }); -} - -TEST_P(BeaverTest, Dot_large) { - const auto factory = std::get<0>(GetParam()); - const size_t kWorldSize = std::get<1>(GetParam()); - const FieldType kField = std::get<2>(GetParam()); - const int64_t kMaxDiff = std::get<3>(GetParam()); - // M < N - const size_t M = 11; - const size_t N = 20; - const size_t K = 1023; - - std::vector triples; - triples.resize(kWorldSize); - - utils::simulate(kWorldSize, - [&](const std::shared_ptr& lctx) { - auto beaver = factory(lctx); - triples[lctx->Rank()] = beaver->AuthDot(kField, M, N, K); - }); - - EXPECT_EQ(triples.size(), kWorldSize); - auto sum_a = ring_zeros(kField, M * K); - auto sum_b = ring_zeros(kField, K * N); - auto sum_c = ring_zeros(kField, M * N); - for (Rank r = 0; r < kWorldSize; r++) { - const auto& [a, b, c] = std::get<0>(triples[r]); - EXPECT_EQ(a.numel(), M * K); - EXPECT_EQ(b.numel(), K * N); - EXPECT_EQ(c.numel(), M * N); - - ring_add_(sum_a, a); - ring_add_(sum_b, b); - ring_add_(sum_c, c); - } - - auto res = ring_mmul(sum_a, sum_b, M, N, K); - DISPATCH_ALL_FIELDS(kField, "_", [&]() { - auto _r = ArrayView(res); - auto _c = ArrayView(sum_c); - for (auto idx = 0; idx < _r.numel(); idx++) { - auto err = _r[idx] > _c[idx] ? _r[idx] - _c[idx] : _c[idx] - _r[idx]; - EXPECT_LE(err, kMaxDiff); - } - }); -} - -TEST_P(BeaverTest, Trunc) { - const auto factory = std::get<0>(GetParam()); - const size_t kWorldSize = std::get<1>(GetParam()); - const FieldType kField = std::get<2>(GetParam()); - const size_t kNumel = 7; - const size_t kBits = 5; - - std::vector pairs; - pairs.resize(kWorldSize); - - utils::simulate( - kWorldSize, [&](const std::shared_ptr& lctx) { - auto beaver = factory(lctx); - pairs[lctx->Rank()] = beaver->AuthTrunc(kField, kNumel, kBits); - }); - - EXPECT_EQ(pairs.size(), kWorldSize); - auto sum_a = ring_zeros(kField, kNumel); - auto sum_b = ring_zeros(kField, kNumel); - for (Rank r = 0; r < kWorldSize; r++) { - const auto& [a, b] = std::get<0>(pairs[r]); - EXPECT_EQ(a.numel(), kNumel); - EXPECT_EQ(b.numel(), kNumel); - - ring_add_(sum_a, a); - ring_add_(sum_b, b); - } - EXPECT_EQ(ring_arshift(sum_a, kBits), sum_b) << sum_a << sum_b; -} - -} // namespace spu::mpc::spdz2k diff --git a/libspu/mpc/spdz2k/beaver/beaver_tinyot.cc b/libspu/mpc/spdz2k/beaver/beaver_tinyot.cc new file mode 100644 index 00000000..09b4e6dc --- /dev/null +++ b/libspu/mpc/spdz2k/beaver/beaver_tinyot.cc @@ -0,0 +1,1025 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "libspu/mpc/spdz2k/beaver/beaver_tinyot.h" + +#include +#include +#include +#include + +#include "Eigen/Core" +#include "yacl/base/dynamic_bitset.h" +#include "yacl/crypto/primitives/ot/ot_store.h" +#include "yacl/crypto/tools/prg.h" +#include "yacl/crypto/utils/rand.h" +#include "yacl/link/link.h" +#include "yacl/utils/matrix_utils.h" +#include "yacl/utils/serialize.h" + +#include "libspu/mpc/common/prg_tensor.h" +#include "libspu/mpc/spdz2k/commitment.h" +#include "libspu/mpc/spdz2k/ot/kos_ote.h" +#include "libspu/mpc/spdz2k/ot/tiny_ot.h" +#include "libspu/mpc/utils/ring_ops.h" + +namespace spu::mpc::spdz2k { +namespace { + +// sqrt2k algorithm find the smallest root for residue in ring2K +// Polynomial time algorithm to find the root +// reference +// https://github.com/sagemath/sage/blob/2114066f877a28b7473bf9242b1bb11931f3ec3e/src/sage/rings/finite_rings/integer_mod.pyx#L3943 +uint128_t inline Sqrt2k(uint128_t residue, uint128_t bits) { + uint128_t x = 1; + uint128_t N = residue; + SPU_ENFORCE((N & 7) == 1); + while (x < 8 && (N & 31) != ((x * x) & 31)) { + x += 2; + } + uint128_t t = (N - x * x) >> 5; + for (size_t i = 4; i < bits; ++i) { + if (t & 1) { + x |= (uint128_t)1 << i; + t -= x - ((uint128_t)1 << (i - 1)); + } + t >>= 1; + } + + uint128_t half_mod = (uint128_t)1 << (bits - 1); + uint128_t mask = half_mod + (half_mod - 1); + auto l = [&mask](uint128_t val) { return val & mask; }; + return std::min({l(x), l(x + half_mod), l(-x), l(-x + half_mod)}); +} + +ArrayRef ring_sqrt2k(const ArrayRef& x, size_t bits = 0) { + const auto field = x.eltype().as()->field(); + const auto numel = x.numel(); + if (bits == 0) { + bits = SizeOf(field) * 8; + } + + ArrayRef ret = ring_zeros(field, x.numel()); + DISPATCH_ALL_FIELDS(field, "_", [&]() { + using U = std::make_unsigned::type; + + auto x_data = ArrayView(x); + auto ret_data = ArrayView(ret); + yacl::parallel_for(0, numel, 4096, [&](int64_t beg, int64_t end) { + for (int64_t idx = beg; idx < end; ++idx) { + ret_data[idx] = Sqrt2k(x_data[idx], bits); + } + }); + }); + return ret; +} + +// reference https://github.com/data61/MP-SPDZ/blob/master/Math/Z2k.hpp +uint128_t inline Invert2k(const uint128_t value, const size_t bits) { + SPU_ENFORCE((value & 1) == 1); + uint128_t ret = 1; + for (size_t i = 0; i < bits; ++i) { + if (!((value * ret >> i) & 1)) { + ret += (uint128_t)1 << i; + } + } + return ret; +} + +ArrayRef ring_inv2k(const ArrayRef& x, size_t bits = 0) { + const auto field = x.eltype().as()->field(); + const auto numel = x.numel(); + if (bits == 0) { + bits = SizeOf(field) * 8; + } + + ArrayRef ret = ring_zeros(field, x.numel()); + DISPATCH_ALL_FIELDS(field, "_", [&]() { + using U = std::make_unsigned::type; + + auto x_data = ArrayView(x); + auto ret_data = ArrayView(ret); + yacl::parallel_for(0, numel, 4096, [&](int64_t beg, int64_t end) { + for (int64_t idx = beg; idx < end; ++idx) { + ret_data[idx] = Invert2k(x_data[idx], bits); + } + }); + }); + return ret; +} + +std::vector ring_cast_vector_boolean(const ArrayRef& x) { + const auto field = x.eltype().as()->field(); + + std::vector res(x.numel()); + DISPATCH_ALL_FIELDS(field, "RingOps", [&]() { + auto x_eigen = Eigen::Map, 0, + Eigen::InnerStride>( + &x.at(0), x.numel(), + Eigen::InnerStride(x.stride())); + yacl::parallel_for(0, x.numel(), 4096, [&](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + res[i] = static_cast(x_eigen[i] & 0x1); + } + }); + }); + return res; +} + +} // namespace + +BeaverTinyOt::BeaverTinyOt(std::shared_ptr lctx) + : seed_(yacl::crypto::SecureRandSeed()) { + comm_ = std::make_shared(lctx); + prg_state_ = std::make_shared(lctx); + spdz2k_ot_primitives_ = std::make_shared(comm_); + + auto buf = yacl::SerializeUint128(seed_); + std::vector all_bufs = + yacl::link::Gather(lctx, buf, 0, "BEAVER_TINY:SYNC_SEEDS"); + + if (comm_->getRank() == 0) { + // Collects seeds from all parties. + for (size_t rank = 0; rank < comm_->getWorldSize(); ++rank) { + PrgSeed seed = yacl::DeserializeUint128(all_bufs[rank]); + tp_.setSeed(rank, comm_->getWorldSize(), seed); + } + } + + auto recv_opts_choices = yacl::dynamic_bitset(kappa_); + auto recv_opts_blocks = std::vector(kappa_); + + auto send_opts_blocks = std::vector>(kappa_); + + if (comm_->getRank() == 0) { + yacl::crypto::BaseOtRecv(comm_->lctx(), recv_opts_choices, + absl::MakeSpan(recv_opts_blocks)); + yacl::crypto::BaseOtSend(comm_->lctx(), absl::MakeSpan(send_opts_blocks)); + } else { + yacl::crypto::BaseOtSend(comm_->lctx(), absl::MakeSpan(send_opts_blocks)); + yacl::crypto::BaseOtRecv(comm_->lctx(), recv_opts_choices, + absl::MakeSpan(recv_opts_blocks)); + } + + recv_opts_ = std::make_shared( + yacl::crypto::MakeOtRecvStore(recv_opts_choices, recv_opts_blocks)); + + send_opts_ = std::make_shared( + yacl::crypto::MakeOtSendStore(send_opts_blocks)); + + // the choices of BaseOT options would be the delta in delta OT + // which means that delta is the "key" in TinyOT + tinyot_key_ = 0; + for (size_t k = 0; k < kappa_; ++k) { + if (recv_opts_->GetChoice(k)) { + tinyot_key_ |= (uint128_t)1 << k; + } + } +} + +uint128_t BeaverTinyOt::InitSpdzKey(FieldType field, size_t s) { + spdz_key_ = yacl::crypto::SecureRandSeed(); + spdz_key_ &= ((uint128_t)1 << s) - 1; + return spdz_key_; +} + +// Refer to: +// Fig. 11 Protocol for authenticating secret-shared values +// SPDZ2k: Efficient MPC mod 2k for Dishonest Majority +// - https://eprint.iacr.org/2018/482.pdf +ArrayRef BeaverTinyOt::AuthArrayRef(const ArrayRef& x, FieldType field, + size_t k, size_t s) { + return DISPATCH_ALL_FIELDS(field, "_", [&]() { + using T = ring2k_t; + + // 1. l_ = max(l, r + s, 2s) + SPDLOG_DEBUG("AuthArrayRef start with numel {}", x.numel()); + const int l = k + s; + const int r = k; + int l_ = std::max(l, r + static_cast(s)); + l_ = std::max(l_, 2 * static_cast(s)); + l_ = std::min(l_, static_cast(SizeOf(field) * 8)); + SPU_ENFORCE(l_ >= static_cast(SizeOf(field) * 8), "k = s"); + + // 2. sample random masks + int64_t t = x.numel(); + size_t new_numel = t + 1; + ArrayRef x_hat(x.eltype(), new_numel); + auto x_mask = ring_rand(field, 1); + for (int i = 0; i < t; ++i) { + x_hat.at(i) = x.at(i); + } + x_hat.at(t) = x_mask.at(0); + + // 3. every pair calls vole && 4. receives vole output + size_t WorldSize = comm_->getWorldSize(); + size_t rank = comm_->getRank(); + + std::vector a, b; + auto alpha = ring_mul(ring_ones(field, new_numel), spdz_key_); + for (size_t i = 0; i < WorldSize; ++i) { + for (size_t j = 0; j < WorldSize; ++j) { + if (i == j) { + continue; + } + + if (i == rank) { + auto tmp = voleRecv(field, alpha); + a.emplace_back(tmp); + } + if (j == rank) { + auto tmp = voleSend(field, x_hat); + b.emplace_back(tmp); + } + } + } + + // 5. each party defines the MAC share + auto a_b = ring_zeros(field, new_numel); + for (size_t i = 0; i < WorldSize - 1; ++i) { + ring_add_(a_b, ring_sub(a[i], b[i])); + } + + auto m = ring_add(ring_mul(x_hat, spdz_key_), a_b); + + // Consistency check + // 6. get l public random values + auto pub_r = prg_state_->genPubl(field, new_numel); + std::vector rv; + size_t numel = x.numel(); + for (size_t i = 0; i < numel; ++i) { + rv.emplace_back(pub_r.at(i)); + } + rv.emplace_back(1); + + // 7. caculate x_angle && 8. caculate m_angle + T x_angle = 0; + T m_angle = 0; + for (size_t i = 0; i < new_numel; ++i) { + // x_hat, not x + x_angle += rv[i] * x_hat.at(i); + m_angle += rv[i] * m.at(i); + } + + auto x_angle_sum = + comm_->allReduce(std::vector{x_angle}, "allReduce x_ref"); + + // 9. commmit and open + auto z = m_angle - x_angle_sum[0] * spdz_key_; + std::string z_str((char*)&z, sizeof(z)); + std::vector recv_strs; + SPU_ENFORCE(commit_and_open(comm_->lctx(), z_str, &recv_strs)); + SPU_ENFORCE(recv_strs.size() == WorldSize); + + // 10. check + T plain_z = 0; + for (const auto& str : recv_strs) { + T t = *(reinterpret_cast(str.data())); + plain_z += t; + } + + SPU_ENFORCE(plain_z == 0); + + // 11. output MAC share + return m.slice(0, m.numel() - 1); + }); +} + +BeaverTinyOt::Pair BeaverTinyOt::AuthCoinTossing(FieldType field, size_t size, + size_t k, size_t s) { + auto rand = ring_rand(field, size); + auto mac = AuthArrayRef(rand, field, k, s); + return {rand, mac}; +} + +// Refer to: +// New Primitives for Actively-Secure MPC over Rings with Applications to +// Private Machine Learning. +// Figure 2: TinyOT share to binary SPDZ2K share conversion. +// - https://eprint.iacr.org/2019/599.pdf +BeaverTinyOt::Triple_Pair BeaverTinyOt::AuthAnd(FieldType field, size_t size, + size_t s) { + const size_t elsize = SizeOf(field); + const size_t tinyot_num = size; + // extra sigma bits = 64 + const size_t sigma = 64; + + auto [auth_a, auth_b, auth_c] = + TinyMul(comm_, send_opts_, recv_opts_, tinyot_num, tinyot_key_); + + // we need extra sigma bits to check + auto auth_r = RandomBits(comm_, send_opts_, recv_opts_, sigma, tinyot_key_); + + // For convenient, put a,b,c,r together + // Then authorize them in SPDZ2k form + // todo: maybe we can use uint64_t in FM64 + AuthBit auth_abcr{std::vector(3 * tinyot_num + sigma, false), + std::vector(3 * tinyot_num + sigma, 0), + tinyot_key_}; + for (size_t i = 0; i < tinyot_num; ++i) { + auth_abcr.choices[i] = auth_a.choices[i]; + auth_abcr.choices[tinyot_num + i] = auth_b.choices[i]; + auth_abcr.choices[tinyot_num * 2 + i] = auth_c.choices[i]; + } + for (size_t i = 0; i < sigma; ++i) { + auth_abcr.choices[tinyot_num * 3 + i] = auth_r.choices[i]; + } + std::memcpy(&auth_abcr.mac[0], &auth_a.mac[0], + tinyot_num * sizeof(uint128_t)); + std::memcpy(&auth_abcr.mac[tinyot_num], &auth_b.mac[0], + tinyot_num * sizeof(uint128_t)); + std::memcpy(&auth_abcr.mac[tinyot_num * 2], &auth_c.mac[0], + tinyot_num * sizeof(uint128_t)); + std::memcpy(&auth_abcr.mac[tinyot_num * 3], &auth_r.mac[0], + sigma * sizeof(uint128_t)); + + // Generate authorize bits in the form of B-Share + ArrayRef spdz_choices(makeType(field), tinyot_num * 3 + sigma); + + DISPATCH_ALL_FIELDS(field, "_", [&]() { + using U = std::make_unsigned::type; + + auto _choices = ArrayView(spdz_choices); + auto _size = auth_abcr.choices.size(); + // copy authbit choices + yacl::parallel_for(0, _size, 4096, [&](int64_t beg, int64_t end) { + for (int64_t idx = beg; idx < end; ++idx) { + _choices[idx] = auth_abcr.choices[idx]; + } + }); + }); + + ArrayRef spdz_mac(makeType(field), tinyot_num * 3 + sigma); + ArrayRef mask0(makeType(field), tinyot_num * 3 + sigma); + ArrayRef mask1(makeType(field), tinyot_num * 3 + sigma); + ArrayRef t(makeType(field), tinyot_num * 3 + sigma); + auto ext_spdz_key = + ring_mul(ring_ones(field, tinyot_num * 3 + sigma), spdz_key_); + + if (comm_->getRank() == 0) { + rotRecv(field, spdz_choices, &t); + auto recv = comm_->recv(comm_->nextRank(), makeType(field), "recv"); + + rotSend(field, &mask0, &mask1); + auto diff = ring_add(ring_sub(mask0, mask1), ext_spdz_key); + comm_->sendAsync(comm_->nextRank(), diff, "send"); + spdz_mac = ring_add(t, ring_mul(spdz_choices, recv)); + } else { + rotSend(field, &mask0, &mask1); + auto diff = ring_add(ring_sub(mask0, mask1), ext_spdz_key); + comm_->sendAsync(comm_->nextRank(), diff, "send"); + + rotRecv(field, spdz_choices, &t); + auto recv = comm_->recv(comm_->nextRank(), makeType(field), "recv"); + spdz_mac = ring_add(t, ring_mul(spdz_choices, recv)); + } + spdz_mac = ring_sub(spdz_mac, mask0); + spdz_mac = ring_add(spdz_mac, ring_mul(spdz_choices, ext_spdz_key)); + + AuthBit check_tiny_bit = {std::vector(sigma, false), + std::vector(sigma, 0), tinyot_key_}; + ArrayRef check_spdz_bit = ring_zeros(field, sigma); + ArrayRef check_spdz_mac = ring_zeros(field, sigma); + auto seed = GenSharedSeed(comm_); + auto prg = yacl::crypto::Prg(seed); + + DISPATCH_ALL_FIELDS(field, "_", [&]() { + using U = std::make_unsigned::type; + + auto _spdz_bit = ArrayView(spdz_choices); + auto _spdz_mac = ArrayView(spdz_mac); + auto _check_spdz_bit = ArrayView(check_spdz_bit); + auto _check_spdz_mac = ArrayView(check_spdz_mac); + + for (size_t i = 0; i < sigma; ++i) { + _check_spdz_bit[i] = _spdz_bit[3 * tinyot_num + i]; + _check_spdz_mac[i] = _spdz_mac[3 * tinyot_num + i]; + check_tiny_bit.mac[i] = auth_abcr.mac[tinyot_num * 3 + i]; + } + for (size_t j = 0; j < tinyot_num * 3; ++j) { + // we can ignore check_tiny_bit.choices + uint64_t ceof = prg(); + // sigma = 64 + for (size_t i = 0; i < sigma; ++i) { + if (ceof & 1) { + check_tiny_bit.mac[i] ^= auth_abcr.mac[j]; + _check_spdz_bit[i] += _spdz_bit[j]; + _check_spdz_mac[i] += _spdz_mac[j]; + } + ceof >>= 1; + } + } + }); + + // Open sigma bits + auto [open_bit, zero_mac] = BatchOpen(check_spdz_bit, check_spdz_mac, 1, s); + check_tiny_bit.choices = ring_cast_vector_boolean(open_bit); + + // TINY Maccheck & SPDZ Maccheck!! + size_t k = s; + SPU_ENFORCE(TinyMacCheck(comm_, check_tiny_bit.choices, check_tiny_bit)); + SPU_ENFORCE(BatchMacCheck(open_bit, zero_mac, k, s)); + + // Pack a,b,c and their mac + auto a = + ArrayRef(spdz_choices.buf(), spdz_choices.eltype(), tinyot_num, 1, 0); + auto b = ArrayRef(spdz_choices.buf(), spdz_choices.eltype(), tinyot_num, 1, + tinyot_num * elsize); + auto c = ArrayRef(spdz_choices.buf(), spdz_choices.eltype(), tinyot_num, 1, + 2 * tinyot_num * elsize); + + auto a_mac = ArrayRef(spdz_mac.buf(), spdz_mac.eltype(), tinyot_num, 1, 0); + auto b_mac = ArrayRef(spdz_mac.buf(), spdz_mac.eltype(), tinyot_num, 1, + tinyot_num * elsize); + auto c_mac = ArrayRef(spdz_mac.buf(), spdz_mac.eltype(), tinyot_num, 1, + 2 * tinyot_num * elsize); + + return {{a, b, c}, {a_mac, b_mac, c_mac}}; +} + +BeaverTinyOt::Triple BeaverTinyOt::dot(FieldType field, size_t M, size_t N, + size_t K, size_t k, size_t s) { + size_t WorldSize = comm_->getWorldSize(); + size_t rank = comm_->getRank(); + + auto a = ring_rand(field, M * K); + auto b = ring_rand(field, K * N); + ring_bitmask_(a, 0, k); + ring_bitmask_(b, 0, k); + + auto c = ring_mmul(a, b, M, N, K); + + // w = a * b + v + std::vector w; + std::vector v; + // every pair calls voleDot + for (size_t i = 0; i < WorldSize; ++i) { + for (size_t j = 0; j < WorldSize; ++j) { + if (i == j) { + continue; + } + if (i == rank) { + auto tmp = voleRecvDot(field, b, M, N, K); + w.emplace_back(tmp); + } + if (j == rank) { + auto tmp = voleSendDot(field, a, M, N, K); + v.emplace_back(tmp); + } + } + } + + for (size_t i = 0; i < WorldSize - 1; ++i) { + ring_add_(c, ring_sub(w[i], v[i])); + } + return {a, b, c}; +} + +// Refer to: +// 6 PreProcessing: Creating Multiplication Triples, +// SPDZ2k: Efficient MPC mod 2k for Dishonest Majority +// - https://eprint.iacr.org/2018/482.pdf +BeaverTinyOt::Triple_Pair BeaverTinyOt::AuthDot(FieldType field, size_t M, + size_t N, size_t K, size_t k, + size_t s) { + // Dot + auto [a_ext, b, c_ext] = dot(field, 2 * M, N, K, k, s); + + // Authenticate + auto a_ext_mac = AuthArrayRef(a_ext, field, k, s); + auto b_mac = AuthArrayRef(b, field, k, s); + auto c_ext_mac = AuthArrayRef(c_ext, field, k, s); + + auto a = a_ext.slice(0, M * K, 1); + auto a_mac = a_ext_mac.slice(0, M * K, 1); + auto c = c_ext.slice(0, M * N, 1); + auto c_mac = c_ext_mac.slice(0, M * N, 1); + + // Sacrifice + auto a2 = a_ext.slice(M * K, 2 * M * K, 1); + auto a2_mac = a_ext_mac.slice(M * K, 2 * M * K, 1); + auto c2 = c_ext.slice(M * N, 2 * M * N, 1); + auto c2_mac = c_ext_mac.slice(M * N, 2 * M * N, 1); + + auto t = prg_state_->genPubl(field, M * M); + auto rou = ring_sub(ring_mmul(t, a, M, K, M), a2); + auto rou_mac = ring_sub(ring_mmul(t, a_mac, M, K, M), a2_mac); + + auto [pub_rou, check_rou_mac] = BatchOpen(rou, rou_mac, k, s); + SPU_ENFORCE(BatchMacCheck(pub_rou, check_rou_mac, k, s)); + + auto t_delta = ring_sub(ring_mmul(t, c, M, N, M), c2); + auto delta = ring_sub(t_delta, ring_mmul(pub_rou, b, M, N, K)); + + auto t_delta_mac = ring_sub(ring_mmul(t, c_mac, M, N, M), c2_mac); + auto delta_mac = ring_sub(t_delta_mac, ring_mmul(pub_rou, b_mac, M, N, K)); + + auto [pub_delta, check_delta_mac] = BatchOpen(delta, delta_mac, k, s); + SPU_ENFORCE(BatchMacCheck(pub_delta, check_delta_mac, k, s)); + + // Output + return {{a, b, c}, {a_mac, b_mac, c_mac}}; +} + +BeaverTinyOt::Pair_Pair BeaverTinyOt::AuthTrunc(FieldType field, size_t size, + size_t bits, size_t k, + size_t s) { + size_t nbits = k; + + auto [b_val, b_mac] = AuthRandBit(field, nbits * size, k, s); + + // compose + ArrayRef r_val(b_val.eltype(), size); + ArrayRef r_mac(b_val.eltype(), size); + ArrayRef tr_val(b_val.eltype(), size); + ArrayRef tr_mac(b_val.eltype(), size); + + DISPATCH_ALL_FIELDS(field, "_", [&]() { + using PShrT = ring2k_t; + auto _val = ArrayView(b_val); + auto _mac = ArrayView(b_mac); + auto _r_val = ArrayView(r_val); + auto _r_mac = ArrayView(r_mac); + auto _tr_val = ArrayView(tr_val); + auto _tr_mac = ArrayView(tr_mac); + pforeach(0, size, [&](int64_t idx) { + _r_val[idx] = 0; + _r_mac[idx] = 0; + _tr_val[idx] = 0; + _tr_mac[idx] = 0; + for (size_t bit = 0; bit < nbits; bit++) { + size_t flat_idx = idx * nbits + bit; + _r_val[idx] += _val[flat_idx] << bit; + _r_mac[idx] += _mac[flat_idx] << bit; + } + for (size_t bit = 0; bit + bits < nbits; bit++) { + size_t flat_idx = idx * nbits + bits + bit; + _tr_val[idx] += _val[flat_idx] << bit; + _tr_mac[idx] += _mac[flat_idx] << bit; + } + + for (size_t bit = nbits - bits; bit < nbits; bit++) { + size_t flat_idx = idx * nbits + nbits - 1; + _tr_val[idx] += _val[flat_idx] << bit; + _tr_mac[idx] += _mac[flat_idx] << bit; + } + }); + }); + + return {{r_val, tr_val}, {r_mac, tr_mac}}; +} + +// Refer to: +// New Primitives for Actively-Secure MPC over Rings with Applications to +// Private Machine Learning. +// Figure 5: Protocol for obtaining authenticated shared bits +// - https://eprint.iacr.org/2019/599.pdf +BeaverTinyOt::Pair BeaverTinyOt::AuthRandBit(FieldType field, size_t size, + size_t k, size_t s) { + auto u = ring_rand(field, size); + ring_bitmask_(u, 0, k + 2); + auto u_mac = AuthArrayRef(u, field, k + 2, s); + + auto y = ring_mul(u, 2); + auto y_mac = ring_mul(u_mac, 2); + auto ones = ring_ones(field, size); + auto ones_mac = ring_mul(ones, spdz_key_); + + if (comm_->getRank() == 0) { + ring_add_(y, ones); + } + ring_add_(y_mac, ones_mac); + + auto [beaver_vec, beaver_mac] = AuthMul(field, size, k, s); + auto& [a, b, c] = beaver_vec; + auto& [a_mac, b_mac, c_mac] = beaver_mac; + + auto e = ring_sub(y, a); + auto e_mac = ring_sub(y_mac, a_mac); + auto f = ring_sub(y, b); + auto f_mac = ring_sub(y_mac, b_mac); + + // Open the least significant bit and Check them + auto [p_e, pe_mac] = BatchOpen(e, e_mac, k + 2, s); + auto [p_f, pf_mac] = BatchOpen(f, f_mac, k + 2, s); + + SPU_ENFORCE(BatchMacCheck(p_e, pe_mac, k, s)); + SPU_ENFORCE(BatchMacCheck(p_f, pf_mac, k, s)); + + // Reserve the least significant bit only + ring_bitmask_(p_e, 0, k + 2); + ring_bitmask_(p_f, 0, k + 2); + auto p_ef = ring_mul(p_e, p_f); + + // z = p_e * b + p_f * a + c; + auto z = ring_add(ring_mul(p_e, b), ring_mul(p_f, a)); + ring_add_(z, c); + if (comm_->getRank() == 0) { + // z += p_e * p_f; + ring_add_(z, p_ef); + } + + // z_mac = p_e * b_mac + p_f * a_mac + c_mac + p_e * p_f * key; + auto z_mac = ring_add(ring_mul(p_e, b_mac), ring_mul(p_f, a_mac)); + ring_add_(z_mac, c_mac); + ring_add_(z_mac, ring_mul(p_ef, spdz_key_)); + + auto [square, zero_mac] = BatchOpen(z, z_mac, k + 2, s); + SPU_ENFORCE(BatchMacCheck(square, zero_mac, k, s)); + SPU_ENFORCE(ring_all_equal(ring_bitmask(square, 0, 1), ones)); + auto root = ring_sqrt2k(square, k + 2); + auto root_inv = ring_inv2k(root, k + 2); + auto root_inv_div2 = ring_rshift(root_inv, 1); + + auto d = ring_mul(root_inv_div2, y); + auto d_mac = ring_mul(root_inv_div2, y_mac); + ring_add_(d, u); + ring_add_(d_mac, u_mac); + if (comm_->getRank() == 0) { + ring_add_(d, ones); + } + ring_add_(d_mac, ones_mac); + + return {d, d_mac}; +} + +ArrayRef BeaverTinyOt::genPublCoin(FieldType field, size_t numel) { + ArrayRef res(makeType(field), numel); + + // generate new seed + uint128_t seed = yacl::crypto::SecureRandSeed(); + std::vector all_strs; + + std::string seed_str(reinterpret_cast(&seed), sizeof(seed)); + SPU_ENFORCE(commit_and_open(comm_->lctx(), seed_str, &all_strs)); + + uint128_t public_seed = 0; + for (const auto& str : all_strs) { + uint128_t seed = *(reinterpret_cast(str.data())); + public_seed += seed; + } + + const auto kAesType = yacl::crypto::SymmetricCrypto::CryptoType::AES128_CTR; + yacl::crypto::FillPRand( + kAesType, public_seed, 0, 0, + absl::MakeSpan(static_cast(res.data()), res.buf()->size())); + + return res; +} + +// Refer to: +// Procedure BatchCheck, 3.2 Batch MAC Checking with Random Linear +// Combinations, SPDZ2k: Efficient MPC mod 2k for Dishonest Majority +// - https://eprint.iacr.org/2018/482.pdf +// +// Check the opened value only +bool BeaverTinyOt::BatchMacCheck(const ArrayRef& open_value, + const ArrayRef& mac, size_t k, size_t s) { + SPDLOG_DEBUG("BatchMacCheck start..."); + SPU_ENFORCE(open_value.numel() == mac.numel()); + const auto field = open_value.eltype().as()->field(); + const size_t mac_bits = k + s; + const size_t key = spdz_key_; + size_t num = open_value.numel(); + + // 1. Generate ceof + auto coef = genPublCoin(field, num); + ring_bitmask_(coef, 0, s); + + // 3. check_value = coef * open_value + // check_mac = coef * mac + auto check_value = ring_mmul(coef, open_value, 1, 1, num); + auto check_mac = ring_mmul(coef, mac, 1, 1, num); + + // 4. local_mac = check_mac - check_value * key + auto local_mac = ring_sub(check_mac, ring_mul(check_value, key)); + // commit and reduce all macs + std::string mac_str(reinterpret_cast(local_mac.data()), + local_mac.numel() * local_mac.elsize()); + std::vector all_mac_strs; + SPU_ENFORCE(commit_and_open(comm_->lctx(), mac_str, &all_mac_strs)); + SPU_ENFORCE(all_mac_strs.size() == comm_->getWorldSize()); + + // 5. compute the sum of all macs + auto zero_mac = ring_zeros(field, 1); + for (size_t i = 0; i < comm_->getWorldSize(); ++i) { + const auto& _mac_str = all_mac_strs[i]; + auto buf = std::make_shared(_mac_str.data(), _mac_str.size()); + ArrayRef _mac(buf, zero_mac.eltype(), _mac_str.size() / SizeOf(field), 1, + 0); + ring_add_(zero_mac, _mac); + } + + // 6. In B-share, the range of Mac is Z_2^{s+1} + if (mac_bits != 0) { + ring_bitmask_(zero_mac, 0, mac_bits); + } + + // 7. verify whether the sum of all macs is zero + auto res = ring_all_equal(zero_mac, ring_zeros(field, 1)); + SPDLOG_DEBUG("BatchMacCheck end with ret {}.", res); + return res; +} + +// Refer to: +// Procedure BatchCheck, 3.2 Batch MAC Checking with Random Linear +// Combinations, SPDZ2k: Efficient MPC mod 2k for Dishonest Majority +// - https://eprint.iacr.org/2018/482.pdf +// +// Open the value only +// Notice return { open_val , zero_mac = open_val * \sum spdz_key_ } +// the last kth bits in open_val is valid +std::pair BeaverTinyOt::BatchOpen(const ArrayRef& value, + const ArrayRef& mac, + size_t k, size_t s) { + static constexpr char kBindName[] = "batch_open"; + SPU_ENFORCE(value.numel() == mac.numel()); + const auto field = value.eltype().as()->field(); + size_t field_bits = std::min(SizeOf(field) * 8, (size_t)64); + auto [r_val, r_mac] = AuthCoinTossing(field, value.numel(), field_bits, s); + // Open the low k_bits only + // value = value + r * 2^k + // mac = mac + r_mac * 2^k + auto masked_val = ring_add(value, ring_lshift(r_val, k)); + auto masked_mac = ring_add(mac, ring_lshift(r_mac, k)); + + // Because we would use Maccheck to comfirm the open value. + // Thus, we don't need commit them. + auto open_val = comm_->allReduce(ReduceOp::ADD, masked_val, kBindName); + return {open_val, masked_mac}; +} + +void BeaverTinyOt::rotSend(FieldType field, ArrayRef* q0, ArrayRef* q1) { + DISPATCH_ALL_FIELDS(field, "_", [&]() { + using T = ring2k_t; + + SPDLOG_DEBUG("rotSend start with numel {}", q0->numel()); + SPU_ENFORCE(q0->numel() == q1->numel()); + size_t numel = q0->numel(); + T* data0 = reinterpret_cast(q0->data()); + T* data1 = reinterpret_cast(q1->data()); + + SPU_ENFORCE(spdz2k_ot_primitives_ != nullptr); + SPU_ENFORCE(spdz2k_ot_primitives_->GetSenderCOT() != nullptr); + + spdz2k_ot_primitives_->GetSenderCOT()->SendRMCC( + absl::MakeSpan(data0, numel), absl::MakeSpan(data1, numel)); + spdz2k_ot_primitives_->GetSenderCOT()->Flush(); + + SPDLOG_DEBUG("rotSend end"); + }); +} + +// todo: use dynamic_bitset instead of ArrayRef for `a` to improve performance +void BeaverTinyOt::rotRecv(FieldType field, const ArrayRef& a, ArrayRef* s) { + DISPATCH_ALL_FIELDS(field, "_", [&]() { + using T = ring2k_t; + + SPDLOG_DEBUG("rotRecv start with numel {}", a.numel()); + size_t numel = a.numel(); + std::vector b_v(numel); + for (size_t i = 0; i < numel; ++i) { + b_v[i] = a.at(i); + } + + SPU_ENFORCE(spdz2k_ot_primitives_ != nullptr); + SPU_ENFORCE(spdz2k_ot_primitives_->GetSenderCOT() != nullptr); + SPU_ENFORCE(spdz2k_ot_primitives_->GetReceiverCOT() != nullptr); + + T* data = reinterpret_cast(s->data()); + spdz2k_ot_primitives_->GetReceiverCOT()->RecvRMCC( + b_v, absl::MakeSpan(data, numel)); + spdz2k_ot_primitives_->GetReceiverCOT()->Flush(); + + SPDLOG_DEBUG("rotRecv end"); + }); +} + +// Refer to: +// Appendix C. Implementing Vector-OLE mod 2^l, P35 +// SPDZ2k: Efficient MPC mod 2k for Dishonest Majority +// - https://eprint.iacr.org/2018/482.pdf +ArrayRef BeaverTinyOt::voleSend(FieldType field, const ArrayRef& x) { + return DISPATCH_ALL_FIELDS(field, "_", [&]() { + using T = ring2k_t; + + SPU_ENFORCE(spdz2k_ot_primitives_ != nullptr); + SPU_ENFORCE(spdz2k_ot_primitives_->GetSenderCOT() != nullptr); + + size_t numel = x.numel(); + ArrayRef res(x.eltype(), numel); + T* data = reinterpret_cast(res.data()); + spdz2k_ot_primitives_->GetSenderCOT()->SendVole( + absl::MakeConstSpan(reinterpret_cast(x.data()), numel), + absl::MakeSpan(data, numel)); + + return res; + }); +} + +ArrayRef BeaverTinyOt::voleRecv(FieldType field, const ArrayRef& alpha) { + return DISPATCH_ALL_FIELDS(field, "_", [&]() { + using T = ring2k_t; + + SPU_ENFORCE(spdz2k_ot_primitives_ != nullptr); + SPU_ENFORCE(spdz2k_ot_primitives_->GetReceiverCOT() != nullptr); + + size_t size = alpha.numel(); + ArrayRef res(makeType(field), size); + T* data = reinterpret_cast(res.data()); + spdz2k_ot_primitives_->GetReceiverCOT()->RecvVole( + absl::MakeConstSpan(reinterpret_cast(alpha.data()), + alpha.numel()), + absl::MakeSpan(data, size)); + + return res; + }); +} + +// Private Matrix Multiplication by VOLE +// W = V + A dot B +// Sender: input A, receive V +// +// Input: (M, K) matrix +// Output: (M, N) matrix +ArrayRef BeaverTinyOt::voleSendDot(FieldType field, const ArrayRef& x, size_t M, + size_t N, size_t K) { + SPU_ENFORCE(x.numel() == static_cast(M * K)); + + auto ret = ring_zeros(field, M * N); + for (size_t i = 0; i < N; ++i) { + // t: (M, K) matrix + auto t = voleSend(field, x); + + // process the matrix + auto ret_col = ret.slice(i, M * N, N); + for (size_t j = 0; j < K; ++j) { + ring_add_(ret_col, t.slice(j, M * K, K)); + } + } + + return ret; +} + +// Private Matrix Multiplication by VOLE +// W = V + A dot B +// Receiver: input B, receive W +// +// Input: (K, N) matrix +// Output: (M, N) matrix +ArrayRef BeaverTinyOt::voleRecvDot(FieldType field, const ArrayRef& alpha, + size_t M, size_t N, size_t K) { + SPU_ENFORCE(alpha.numel() == static_cast(K * N)); + + auto ret = ring_zeros(field, M * N); + for (size_t i = 0; i < N; ++i) { + auto alpha_col = alpha.slice(i, K * N, N); + + ArrayRef alpha_ext(alpha.eltype(), M * K); + for (size_t i = 0; i < M; ++i) { + auto alpha_ext_row = alpha_ext.slice(i * K, (i + 1) * K, 1); + ring_assign(alpha_ext_row, alpha_col); + } + + // t: (m, k) matrix + auto t = voleRecv(field, alpha_ext); + + // process the matrix + auto ret_col = ret.slice(i, M * N, N); + for (size_t j = 0; j < K; ++j) { + ring_add_(ret_col, t.slice(j, M * K, K)); + } + } + + return ret; +} + +// Refer to: +// 6 PreProcessing: Creating Multiplication Triples, +// SPDZ2k: Efficient MPC mod 2k for Dishonest Majority +// - https://eprint.iacr.org/2018/482.pdf +BeaverTinyOt::Triple_Pair BeaverTinyOt::AuthMul(FieldType field, size_t size, + size_t k, size_t s) { + return DISPATCH_ALL_FIELDS(field, "_", [&]() { + using T = ring2k_t; + + SPDLOG_DEBUG("AuthMul start..."); + size_t tao = 4 * s + 2 * k; + size_t expand_tao = tao * size; + auto a = ring_randbit(field, expand_tao); + + auto b = ring_rand(field, size); + auto b_arr = ring_zeros(field, expand_tao); + for (size_t i = 0; i < expand_tao; ++i) { + b_arr.at(i) = b.at(i / tao); + } + + // Every ordered pair does following + size_t WorldSize = comm_->getWorldSize(); + size_t rank = comm_->getRank(); + ArrayRef q0(makeType(field), expand_tao); + ArrayRef q1(makeType(field), expand_tao); + ArrayRef t_s(makeType(field), expand_tao); + + std::vector ci, cj; + + for (size_t i = 0; i < WorldSize; ++i) { + for (size_t j = 0; j < WorldSize; ++j) { + if (i == j) { + continue; + } + + if (i == rank) { + rotRecv(field, a, &t_s); + auto tmp = comm_->lctx()->Recv(j, "recv_d"); + ArrayRef recv_d(std::make_shared(tmp), a.eltype(), + a.numel(), a.stride(), a.offset()); + auto t = ring_add(t_s, ring_mul(a, recv_d)); + ci.emplace_back(t); + } + + if (j == rank) { + rotSend(field, &q0, &q1); + auto d = ring_add(ring_sub(q0, q1), b_arr); + comm_->lctx()->SendAsync(i, *(d.buf().get()), "send_d"); + cj.emplace_back(ring_neg(q0)); + } + } + } + + auto cij = ring_zeros(field, expand_tao); + auto cji = ring_zeros(field, expand_tao); + for (size_t i = 0; i < WorldSize - 1; ++i) { + ring_add_(cij, ci[i]); + ring_add_(cji, cj[i]); + } + + // Construct c + auto c = ring_mul(a, b_arr); + auto other_c = ring_add(cij, cji); + ring_add_(c, other_c); + + // Combine + auto r = prg_state_->genPubl(field, expand_tao); + auto r_hat = prg_state_->genPubl(field, expand_tao); + auto ra = ring_mul(r, a); + auto ra_hat = ring_mul(r_hat, a); + auto rc = ring_mul(r, c); + auto rc_hat = ring_mul(r_hat, c); + + ArrayRef cra = ring_zeros(field, size); + ArrayRef cra_hat = ring_zeros(field, size); + ArrayRef crc = ring_zeros(field, size); + ArrayRef crc_hat = ring_zeros(field, size); + + for (size_t i = 0; i < expand_tao; ++i) { + cra.at(i / tao) += ra.at(i); + cra_hat.at(i / tao) += ra_hat.at(i); + + crc.at(i / tao) += rc.at(i); + crc_hat.at(i / tao) += rc_hat.at(i); + } + + // Authenticate + auto a_mac = AuthArrayRef(cra, field, k, s); + auto b_mac = AuthArrayRef(b, field, k, s); + auto c_mac = AuthArrayRef(crc, field, k, s); + + auto a_hat_mac = AuthArrayRef(cra_hat, field, k, s); + auto c_hat_mac = AuthArrayRef(crc_hat, field, k, s); + + // Sacrifice + auto t = prg_state_->genPubl(field, size); + auto rou = ring_sub(ring_mul(t, cra), cra_hat); + auto rou_mac = ring_sub(ring_mul(t, a_mac), a_hat_mac); + + auto [pub_rou, check_rou_mac] = BatchOpen(rou, rou_mac, k, s); + SPU_ENFORCE(BatchMacCheck(pub_rou, check_rou_mac, k, s)); + + auto t_delta = ring_sub(ring_mul(t, crc), crc_hat); + auto delta = ring_sub(t_delta, ring_mul(b, pub_rou)); + + auto t_delta_mac = ring_sub(ring_mul(t, c_mac), c_hat_mac); + auto delta_mac = ring_sub(t_delta_mac, ring_mul(b_mac, pub_rou)); + + auto [pub_delta, check_delta_mac] = BatchOpen(delta, delta_mac, k, s); + SPU_ENFORCE(BatchMacCheck(pub_delta, check_delta_mac, k, s)); + + SPDLOG_DEBUG("AuthMul end"); + // Output + return BeaverTinyOt::Triple_Pair{{cra, b, crc}, {a_mac, b_mac, c_mac}}; + }); +} + +} // namespace spu::mpc::spdz2k diff --git a/libspu/mpc/spdz2k/beaver/beaver_tinyot.h b/libspu/mpc/spdz2k/beaver/beaver_tinyot.h new file mode 100644 index 00000000..455b3bdc --- /dev/null +++ b/libspu/mpc/spdz2k/beaver/beaver_tinyot.h @@ -0,0 +1,120 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "yacl/crypto/primitives/ot/base_ot.h" +#include "yacl/link/context.h" + +#include "libspu/mpc/common/prg_state.h" +#include "libspu/mpc/spdz2k/beaver/beaver_interface.h" +#include "libspu/mpc/spdz2k/beaver/trusted_party.h" +#include "libspu/mpc/spdz2k/ot/basic_ot_prot.h" + +namespace spu::mpc::spdz2k { + +class BeaverTinyOt final : public Beaver { + protected: + // Only for rank0 party. + TrustedParty tp_; + + std::shared_ptr comm_; + + std::shared_ptr prg_state_; + + PrgSeed seed_; + + // tinyOT alpha + uint128_t tinyot_key_; + + // spzd key + uint128_t spdz_key_; + + // base OT + std::shared_ptr recv_opts_; + std::shared_ptr send_opts_; + + // ferret ot + std::shared_ptr spdz2k_ot_primitives_; + + // security parameters + static constexpr int kappa_ = 128; + + public: + using Triple = std::tuple; + using Pair = std::pair; + using Pair_Pair = std::pair; + using Triple_Pair = std::pair; + + BeaverTinyOt(std::shared_ptr lctx); + + uint128_t InitSpdzKey(FieldType field, size_t s) override; + + ArrayRef AuthArrayRef(const ArrayRef& value, FieldType field, size_t k, + size_t s) override; + + Pair AuthCoinTossing(FieldType field, size_t size, size_t k, + size_t s) override; + + Triple_Pair AuthMul(FieldType field, size_t size, size_t k, + size_t s) override; + + Triple_Pair AuthDot(FieldType field, size_t M, size_t N, size_t K, size_t k, + size_t s) override; + + Triple_Pair AuthAnd(FieldType field, size_t size, size_t s) override; + + Pair_Pair AuthTrunc(FieldType field, size_t size, size_t bits, size_t k, + size_t s) override; + + Pair AuthRandBit(FieldType field, size_t size, size_t k, size_t s) override; + + // Check the opened value only + bool BatchMacCheck(const ArrayRef& open_value, const ArrayRef& mac, size_t k, + size_t s); + // Open the low k_bits of value only + std::pair BatchOpen(const ArrayRef& value, + const ArrayRef& mac, size_t k, + size_t s); + + // public coin, used in malicious model, all party generate new seed, then + // get exactly the same random variable. + ArrayRef genPublCoin(FieldType field, size_t numel); + + // ROT encapsulation + // s[i] = (a[i] == 0) ? q0[i] : q1[i] + void rotSend(FieldType field, ArrayRef* q0, ArrayRef* q1); + void rotRecv(FieldType field, const ArrayRef& a, ArrayRef* s); + + // Vector-OLE encapsulation + // a[i] = b[i] + x[i] * alpha[i] + // Sender: input x, receive b + // Receiver: input alpha, receive a + ArrayRef voleSend(FieldType field, const ArrayRef& x); + ArrayRef voleRecv(FieldType field, const ArrayRef& alpha); + + // Private Matrix Multiplication by VOLE + // W = V + A dot B + // Sender: input A, receive V + // Receiver: input B, receive W + ArrayRef voleSendDot(FieldType field, const ArrayRef& x, size_t M, size_t N, + size_t K); + ArrayRef voleRecvDot(FieldType field, const ArrayRef& alpha, size_t M, + size_t N, size_t K); + + // Generate semi-honest dot triple + Triple dot(FieldType field, size_t M, size_t N, size_t K, size_t k, size_t s); +}; + +} // namespace spu::mpc::spdz2k diff --git a/libspu/mpc/spdz2k/beaver/trusted_party.cc b/libspu/mpc/spdz2k/beaver/trusted_party.cc index c0a076c0..17538325 100644 --- a/libspu/mpc/spdz2k/beaver/trusted_party.cc +++ b/libspu/mpc/spdz2k/beaver/trusted_party.cc @@ -105,11 +105,34 @@ ArrayRef TrustedParty::adjustSpdzKey(const PrgArrayDesc& desc) const { } std::vector TrustedParty::adjustAuthCoinTossing( + const PrgArrayDesc& desc, const PrgArrayDesc& mac_desc, + uint128_t global_key, size_t k, size_t s) const { + SPU_ENFORCE(s <= SizeOf(desc.field) * 8); + + auto [r0, rs] = reconstruct(RecOp::ADD, getSeeds(), absl::MakeSpan(&desc, 1)); + SPU_ENFORCE(r0.size() == 1 && rs.size() == 1); + auto r = ring_bitmask(ring_rand(desc.field, desc.numel), 0, k); + ring_add_(r0[0], ring_sub(r, rs[0])); + + auto [mac_r0, mac_rs] = + reconstruct(RecOp::ADD, getSeeds(), absl::MakeSpan(&mac_desc, 1)); + SPU_ENFORCE(mac_r0.size() == 1 && mac_rs.size() == 1); + + // mac_r0[0] += r * global_key - mac_rs[0]; + auto mac = ring_mul(r, global_key); + ring_add_(mac_r0[0], ring_sub(mac, mac_rs[0])); + + return {r0[0], mac_r0[0]}; +} + +std::vector TrustedParty::adjustAuthRandBit( const PrgArrayDesc& desc, const PrgArrayDesc& mac_desc, uint128_t global_key, size_t s) const { auto [r0, rs] = reconstruct(RecOp::ADD, getSeeds(), absl::MakeSpan(&desc, 1)); SPU_ENFORCE(r0.size() == 1 && rs.size() == 1); - auto r = ring_bitmask(ring_rand(FM128, desc.numel), 0, s); + + // r0[0] += r - rs[0]; + auto r = ring_bitmask(ring_rand(desc.field, desc.numel), 0, 1); ring_add_(r0[0], ring_sub(r, rs[0])); auto [mac_r0, mac_rs] = @@ -181,26 +204,65 @@ std::vector TrustedParty::adjustAuthDot( return {r0[2], mac_r0[0], mac_r0[1], mac_r0[2]}; } +std::vector TrustedParty::adjustAuthAnd( + absl::Span descs, + absl::Span mac_descs, uint128_t global_key) const { + SPU_ENFORCE_EQ(descs.size(), 3U); + checkDescs(descs); + + SPU_ENFORCE_EQ(mac_descs.size(), 3U); + checkDescs(mac_descs); + + auto [r0, rs] = reconstruct(RecOp::ADD, getSeeds(), descs); + // r0[2] += rs[0] * rs[1] - rs[2]; + ring_add_(r0[2], ring_sub(ring_mul(rs[0], rs[1]), rs[2])); + + auto [mac_r0, mac_rs] = reconstruct(RecOp::ADD, getSeeds(), mac_descs); + // mac_r0[0] += rs[0] * global_key - mac_rs[0]; + auto amac = ring_mul(rs[0], global_key); + ring_add_(mac_r0[0], ring_sub(amac, mac_rs[0])); + + // mac_r0[1] += rs[1] * global_key - mac_rs[1]; + auto bmac = ring_mul(rs[1], global_key); + ring_add_(mac_r0[1], ring_sub(bmac, mac_rs[1])); + + // mac_r0[2] += rs[0] * rs[1] * global_key - mac_rs[2]; + auto c = ring_mul(rs[0], rs[1]); + auto cmac = ring_mul(c, global_key); + ring_add_(mac_r0[2], ring_sub(cmac, mac_rs[2])); + return {r0[2], mac_r0[0], mac_r0[1], mac_r0[2]}; +} + std::vector TrustedParty::adjustAuthTrunc( absl::Span descs, - absl::Span mac_descs, size_t bits, - uint128_t global_key) const { + absl::Span mac_descs, size_t bits, uint128_t global_key, + size_t k, size_t s) const { SPU_ENFORCE_EQ(descs.size(), 2U); checkDescs(descs); + const auto field = descs[0].field; auto [r0, rs] = reconstruct(RecOp::ADD, getSeeds(), descs); + // r0[0] += (rs[0] & ((1 << k) - 1)) - rs[0]; + auto t_rs = rs[0].clone(); + ring_bitmask_(rs[0], 0, k); + ring_add_(r0[0], ring_sub(rs[0], t_rs)); + // r0[1] += (rs[0] >> bits) - rs[1]; - ring_add_(r0[1], ring_sub(ring_arshift(rs[0], bits), rs[1])); + const size_t bit_len = SizeOf(field) * 8; + auto tr_rs0 = + ring_arshift(ring_lshift(rs[0], bit_len - k), bit_len - k + bits); + ring_bitmask_(tr_rs0, 0, k); + ring_add_(r0[1], ring_sub(tr_rs0, rs[1])); auto [mac_r0, mac_rs] = reconstruct(RecOp::ADD, getSeeds(), mac_descs); - // mac_r0[0] += rs[0] * global_key - mac_rs[1]; - auto mac = ring_mul(ring_arshift(rs[0], bits), global_key); + // mac_r0[0] += rs[0] * global_key - mac_rs[0]; + auto mac = ring_mul(rs[0], global_key); ring_add_(mac_r0[0], ring_sub(mac, mac_rs[0])); // mac_r0[1] += (rs[0] >> bits) * global_key - mac_rs[1]; - auto tr_mac = ring_mul(ring_arshift(rs[0], bits), global_key); + auto tr_mac = ring_mul(tr_rs0, global_key); ring_add_(mac_r0[1], ring_sub(tr_mac, mac_rs[1])); - return {r0[1], mac_r0[0], mac_r0[1]}; + return {r0[0], r0[1], mac_r0[0], mac_r0[1]}; } } // namespace spu::mpc::spdz2k diff --git a/libspu/mpc/spdz2k/beaver/trusted_party.h b/libspu/mpc/spdz2k/beaver/trusted_party.h index a178b638..ebb339df 100644 --- a/libspu/mpc/spdz2k/beaver/trusted_party.h +++ b/libspu/mpc/spdz2k/beaver/trusted_party.h @@ -37,9 +37,13 @@ class TrustedParty { std::vector adjustAuthCoinTossing(const PrgArrayDesc& desc, const PrgArrayDesc& mac_desc, - uint128_t global_key, + uint128_t global_key, size_t k, size_t s) const; + std::vector adjustAuthRandBit(const PrgArrayDesc& desc, + const PrgArrayDesc& mac_desc, + uint128_t global_key, size_t s) const; + std::vector adjustAuthMul(absl::Span descs, absl::Span mac_descs, uint128_t global_key) const; @@ -49,10 +53,14 @@ class TrustedParty { size_t m, size_t n, size_t k, uint128_t global_key) const; + std::vector adjustAuthAnd(absl::Span descs, + absl::Span mac_descs, + uint128_t global_key) const; + std::vector adjustAuthTrunc( absl::Span descs, absl::Span mac_descs, size_t bits, - uint128_t global_key) const; + uint128_t global_key, size_t k, size_t s) const; }; } // namespace spu::mpc::spdz2k diff --git a/libspu/mpc/spdz2k/boolean.cc b/libspu/mpc/spdz2k/boolean.cc new file mode 100644 index 00000000..e75da37a --- /dev/null +++ b/libspu/mpc/spdz2k/boolean.cc @@ -0,0 +1,630 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "libspu/mpc/spdz2k/boolean.h" + +#include + +#include "libspu/core/parallel_utils.h" +#include "libspu/core/trace.h" +#include "libspu/mpc/ab_api.h" +#include "libspu/mpc/common/communicator.h" +#include "libspu/mpc/common/prg_state.h" +#include "libspu/mpc/common/pv2k.h" +#include "libspu/mpc/spdz2k/type.h" +#include "libspu/mpc/spdz2k/value.h" + +namespace spu::mpc::spdz2k { + +namespace { + +// Input a plaintext +// Output the B-share without MAC +// LSB first, MSB last +ArrayRef P2Value(FieldType out_field, const ArrayRef& in, size_t k, + size_t new_nbits = 0) { + const auto* in_ty = in.eltype().as(); + const auto in_field = in_ty->field(); + return DISPATCH_ALL_FIELDS(in_field, "_", [&]() { + using PShrT = ring2k_t; + auto _in = ArrayView(in); + size_t nbits = std::min(k, maxBitWidth(_in)); + + if (new_nbits == 0) { + new_nbits = nbits; + } + + size_t valid_nbits = new_nbits; + size_t min_nbits = std::min(nbits, new_nbits); + + const size_t out_numel = in.numel() * valid_nbits; + auto out = ring_zeros(out_field, out_numel); + return DISPATCH_ALL_FIELDS(out_field, "_", [&]() { + using BShrT = ring2k_t; + auto _out = ArrayView(out); + pforeach(0, in.numel(), [&](int64_t idx) { + pforeach(0, min_nbits, [&](int64_t jdx) { + size_t offset = idx * valid_nbits + jdx; + _out[offset] = static_cast((_in[idx] >> jdx) & 1); + }); + }); + + return out; + }); + }); +} + +// RShift implementation +std::pair RShiftBImpl(const ArrayRef& in, size_t bits) { + const auto field = in.eltype().as()->field(); + const auto old_nbits = in.eltype().as()->nbits(); + const size_t p_num = in.numel(); + size_t new_nbits = old_nbits - bits; + + if (bits == 0) return {getValueShare(in), getMacShare(in)}; + if (new_nbits <= 0) { + return {ring_zeros(field, p_num), ring_zeros(field, p_num)}; + } + + auto out_val = ring_zeros(field, p_num * new_nbits); + auto out_mac = ring_zeros(field, p_num * new_nbits); + + size_t out_offset = 0; + size_t in_offset = bits; + + auto in_val = getValueShare(in).clone(); + auto in_mac = getMacShare(in).clone(); + + for (size_t i = 0; i < p_num; ++i) { + auto _in_val = ArrayRef(in_val.buf(), makeType(field), new_nbits, 1, + (i * old_nbits + in_offset) * SizeOf(field)); + auto _in_mac = ArrayRef(in_mac.buf(), makeType(field), new_nbits, 1, + (i * old_nbits + in_offset) * SizeOf(field)); + auto _out_val = ArrayRef(out_val.buf(), makeType(field), new_nbits, + 1, (i * new_nbits + out_offset) * SizeOf(field)); + auto _out_mac = ArrayRef(out_mac.buf(), makeType(field), new_nbits, + 1, (i * new_nbits + out_offset) * SizeOf(field)); + ring_add_(_out_val, _in_val); + ring_add_(_out_mac, _in_mac); + } + return {out_val, out_mac}; +} + +// ARShift implementation +std::pair ARShiftBImpl(const ArrayRef& in, size_t bits, + size_t k) { + const auto old_nbits = in.eltype().as()->nbits(); + // Only process negative number + SPU_ENFORCE(old_nbits == k); + const auto field = in.eltype().as()->field(); + + if (bits == 0) return {getValueShare(in), getMacShare(in)}; + size_t p_num = in.numel(); + + ArrayRef out_val(in.eltype(), p_num * old_nbits); + ArrayRef out_mac(in.eltype(), p_num * old_nbits); + + size_t offset1 = bits < old_nbits ? bits : old_nbits; + size_t offset2 = bits < old_nbits ? old_nbits - bits : 0; + + auto in_val = getValueShare(in).clone(); + auto in_mac = getMacShare(in).clone(); + + auto ones = ring_ones(field, offset1); + + for (size_t i = 0; i < p_num; ++i) { + auto _in_val1 = ArrayRef(in_val.buf(), makeType(field), offset2, 1, + (i * old_nbits + offset1) * SizeOf(field)); + auto _in_mac1 = ArrayRef(in_mac.buf(), makeType(field), offset2, 1, + (i * old_nbits + offset1) * SizeOf(field)); + auto _out_val1 = ArrayRef(out_val.buf(), makeType(field), offset2, + 1, (i * old_nbits) * SizeOf(field)); + auto _out_mac1 = ArrayRef(out_mac.buf(), makeType(field), offset2, + 1, (i * old_nbits) * SizeOf(field)); + ring_assign(_out_val1, _in_val1); + ring_assign(_out_mac1, _in_mac1); + + auto _in_val_sign = ArrayRef(in_val.buf(), makeType(field), 1, 1, + ((i + 1) * old_nbits - 1) * SizeOf(field)); + auto _in_mac_sign = ArrayRef(in_mac.buf(), makeType(field), 1, 1, + ((i + 1) * old_nbits - 1) * SizeOf(field)); + + // sign extension + auto _in_val2 = ring_mmul(_in_val_sign, ones, 1, offset1, 1); + auto _in_mac2 = ring_mmul(_in_mac_sign, ones, 1, offset1, 1); + auto _out_val2 = ArrayRef(out_val.buf(), makeType(field), offset1, + 1, (i * old_nbits + offset2) * SizeOf(field)); + auto _out_mac2 = ArrayRef(out_mac.buf(), makeType(field), offset1, + 1, (i * old_nbits + offset2) * SizeOf(field)); + ring_assign(_out_val2, _in_val2); + ring_assign(_out_mac2, _in_mac2); + } + return {out_val, out_mac}; +} + +// LShift implementation +std::pair LShiftBImpl(const ArrayRef& in, size_t bits, + size_t k) { + const auto old_nbits = in.eltype().as()->nbits(); + if (bits == 0) return {getValueShare(in), getMacShare(in)}; + + size_t p_num = in.numel(); + size_t new_nbits = old_nbits + bits; + + if (new_nbits > k) { + new_nbits = k; + } + size_t min_nbits = new_nbits - bits; + + const auto field = in.eltype().as()->field(); + auto out_val = ring_zeros(field, p_num * new_nbits); + auto out_mac = ring_zeros(field, p_num * new_nbits); + if (bits >= k) { + return {ring_zeros(field, in.numel()), ring_zeros(field, in.numel())}; + } + + size_t out_offset = bits; + size_t in_offset = 0; + + auto in_val = getValueShare(in).clone(); + auto in_mac = getMacShare(in).clone(); + + for (size_t i = 0; i < p_num; ++i) { + auto _in_val = ArrayRef(in_val.buf(), makeType(field), min_nbits, 1, + (i * old_nbits + in_offset) * SizeOf(field)); + auto _in_mac = ArrayRef(in_mac.buf(), makeType(field), min_nbits, 1, + (i * old_nbits + in_offset) * SizeOf(field)); + + auto _out_val = ArrayRef(out_val.buf(), makeType(field), min_nbits, + 1, (i * new_nbits + out_offset) * SizeOf(field)); + auto _out_mac = ArrayRef(out_mac.buf(), makeType(field), min_nbits, + 1, (i * new_nbits + out_offset) * SizeOf(field)); + + ring_add_(_out_val, _in_val); + ring_add_(_out_mac, _in_mac); + } + return {out_val, out_mac}; +} + +}; // namespace + +void CommonTypeB::evaluate(KernelEvalContext* ctx) const { + const Type& lhs = ctx->getParam(0); + const Type& rhs = ctx->getParam(1); + + SPU_TRACE_MPC_DISP(ctx, lhs, rhs); + + SPU_ENFORCE(lhs == rhs, "spdz2k always use same bshare type, lhs={}, rhs={}", + lhs, rhs); + + ctx->setOutput(lhs); +} + +ArrayRef CastTypeB::proc(KernelEvalContext* ctx, const ArrayRef& in, + const Type& to_type) const { + SPU_ENFORCE(in.eltype() == to_type, + "spdz2k always use same bshare type, lhs={}, rhs={}", in.eltype(), + to_type); + return in; +} + +ArrayRef B2P::proc(KernelEvalContext* ctx, const ArrayRef& in) const { + SPU_TRACE_MPC_LEAF(ctx, in); + + auto* beaver_ptr = ctx->getState()->beaver(); + const auto s = ctx->getState()->s(); + const auto field = in.eltype().as()->field(); + const auto out_field = ctx->getState()->getDefaultField(); + const auto nbits = in.eltype().as()->nbits(); + const size_t out_numel = in.numel(); + + // 1. Open the least significant bit + auto [pub, mac] = + beaver_ptr->BatchOpen(getValueShare(in), getMacShare(in), 1, s); + + // 2. Maccheck + SPU_ENFORCE(beaver_ptr->BatchMacCheck(pub, mac, 1, s)); + + return DISPATCH_ALL_FIELDS(field, "_", [&]() { + using BShrT = ring2k_t; + auto& value = pub; + auto _value = ArrayView(value); + return DISPATCH_ALL_FIELDS(out_field, "_", [&]() { + using PShrT = ring2k_t; + + ArrayRef out(makeType(out_field), out_numel); + auto _out = ArrayView(out); + + pforeach(0, out_numel, [&](int64_t idx) { + PShrT t = 0; + for (size_t jdx = 0; jdx < nbits; ++jdx) { + t |= static_cast((_value[idx * nbits + jdx] & 1) << jdx); + } + _out[idx] = t; + }); + + return out; + }); + }); +} + +ArrayRef P2B::proc(KernelEvalContext* ctx, const ArrayRef& in) const { + SPU_TRACE_MPC_LEAF(ctx, in); + + auto* comm = ctx->getState(); + const auto k = ctx->getState()->k(); + const auto key = ctx->getState()->key(); + const auto out_field = ctx->getState()->getDefaultField(); + auto* prg_state = ctx->getState(); + + // 1. Convert plaintext into B-value + auto p = P2Value(out_field, in, k); + auto out = ring_zeros(out_field, p.numel()); + + // 2. out = p + if (comm->getRank() == 0) { + ring_add_(out, p); + } + ArrayRef& out_mac = p; + // 3. out_mac = p * key + ring_mul_(p, key); + // 4. add some random mask + auto [r0, r1] = prg_state->genPrssPair(out_field, out.numel()); + auto [r2, r3] = prg_state->genPrssPair(out_field, out.numel()); + ring_add_(out, ring_sub(r0, r1)); + ring_add_(out_mac, ring_sub(r2, r3)); + // 5. makeBShare + const auto nbits = out.numel() / in.numel(); + return makeBShare(out, out_mac, out_field, nbits); +} + +ArrayRef NotB::proc(KernelEvalContext* ctx, const ArrayRef& in) const { + SPU_TRACE_MPC_LEAF(ctx, in); + + const auto field = in.eltype().as()->field(); + const auto nbits = in.eltype().as()->nbits(); + const auto key = ctx->getState()->key(); + const auto numel = in.numel() * nbits; + auto* comm = ctx->getState(); + // 1. convert B-share into x & x_mac + auto [x, x_mac] = BShareSwitch2Nbits(in, nbits); + + // 2. create ones + auto ones = ring_ones(field, numel); + + // 3. ret = x + one + auto ret = x.clone(); + if (comm->getRank() == 0) { + ring_add_(ret, ones); + } + + // 4. z_mac = x_mac + ones * key + ring_mul_(ones, key); + ArrayRef& ret_mac = ones; + ring_add_(ret_mac, x_mac); + + return makeBShare(ret, ret_mac, field, nbits); +} + +ArrayRef BitrevB::proc(KernelEvalContext* ctx, const ArrayRef& in, size_t start, + size_t end) const { + SPU_TRACE_MPC_LEAF(ctx, in); + + const auto field = in.eltype().as()->field(); + const auto nbits = in.eltype().as()->nbits(); + const auto numel = in.numel(); + + SPU_ENFORCE(start <= end); + SPU_ENFORCE(end <= nbits); + + auto x = getValueShare(in); + auto x_mac = getMacShare(in); + auto ret = x.clone(); + auto ret_mac = x_mac.clone(); + + DISPATCH_ALL_FIELDS(field, "_", [&]() { + auto _x = ArrayView(x); + auto _x_mac = ArrayView(x_mac); + auto _ret = ArrayView(ret); + auto _ret_mac = ArrayView(ret_mac); + for (size_t i = 0; i < static_cast(numel); ++i) { + for (size_t j = start; j < end; ++j) { + _ret[i * nbits + j] = _x[i * nbits + end + start - j - 1]; + _ret_mac[i * nbits + j] = _x_mac[i * nbits + end + start - j - 1]; + } + } + }); + + return makeBShare(ret, ret_mac, field, nbits); +} + +ArrayRef XorBB::proc(KernelEvalContext* ctx, const ArrayRef& lhs, + const ArrayRef& rhs) const { + SPU_TRACE_MPC_LEAF(ctx, lhs, rhs); + const auto field = lhs.eltype().as()->field(); + const auto nbits = maxNumBits(lhs, rhs); + + // lhs + const auto [x, x_mac] = BShareSwitch2Nbits(lhs, nbits); + + // rhs + const auto [y, y_mac] = BShareSwitch2Nbits(rhs, nbits); + + // ret + const auto& z = ring_add(x, y); + const auto& z_mac = ring_add(x_mac, y_mac); + return makeBShare(z, z_mac, field, nbits); +} + +ArrayRef XorBP::proc(KernelEvalContext* ctx, const ArrayRef& lhs, + const ArrayRef& rhs) const { + SPU_TRACE_MPC_LEAF(ctx, lhs, rhs); + + const auto field = lhs.eltype().as()->field(); + const auto nbits = maxNumBits(lhs, rhs); + const auto k = ctx->getState()->k(); + const auto key = ctx->getState()->key(); + auto* comm = ctx->getState(); + + // lhs + auto [x, x_mac] = BShareSwitch2Nbits(lhs, nbits); + + // convert plaintext to B-value + auto p = P2Value(field, rhs, k, nbits); + + // ret + auto z = x.clone(); + if (comm->getRank() == 0) { + // z += p + ring_add_(z, p); + } + + // z_mac = x_mac + p * key + ring_mul_(p, key); + ArrayRef& z_mac = p; + ring_add_(z_mac, x_mac); + + return makeBShare(z, z_mac, field, nbits); +} + +ArrayRef AndBB::proc(KernelEvalContext* ctx, const ArrayRef& lhs, + const ArrayRef& rhs) const { + SPU_TRACE_MPC_LEAF(ctx, lhs, rhs); + + const auto field = lhs.eltype().as()->field(); + auto* comm = ctx->getState(); + auto* beaver_ptr = ctx->getState()->beaver(); + const auto key = ctx->getState()->key(); + const auto s = ctx->getState()->s(); + + // 1. find the min nbits + const auto nbits = minNumBits(lhs, rhs); + + // 2. convert B-share into value & mac + auto [x, x_mac] = BShareSwitch2Nbits(lhs, nbits); + auto [y, y_mac] = BShareSwitch2Nbits(rhs, nbits); + + SPU_ENFORCE(x.numel() == y.numel()); + const auto numel = x.numel(); + // e = x - a, f = y - b + auto [beaver_vec, beaver_mac] = beaver_ptr->AuthAnd(field, numel, s); + auto& [a, b, c] = beaver_vec; + auto& [a_mac, b_mac, c_mac] = beaver_mac; + + auto e = ring_sub(x, a); + auto e_mac = ring_sub(x_mac, a_mac); + auto f = ring_sub(y, b); + auto f_mac = ring_sub(y_mac, b_mac); + + // Open the least significant bit and Check them + auto [p_e, pe_mac] = beaver_ptr->BatchOpen(e, e_mac, 1, s); + auto [p_f, pf_mac] = beaver_ptr->BatchOpen(f, f_mac, 1, s); + + SPU_ENFORCE(beaver_ptr->BatchMacCheck(p_e, pe_mac, 1, s)); + SPU_ENFORCE(beaver_ptr->BatchMacCheck(p_f, pf_mac, 1, s)); + + // Reserve the least significant bit only + ring_bitmask_(p_e, 0, 1); + ring_bitmask_(p_f, 0, 1); + auto p_ef = ring_mul(p_e, p_f); + + // z = p_e * b + p_f * a + c; + auto z = ring_add(ring_mul(p_e, b), ring_mul(p_f, a)); + ring_add_(z, c); + if (comm->getRank() == 0) { + // z += p_e * p_f; + ring_add_(z, p_ef); + } + + // z_mac = p_e * b_mac + p_f * a_mac + c_mac + p_e * p_f * key; + auto z_mac = ring_add(ring_mul(p_e, b_mac), ring_mul(p_f, a_mac)); + ring_add_(z_mac, c_mac); + ring_add_(z_mac, ring_mul(p_ef, key)); + + return makeBShare(z, z_mac, field, nbits); +} + +ArrayRef AndBP::proc(KernelEvalContext* ctx, const ArrayRef& lhs, + const ArrayRef& rhs) const { + SPU_TRACE_MPC_LEAF(ctx, lhs, rhs); + + const auto k = ctx->getState()->k(); + const auto field = lhs.eltype().as()->field(); + const auto nbits = minNumBits(lhs, rhs); + + // lhs + auto [x, x_mac] = BShareSwitch2Nbits(lhs, nbits); + + // convert rhs to B-share value + const auto p = P2Value(field, rhs, k, nbits); + + SPU_ENFORCE(x.numel() == p.numel(), "x {} p {}", x.numel(), p.numel()); + // ret + // z = x * p + const auto z = ring_mul(x, p); + // z = x_mac * p + const auto z_mac = ring_mul(x_mac, p); + return makeBShare(z, z_mac, field, nbits); +} + +ArrayRef LShiftB::proc(KernelEvalContext* ctx, const ArrayRef& in, + size_t bits) const { + const auto field = in.eltype().as()->field(); + const auto k = ctx->getState()->k(); + const size_t nbits = in.eltype().as()->nbits(); + size_t res_nbits = nbits + bits; + + if (bits >= k) { + res_nbits = 1; + } else if (res_nbits > k) { + res_nbits = k; + } + auto [ret, ret_mac] = LShiftBImpl(in, bits, k); + return makeBShare(ret, ret_mac, field, res_nbits); +} + +ArrayRef RShiftB::proc(KernelEvalContext* ctx, const ArrayRef& in, + size_t bits) const { + SPU_TRACE_MPC_LEAF(ctx, in, bits); + + const auto field = in.eltype().as()->field(); + const auto nbits = in.eltype().as()->nbits(); + // size_t rbits = std::min(nbits, bits); + // nbits -= rbits; + // if (nbits <= 0) nbits = 1; + size_t new_nbis = nbits > bits ? nbits - bits : 1; + auto [ret, ret_mac] = RShiftBImpl(in, bits); + return makeBShare(ret, ret_mac, field, new_nbis); +} + +static ArrayRef wrap_rshift_b(SPUContext* ctx, const ArrayRef& x, size_t bits) { + const Shape shape = {x.numel()}; + auto [res, _s, _t] = UnwrapValue(rshift_b(ctx, WrapValue(x, shape), bits)); + return res; +} + +ArrayRef ARShiftB::proc(KernelEvalContext* ctx, const ArrayRef& in, + size_t bits) const { + SPU_TRACE_MPC_LEAF(ctx, in, bits); + + const auto field = in.eltype().as()->field(); + const auto k = ctx->getState()->k(); + const auto nbits = in.eltype().as()->nbits(); + + if (nbits != k) { + return wrap_rshift_b(ctx->sctx(), in, bits); + // size_t new_nbis = nbits > bits ? nbits - bits : 1; + // auto [ret, ret_mac] = RShiftBImpl(in, bits); + // return makeBShare(ret, ret_mac, field, new_nbis); + } else { + auto [ret, ret_mac] = ARShiftBImpl(in, bits, k); + return makeBShare(ret, ret_mac, field, k); + } +} + +// Only process k bits at now. +ArrayRef BitIntlB::proc(KernelEvalContext* ctx, const ArrayRef& in, + size_t stride) const { + const auto field = in.eltype().as()->field(); + const auto k = ctx->getState()->k(); + + ArrayRef out = in.clone(); + DISPATCH_ALL_FIELDS(field, "_", [&]() { + using T = ring2k_t; + + if (in.eltype().isa()) { + pforeach(0, in.numel(), [&](int64_t idx) { + out.at(idx) = BitIntl(in.at(idx), stride, k); + }); + } else { + auto _in = ArrayView>(in); + auto _out = ArrayView>(out); + size_t num_per_group = 1 << stride; + size_t group_num = k / num_per_group + (k % num_per_group != 0); + size_t half_group_num = (group_num + 1) / 2; + pforeach(0, in.numel(), [&](size_t jdx) { + size_t base_offset = jdx * k; + pforeach(0, k, [&](size_t idx) { + auto group = idx / num_per_group; + auto offset = idx % num_per_group; + size_t src_idx = base_offset; + size_t dest_idx = base_offset; + if (idx < (k + 1) / 2) { + src_idx += idx; + dest_idx += 2 * group * num_per_group + offset; + _out[dest_idx][0] = _in[src_idx][0]; + _out[dest_idx][1] = _in[src_idx][1]; + } else { + src_idx += idx; + dest_idx += + (2 * (group - half_group_num) + 1) * num_per_group + offset; + _out[dest_idx][0] = _in[src_idx][0]; + _out[dest_idx][1] = _in[src_idx][1]; + } + }); + }); + } + }); + + return out; +} + +// Only process k bits at now. +ArrayRef BitDeintlB::proc(KernelEvalContext* ctx, const ArrayRef& in, + size_t stride) const { + const auto field = in.eltype().as()->field(); + const auto k = ctx->getState()->k(); + + ArrayRef out = in.clone(); + DISPATCH_ALL_FIELDS(field, "_", [&]() { + using T = ring2k_t; + + if (in.eltype().isa()) { + pforeach(0, in.numel(), [&](int64_t idx) { + out.at(idx) = BitDeintl(in.at(idx), stride, k); + }); + } else { + auto _in = ArrayView>(in); + auto _out = ArrayView>(out); + size_t num_per_group = 1 << stride; + size_t group_num = k / num_per_group + (k % num_per_group != 0); + size_t half_group_num = (group_num + 1) / 2; + pforeach(0, in.numel(), [&](size_t jdx) { + size_t base_offset = jdx * k; + pforeach(0, k, [&](size_t idx) { + auto group = idx / num_per_group; + auto offset = idx % num_per_group; + size_t src_idx = base_offset; + size_t dest_idx = base_offset; + if (idx < (k + 1) / 2) { + dest_idx += idx; + src_idx += 2 * group * num_per_group + offset; + _out[dest_idx][0] = _in[src_idx][0]; + _out[dest_idx][1] = _in[src_idx][1]; + } else { + dest_idx += idx; + src_idx += + (2 * (group - half_group_num) + 1) * num_per_group + offset; + _out[dest_idx][0] = _in[src_idx][0]; + _out[dest_idx][1] = _in[src_idx][1]; + } + }); + }); + } + }); + + return out; +} + +} // namespace spu::mpc::spdz2k diff --git a/libspu/mpc/spdz2k/boolean.h b/libspu/mpc/spdz2k/boolean.h new file mode 100644 index 00000000..ce3e3ce7 --- /dev/null +++ b/libspu/mpc/spdz2k/boolean.h @@ -0,0 +1,203 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "libspu/core/array_ref.h" +#include "libspu/core/cexpr.h" +#include "libspu/mpc/kernel.h" +#include "libspu/mpc/spdz2k/state.h" +#include "libspu/mpc/spdz2k/value.h" + +namespace spu::mpc::spdz2k { + +class CommonTypeB : public Kernel { + public: + static constexpr char kBindName[] = "common_type_b"; + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + void evaluate(KernelEvalContext* ctx) const override; +}; + +class CastTypeB : public CastTypeKernel { + public: + static constexpr char kBindName[] = "cast_type_b"; + + Kind kind() const override { return Kind::Dynamic; } + + ArrayRef proc(KernelEvalContext* ctx, const ArrayRef& in, + const Type& to_type) const override; +}; + +class B2P : public UnaryKernel { + public: + static constexpr char kBindName[] = "b2p"; + + Kind kind() const override { return Kind::Dynamic; } + + ArrayRef proc(KernelEvalContext* ctx, const ArrayRef& in) const override; +}; + +class P2B : public UnaryKernel { + public: + static constexpr char kBindName[] = "p2b"; + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + ArrayRef proc(KernelEvalContext* ctx, const ArrayRef& in) const override; +}; + +class NotB : public UnaryKernel { + public: + static constexpr char kBindName[] = "not_b"; + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + ArrayRef proc(KernelEvalContext* ctx, const ArrayRef& in) const override; +}; + +class BitrevB : public BitrevKernel { + public: + static constexpr char kBindName[] = "bitrev_b"; + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + ArrayRef proc(KernelEvalContext* ctx, const ArrayRef& in, size_t start, + size_t end) const override; +}; + +class AndBP : public BinaryKernel { + public: + static constexpr char kBindName[] = "and_bp"; + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + ArrayRef proc(KernelEvalContext* ctx, const ArrayRef& lhs, + const ArrayRef& rhs) const override; +}; + +class AndBB : public BinaryKernel { + public: + static constexpr char kBindName[] = "and_bb"; + + ce::CExpr latency() const override { + // rotate : 1 + return ce::Const(0); + } + + ce::CExpr comm() const override { + // rotate : k + return ce::Const(0); + } + + ArrayRef proc(KernelEvalContext* ctx, const ArrayRef& lhs, + const ArrayRef& rhs) const override; +}; + +class XorBP : public BinaryKernel { + public: + static constexpr char kBindName[] = "xor_bp"; + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + ArrayRef proc(KernelEvalContext* ctx, const ArrayRef& lhs, + const ArrayRef& rhs) const override; +}; + +class XorBB : public BinaryKernel { + public: + static constexpr char kBindName[] = "xor_bb"; + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + ArrayRef proc(KernelEvalContext* ctx, const ArrayRef& lhs, + const ArrayRef& rhs) const override; +}; + +class LShiftB : public ShiftKernel { + public: + static constexpr char kBindName[] = "lshift_b"; + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + ArrayRef proc(KernelEvalContext* ctx, const ArrayRef& in, + size_t bits) const override; +}; + +class RShiftB : public ShiftKernel { + public: + static constexpr char kBindName[] = "rshift_b"; + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + ArrayRef proc(KernelEvalContext* ctx, const ArrayRef& in, + size_t bits) const override; +}; + +class ARShiftB : public ShiftKernel { + public: + static constexpr char kBindName[] = "arshift_b"; + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + ArrayRef proc(KernelEvalContext* ctx, const ArrayRef& in, + size_t bits) const override; +}; + +class BitIntlB : public BitSplitKernel { + public: + static constexpr char kBindName[] = "bitintl_b"; + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + ArrayRef proc(KernelEvalContext* ctx, const ArrayRef& in, + size_t stride) const override; +}; + +class BitDeintlB : public BitSplitKernel { + public: + static constexpr char kBindName[] = "bitdeintl_b"; + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + ArrayRef proc(KernelEvalContext* ctx, const ArrayRef& in, + size_t stride) const override; +}; + +} // namespace spu::mpc::spdz2k diff --git a/libspu/mpc/spdz2k/commitment.cc b/libspu/mpc/spdz2k/commitment.cc index 389c3198..065e1811 100644 --- a/libspu/mpc/spdz2k/commitment.cc +++ b/libspu/mpc/spdz2k/commitment.cc @@ -19,8 +19,9 @@ #include "yacl/crypto/base/hash/hash_utils.h" #include "yacl/crypto/utils/rand.h" +#include "libspu/core/prelude.h" + namespace spu::mpc { -// TODO: Maybe we need a better commit scheme std::string commit(size_t send_player, absl::string_view msg, absl::string_view r, size_t hash_len, yacl::crypto::HashAlgorithm hash_type) { @@ -30,14 +31,14 @@ std::string commit(size_t send_player, absl::string_view msg, hash_algo = std::make_unique(); break; default: - YACL_THROW("Unsupported hash algo in commitment scheme"); + SPU_THROW("Unsupported hash algo in commitment scheme"); } hash_algo->Update(std::to_string(send_player)); hash_algo->Update(msg); hash_algo->Update(r); std::vector hash = hash_algo->CumulativeHash(); - YACL_ENFORCE(hash_len <= hash.size()); + SPU_ENFORCE(hash_len <= hash.size()); std::string hash_str(reinterpret_cast(hash.data()), hash_len); diff --git a/libspu/mpc/spdz2k/conversion.cc b/libspu/mpc/spdz2k/conversion.cc new file mode 100644 index 00000000..17095af6 --- /dev/null +++ b/libspu/mpc/spdz2k/conversion.cc @@ -0,0 +1,564 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "libspu/mpc/spdz2k/conversion.h" + +#include "libspu/core/parallel_utils.h" +#include "libspu/core/trace.h" +#include "libspu/core/vectorize.h" +#include "libspu/core/xt_helper.h" +#include "libspu/mpc/ab_api.h" +#include "libspu/mpc/api.h" +#include "libspu/mpc/common/communicator.h" +#include "libspu/mpc/common/prg_state.h" +#include "libspu/mpc/common/pv2k.h" +#include "libspu/mpc/spdz2k/arithmetic.h" +#include "libspu/mpc/spdz2k/boolean.h" +#include "libspu/mpc/spdz2k/type.h" +#include "libspu/mpc/spdz2k/value.h" +#include "libspu/mpc/utils/circuits.h" +#include "libspu/mpc/utils/ring_ops.h" + +namespace spu::mpc::spdz2k { + +namespace { + +static ArrayRef wrap_add_bb(SPUContext* ctx, const ArrayRef& x, + const ArrayRef& y) { + SPU_ENFORCE(x.numel() == y.numel()); + const Shape shape = {x.numel()}; + auto [res, _s, _t] = + UnwrapValue(add_bb(ctx, WrapValue(x, shape), WrapValue(y, shape))); + return res; +} + +static ArrayRef wrap_and_bb(SPUContext* ctx, const ArrayRef& x, + const ArrayRef& y) { + SPU_ENFORCE(x.numel() == y.numel()); + const Shape shape = {x.numel()}; + auto [res, _s, _t] = + UnwrapValue(and_bb(ctx, WrapValue(x, shape), WrapValue(y, shape))); + return res; +} + +static ArrayRef wrap_a2bit(SPUContext* ctx, const ArrayRef& x) { + const Shape shape = {x.numel()}; + auto [res, _s, _t] = + UnwrapValue(dynDispatch(ctx, "a2bit", WrapValue(x, shape))); + return res; +} + +static ArrayRef wrap_bit2a(SPUContext* ctx, const ArrayRef& x) { + const Shape shape = {x.numel()}; + auto [res, _s, _t] = + UnwrapValue(dynDispatch(ctx, "bit2a", WrapValue(x, shape))); + return res; +} + +static ArrayRef wrap_b2a(SPUContext* ctx, const ArrayRef& x) { + const Shape shape = {x.numel()}; + auto [res, _s, _t] = UnwrapValue(b2a(ctx, WrapValue(x, shape))); + return res; +} + +static ArrayRef wrap_p2b(SPUContext* ctx, const ArrayRef& x) { + const Shape shape = {x.numel()}; + auto [res, _s, _t] = UnwrapValue(p2b(ctx, WrapValue(x, shape))); + return res; +} + +static ArrayRef wrap_not_b(SPUContext* ctx, const ArrayRef& x) { + const Shape shape = {x.numel()}; + KernelEvalContext kctx(ctx); + auto [res, _s, _t] = + UnwrapValue(dynDispatch(ctx, "not_b", WrapValue(x, shape))); + return res; +} + +static ArrayRef wrap_bitle_bb(SPUContext* ctx, const ArrayRef& x, + const ArrayRef& y) { + SPU_ENFORCE(x.numel() == y.numel()); + const Shape shape = {x.numel()}; + auto [res, _s, _t] = UnwrapValue( + dynDispatch(ctx, "bitle_bb", WrapValue(x, shape), WrapValue(y, shape))); + return res; +} + +static ArrayRef wrap_carray_out(const CircuitBasicBlock& cbb, + const ArrayRef& x, const ArrayRef& y, + size_t nbits) { + SPU_ENFORCE(x.numel() == y.numel()); + const Shape shape = {x.numel()}; + auto [res, _s, _t] = UnwrapValue( + carry_out(cbb, WrapValue(x, shape), WrapValue(y, shape), nbits)); + return res; +} + +[[maybe_unused]] auto pt_ones(FieldType field, size_t numel) { + return DISPATCH_ALL_FIELDS(field, "_", [&]() { + using PShrT = ring2k_t; + ArrayRef out(makeType(field), numel); + auto _out = ArrayView(out); + pforeach(0, numel, [&](int64_t idx) { + PShrT t = 1; + _out[idx] = t; + }); + return out; + }); +} + +inline bool _IsB(const Value& x) { return x.storage_type().isa(); } +inline bool _IsP(const Value& x) { return x.storage_type().isa(); } + +#define COMMUTATIVE_DISPATCH(FnPP, FnBP, FnBB) \ + if (_IsP(x) && _IsP(y)) { \ + return FnPP(ctx, x, y); \ + } else if (_IsB(x) && _IsP(y)) { \ + return FnBP(ctx, x, y); \ + } else if (_IsP(x) && _IsB(y)) { \ + return FnBP(ctx, y, x); \ + } else if (_IsB(x) && _IsB(y)) { \ + return FnBB(ctx, y, x); \ + } else { \ + SPU_THROW("unsupported op x={}, y={}", x, y); \ + } + +CircuitBasicBlock MakeSPDZBasicBlock(SPUContext* ctx) { + using T = Value; + CircuitBasicBlock cbb; + cbb._xor = [=](T const& x, T const& y) -> T { + COMMUTATIVE_DISPATCH(xor_pp, xor_bp, xor_bb); + }; + cbb._and = [=](T const& x, T const& y) -> T { + COMMUTATIVE_DISPATCH(and_pp, and_bp, and_bb); + }; + cbb.lshift = [=](T const& x, size_t bits) -> T { + if (_IsP(x)) { + return lshift_p(ctx, x, bits); + } else if (_IsB(x)) { + return lshift_b(ctx, x, bits); + } + SPU_THROW("unsupported op x={}", x); + }; + cbb.rshift = [=](T const& x, size_t bits) -> T { + if (_IsP(x)) { + return rshift_p(ctx, x, bits); + } else if (_IsB(x)) { + return rshift_b(ctx, x, bits); + } + SPU_THROW("unsupported op x={}", x); + }; + cbb.init_like = [=](T const& x, uint128_t init) -> T { + return make_p(ctx, init, {x.numel()}); + }; + cbb.set_nbits = [=](T& x, size_t nbits) { + return x.storage_type().as()->setNbits(nbits); + }; + return cbb; +} + +}; // namespace + +ArrayRef A2Bit::proc(KernelEvalContext* ctx, const ArrayRef& in) const { + // ArrayRef a2bit_impl(KernelEvalContext* ctx, const ArrayRef& in) { + SPU_TRACE_MPC_LEAF(ctx, in); + + const auto field = in.eltype().as()->field(); + const size_t s = ctx->getState()->s(); + const size_t nbits = 1; + + // 1. value mod 2^{s+1} + // mac mod 2^{s+1} + const auto& in_val = getValueShare(in); + const auto& in_mac = GetMacShare(ctx, in); + auto res_val = ring_bitmask(in_val, 0, s + 1); + auto res_mac = ring_bitmask(in_mac, 0, s + 1); + // 2. makeBShare + return makeBShare(res_val, res_mac, field, nbits); +} + +ArrayRef Bit2A::proc(KernelEvalContext* ctx, const ArrayRef& in) const { + // ArrayRef bit2a_impl(KernelEvalContext* ctx, const ArrayRef& in) { + SPU_TRACE_MPC_LEAF(ctx, in); + + const auto field = in.eltype().as()->field(); + auto* comm = ctx->getState(); + auto* beaver = ctx->getState()->beaver(); + const size_t s = ctx->getState()->s(); + const size_t k = ctx->getState()->k(); + const auto key = ctx->getState()->key(); + + // The protocol for Bit2A in SPDZ2k + // reference: https://eprint.iacr.org/2019/599.pdf + // Page 6, Figure 3. + + // 1. Reserve the least significant bit + const auto [_in_val, _in_mac] = BShareSwitch2Nbits(in, 1); + const auto _in = makeBShare(_in_val, _in_mac, field, 1); + const size_t out_numel = _in.numel(); + + // 2. get random bit [r] in the form of A-share + auto [r, r_mac] = beaver->AuthRandBit(field, out_numel, k, s); + auto ar = makeAShare(r, r_mac, field); + + // 3. Convert [r] into B-share + auto br = wrap_a2bit(ctx->sctx(), ar); + + // 4. c = open([x] + [r]) + auto bc = wrap_add_bb(ctx->sctx(), _in, br); + // Notice we only reserve the least significant bit + const auto [bc_val, bc_mac] = BShareSwitch2Nbits(bc, 1); + auto [c, zero_mac] = beaver->BatchOpen(bc_val, bc_mac, 1, s); + SPU_ENFORCE(beaver->BatchMacCheck(c, zero_mac, 1, s)); + ring_bitmask_(c, 0, 1); + + // 5. [x] = c + [r] - 2 * c * [r] + return DISPATCH_ALL_FIELDS(field, "_", [&]() { + auto _c = ArrayView(c); + auto _r = ArrayView(r); + auto _r_mac = ArrayView(r_mac); + + ArrayRef out(makeType(field, true), out_numel); + auto _out = ArrayView>(out); + + pforeach(0, out_numel, [&](int64_t idx) { + _out[idx][0] = (_r[idx] - 2 * _c[idx] * _r[idx]); + if (comm->getRank() == 0) { + _out[idx][0] += _c[idx]; + } + _out[idx][1] = _c[idx] * key + _r_mac[idx] - 2 * _c[idx] * _r_mac[idx]; + }); + + return out; + }); +} + +ArrayRef A2B::proc(KernelEvalContext* ctx, const ArrayRef& in) const { + SPU_TRACE_MPC_LEAF(ctx, in); + + const auto field = in.eltype().as()->field(); + auto* beaver = ctx->getState()->beaver(); + const size_t k = ctx->getState()->k(); + const size_t s = ctx->getState()->s(); + + // 1. get rand bit r + auto [rbit, rbit_mac] = beaver->AuthRandBit(field, in.numel() * k, k, s); + auto arbit = makeAShare(rbit, rbit_mac, field); + + // 2. a2bit + auto _br = wrap_a2bit(ctx->sctx(), arbit); + auto br_val = getValueShare(_br); + auto br_mac = getMacShare(_br); + + auto br = makeBShare(br_val, br_mac, field, k); + auto ar = wrap_b2a(ctx->sctx(), br); + + // 3. open a - r + const auto& in_val = getValueShare(in); + const auto& r_val = getValueShare(ar); + auto a_r_val = ring_sub(in_val, r_val); + + const auto& in_mac = GetMacShare(ctx, in); + const auto& r_mac = getMacShare(ar); + auto a_r_mac = ring_sub(in_mac, r_mac); + + auto [c, check_mac] = beaver->BatchOpen(a_r_val, a_r_mac, k, s); + SPU_ENFORCE(beaver->BatchMacCheck(c, check_mac, k, s)); + + // 4. binary add + auto ty = makeType(field); + ring_bitmask_(c, 0, k); + auto bc = wrap_p2b(ctx->sctx(), c.as(ty)); + auto [bc_val, bc_mac] = BShareSwitch2Nbits(bc, k); + + bc = makeBShare(bc_val, bc_mac, field, k); + auto res = wrap_add_bb(ctx->sctx(), br, bc); + return res; +} + +ArrayRef B2A::proc(KernelEvalContext* ctx, const ArrayRef& in) const { + SPU_TRACE_MPC_LEAF(ctx, in); + + const auto field = in.eltype().as()->field(); + const auto nbits = in.eltype().as()->nbits(); + auto* comm = ctx->getState(); + auto* beaver = ctx->getState()->beaver(); + const size_t k = ctx->getState()->k(); + const size_t s = ctx->getState()->s(); + const auto key = ctx->getState()->key(); + + const auto _in_val = getValueShare(in); + const auto _in_mac = getMacShare(in); + const auto _in = makeBShare(_in_val, _in_mac, field, 1); + const size_t out_numel = _in.numel() / nbits; + + // 1. get rand bit [r] + auto [r, r_mac] = beaver->AuthRandBit(field, _in.numel(), k, s); + auto ar = makeAShare(r, r_mac, field); + + // 2. a2bit + auto br = wrap_a2bit(ctx->sctx(), ar); + + // 3. c = open([x] + [r]) + auto bc = wrap_add_bb(ctx->sctx(), _in, br); + // Notice we only reserve the least significant bit + const auto [bc_val, bc_mac] = BShareSwitch2Nbits(bc, 1); + + auto [c, zero_mac] = beaver->BatchOpen(bc_val, bc_mac, 1, s); + SPU_ENFORCE(beaver->BatchMacCheck(c, zero_mac, 1, s)); + ring_bitmask_(c, 0, 1); + + // 4. [x] = c + [r] - 2 * c * [r] + return DISPATCH_ALL_FIELDS(field, "_", [&]() { + using PShrT = ring2k_t; + auto _c = ArrayView(c); + auto _r = ArrayView(r); + auto _r_mac = ArrayView(r_mac); + + ArrayRef out(makeType(field, true), out_numel); + ArrayRef expand_out(makeType(field, true), _in.numel()); + auto _out = ArrayView>(out); + auto _expand_out = ArrayView>(expand_out); + + pforeach(0, _in.numel(), [&](int64_t idx) { + _expand_out[idx][0] = (_r[idx] - 2 * _c[idx] * _r[idx]); + if (comm->getRank() == 0) { + _expand_out[idx][0] += _c[idx]; + } + _expand_out[idx][1] = + _c[idx] * key + _r_mac[idx] - 2 * _c[idx] * _r_mac[idx]; + }); + + pforeach(0, out_numel, [&](int64_t idx) { + _out[idx][0] = 0; + _out[idx][1] = 0; + for (size_t jdx = 0; jdx < nbits; ++jdx) { + _out[idx][0] += (_expand_out[idx * nbits + jdx][0]) << jdx; + _out[idx][1] += (_expand_out[idx * nbits + jdx][1]) << jdx; + } + }); + + return out; + }); +} + +ArrayRef MSB::proc(KernelEvalContext* ctx, const ArrayRef& in) const { + SPU_TRACE_MPC_LEAF(ctx, in); + const auto field = in.eltype().as()->field(); + const size_t k = ctx->getState()->k(); + const size_t s = ctx->getState()->s(); + const auto key = ctx->getState()->key(); + const auto* comm = ctx->getState(); + + auto* beaver = ctx->getState()->beaver(); + const auto numel = in.numel(); + + // The protocol for extracting MSB in SPDZ2k + // reference: https://eprint.iacr.org/2019/599.pdf + // Page7, Figure 6. + + auto _in = getValueShare(in); + auto _in_mac = GetMacShare(ctx, in); + + auto [c_in, c_in_mac] = beaver->BatchOpen(_in, _in_mac, k, s); + SPU_ENFORCE(beaver->BatchMacCheck(c_in, c_in_mac, k, s)); + + auto _r_val = ring_zeros(field, numel); + auto _r_mac = ring_zeros(field, numel); + std::vector _r_vec; + std::vector _r_mac_vec; + + // 1. generate random bit b , r_0 , ... , r_{k-1} + // then set r = \sum r_i 2^{i} + for (size_t i = 0; i < k; ++i) { + auto [_r_i, _r_i_mac] = beaver->AuthRandBit(field, numel, k, s); + ring_add_(_r_val, ring_lshift(_r_i, i)); + ring_add_(_r_mac, ring_lshift(_r_i_mac, i)); + // record r_i & r_i_mac + _r_vec.emplace_back(std::move(_r_i)); + _r_mac_vec.emplace_back(std::move(_r_i_mac)); + } + + // 2. reveal a + r + auto _c = ring_add(_in, _r_val); + auto _c_mac = ring_add(_in_mac, _r_mac); + auto [c_open, zero_mac] = beaver->BatchOpen(_c, _c_mac, k, s); + SPU_ENFORCE(beaver->BatchMacCheck(c_open, zero_mac, k, s)); + auto _c_open = ring_bitmask(c_open, 0, k - 1); + + // 3. convert r from A-share to B-share + // set r' be the B-share for sum_{i=0}^{k-2} r_i + ArrayRef _bt_r = ring_zeros(field, numel * (k - 1)); + ArrayRef _bt_r_mac = ring_zeros(field, numel * (k - 1)); + ArrayRef _ar = ring_zeros(field, numel); + ArrayRef _ar_mac = ring_zeros(field, numel); + + const auto ty = makeType(field); + for (size_t i = 0; i < k - 1; ++i) { + ring_add_(_ar, ring_lshift(_r_vec[i], i)); + ring_add_(_ar_mac, ring_lshift(_r_mac_vec[i], i)); + + auto at_r_i = makeAShare(_r_vec[i], _r_mac_vec[i], field); + auto bt_r_i = wrap_a2bit(ctx->sctx(), at_r_i); + const auto _bt_r_i = getValueShare(bt_r_i); + const auto _bt_r_i_mac = getMacShare(bt_r_i); + auto _sub_bt_r = + ArrayRef(_bt_r.buf(), ty, numel, _bt_r.stride() * (k - 1), + _bt_r.offset() + i * static_cast(ty.size())); + auto _sub_bt_r_mac = + ArrayRef(_bt_r_mac.buf(), ty, numel, _bt_r_mac.stride() * (k - 1), + _bt_r_mac.offset() + i * static_cast(ty.size())); + ring_add_(_sub_bt_r, _bt_r_i); + ring_add_(_sub_bt_r_mac, _bt_r_i_mac); + } + auto br = makeBShare(_bt_r, _bt_r_mac, field, k - 1); + + // 4. u = BitLT( c , r' ) + // todo: Here should be ctx->caller()->call("bitlt_pb", _pc , br) + // or ctx->caller()->call("bitle_bp", br , _pc) + // or ctx->caller()->call("bitlt_bb", _bc, br) + auto _pc = _c_open.as(makeType(field)); + ring_bitmask_(_pc, 0, k); + auto _bc = wrap_p2b(ctx->sctx(), _pc); + + // auto not_u = BitLEBB().proc(ctx, br, _bc); + // auto bu = NotB().proc(ctx, not_u); + auto not_u = wrap_bitle_bb(ctx->sctx(), br, _bc); + auto bu = wrap_not_b(ctx->sctx(), not_u); + + // 5. convert u from B-share to A-share + auto au = wrap_bit2a(ctx->sctx(), bu); + + // 6. Compute a' = c' - r' + 2^{k-1} u + // d = a - a' + auto _au = getValueShare(au); + auto _au_mac = getMacShare(au); + auto _aa = ring_sub(ring_lshift(_au, k - 1), _ar); + auto _aa_mac = ring_sub(ring_lshift(_au_mac, k - 1), _ar_mac); + if (comm->getRank() == 0) { + ring_add_(_aa, _c_open); + } + ring_add_(_aa_mac, ring_mul(_c_open, key)); + auto _d = ring_sub(_in, _aa); + auto _d_mac = ring_sub(_in_mac, _aa_mac); + + // 7. let e = d + 2^{k-1} b, then open e + auto [_b, _b_mac] = beaver->AuthRandBit(field, numel, k, s); + auto _e = ring_add(_d, ring_lshift(_b, k - 1)); + auto _e_mac = ring_add(_d_mac, ring_lshift(_b_mac, k - 1)); + + auto [e_open, e_zero_mac] = beaver->BatchOpen(_e, _e_mac, k, s); + SPU_ENFORCE(beaver->BatchMacCheck(e_open, e_zero_mac, k, s)); + + // 8. e' be the most significant bit of e + auto _ee = ring_bitmask(ring_rshift(e_open, k - 1), 0, 1); + + // 9. output e_{k-1} + b - 2 e_{k-1} b + auto _ret = ring_sub(_b, ring_lshift(ring_mul(_b, _ee), 1)); + auto _ret_mac = ring_sub(_b_mac, ring_lshift(ring_mul(_b_mac, _ee), 1)); + if (comm->getRank() == 0) { + ring_add_(_ret, _ee); + } + ring_add_(_ret_mac, ring_mul(_ee, key)); + SPU_ENFORCE(_ret.numel() == in.numel(), "_ret numel {}, in numel is {}", + _ret.numel(), in.numel()); + + return makeBShare(_ret, _ret_mac, field, 1); +} + +static ArrayRef wrap_kogge_stone(const CircuitBasicBlock& cbb, + const ArrayRef& x, const ArrayRef& y, + size_t nbits) { + SPU_ENFORCE(x.numel() == y.numel()); + const Shape shape = {x.numel()}; + auto [res, _s, _t] = UnwrapValue( + kogge_stone(cbb, WrapValue(x, shape), WrapValue(y, shape), nbits)); + return res; +} + +ArrayRef AddBB::proc(KernelEvalContext* ctx, const ArrayRef& lhs, + const ArrayRef& rhs) const { + SPU_TRACE_MPC_LEAF(ctx, lhs, rhs); + const size_t nbits = maxNumBits(lhs, rhs); + auto cbb = MakeSPDZBasicBlock(ctx->sctx()); + // sklansky has more local computation which leads to lower performance. + return wrap_kogge_stone(cbb, lhs, rhs, nbits); +} + +ArrayRef AddBP::proc(KernelEvalContext* ctx, const ArrayRef& lhs, + const ArrayRef& rhs) const { + SPU_TRACE_MPC_LEAF(ctx, lhs, rhs); + const size_t nbits = maxNumBits(lhs, rhs); + auto cbb = MakeSPDZBasicBlock(ctx->sctx()); + // sklansky has more local computation which leads to lower performance. + return wrap_kogge_stone(cbb, lhs, rhs, nbits); +} + +#if 0 +ArrayRef BitLTBB::proc(KernelEvalContext* ctx, const ArrayRef& lhs, + const ArrayRef& rhs) const { + SPU_TRACE_MPC_LEAF(ctx, lhs, rhs); + const auto nbits = maxNumBits(lhs, rhs); + const auto field = lhs.eltype().as()->field(); + const auto numel = lhs.numel(); + auto rhs_not = ctx->caller()->call("not_b", rhs); + auto cbb = MakeSPDZBasicBlock(ctx->sctx()); + + // Full adder implementation using two half adders + // TODO: design an efficient full adder in circuit.h + auto sum = kogge_stone(cbb, lhs, rhs_not, nbits); + auto carry1 = carry_out(cbb, lhs, rhs_not, nbits); + + const auto p_numel = numel; + auto ones = pt_ones(field, p_numel); + + auto carry2 = carry_out(cbb, sum, ones, nbits); + auto ret = ctx->caller()->call("xor_bb", carry1, carry2); + + auto res = ctx->caller()->call("not_b", ret); + SPU_ENFORCE(res.numel() == lhs.numel(), "_ret numel {}, lhs numel is {}", + res.numel(), lhs.numel()); + return res; +} +#else +ArrayRef BitLTBB::proc(KernelEvalContext* ctx, const ArrayRef& lhs, + const ArrayRef& rhs) const { + SPU_TRACE_MPC_LEAF(ctx, lhs, rhs); + + auto res0 = wrap_bitle_bb(ctx->sctx(), lhs, rhs); + auto res1 = wrap_bitle_bb(ctx->sctx(), rhs, lhs); + auto eq = wrap_and_bb(ctx->sctx(), res0, res1); + auto neq = wrap_not_b(ctx->sctx(), eq); + auto res = wrap_and_bb(ctx->sctx(), neq, res0); + + return res; +} +#endif + +ArrayRef BitLEBB::proc(KernelEvalContext* ctx, const ArrayRef& lhs, + const ArrayRef& rhs) const { + SPU_TRACE_MPC_LEAF(ctx, lhs, rhs); + + const auto nbits = maxNumBits(lhs, rhs); + + auto rhs_not = wrap_not_b(ctx->sctx(), rhs); + auto cbb = MakeSPDZBasicBlock(ctx->sctx()); + auto ret = wrap_carray_out(cbb, lhs, rhs_not, nbits); + auto res = wrap_not_b(ctx->sctx(), ret); + SPU_ENFORCE(res.numel() == lhs.numel(), "res numel {}, lhs numel {} ", + res.numel(), lhs.numel()); + return res; +} + +}; // namespace spu::mpc::spdz2k \ No newline at end of file diff --git a/libspu/mpc/spdz2k/conversion.h b/libspu/mpc/spdz2k/conversion.h new file mode 100644 index 00000000..e4d48020 --- /dev/null +++ b/libspu/mpc/spdz2k/conversion.h @@ -0,0 +1,205 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "libspu/core/array_ref.h" +#include "libspu/core/cexpr.h" +#include "libspu/mpc/kernel.h" +#include "libspu/mpc/spdz2k/value.h" + +namespace spu::mpc::spdz2k { + +using ce::CExpr; +using ce::Const; +using ce::K; +using ce::Log; +using ce::N; + +class A2B : public UnaryKernel { + public: + static constexpr char kBindName[] = "a2b"; + + CExpr latency() const override { + // 1 * AddBB : log(k) + 1 + // 1 * rotate: 1 + // return Log(K()) + 1 + 1; + return Const(0); + } + + CExpr comm() const override { + // 1 * AddBB : logk * 2k + k + // 1 * rotate: k + // return Log(K()) * K() * 2 + K() * 2; + return Const(0); + } + + ArrayRef proc(KernelEvalContext* ctx, const ArrayRef& in) const override; +}; + +class A2Bit : public UnaryKernel { + public: + static constexpr char kBindName[] = "a2bit"; + + CExpr latency() const override { + // 1 * AddBB : log(k) + 1 + // 1 * rotate: 1 + // return Log(K()) + 1 + 1; + return Const(0); + } + + CExpr comm() const override { + // 1 * AddBB : logk * 2k + k + // 1 * rotate: k + // return Log(K()) * K() * 2 + K() * 2; + return Const(0); + } + + ArrayRef proc(KernelEvalContext* ctx, const ArrayRef& in) const override; +}; + +class Bit2A : public UnaryKernel { + public: + static constexpr char kBindName[] = "bit2a"; + + CExpr latency() const override { + // 1 * AddBB : log(k) + 1 + // 1 * rotate: 1 + // return Log(K()) + 1 + 1; + return Const(0); + } + + CExpr comm() const override { + // 1 * AddBB : logk * 2k + k + // 1 * rotate: k + // return Log(K()) * K() * 2 + K() * 2; + return Const(0); + } + + ArrayRef proc(KernelEvalContext* ctx, const ArrayRef& in) const override; +}; + +class BitDec : public UnaryKernel { + public: + static constexpr char kBindName[] = "bit_dec"; + + CExpr latency() const override { + // 1 * AddBB : log(k) + 1 + // 1 * rotate: 1 + return Log(K()) + 1 + 1; + } + + CExpr comm() const override { + // 1 * AddBB : logk * 2k + k + // 1 * rotate: k + return Log(K()) * K() * 2 + K() * 2; + } + + ArrayRef proc(KernelEvalContext* ctx, const ArrayRef& in) const override; +}; + +// Referrence: +// IV.E Boolean to Arithmetic Sharing (B2A), extended to 3pc settings. +// https://encrypto.de/papers/DSZ15.pdf +class B2A : public UnaryKernel { + public: + static constexpr char kBindName[] = "b2a"; + + CExpr latency() const override { + // 2 * rotate : 2 + // 1 * AddBB : 1 + logk + // return Const(3) + Log(K()); + return Const(0); + } + + CExpr comm() const override { + // 2 * rotate : 2k + // 1 * AddBB : logk * 2k + k + // return Log(K()) * K() * 2 + 3 * K(); + return Const(0); + } + + ArrayRef proc(KernelEvalContext* ctx, const ArrayRef& x) const override; +}; + +class MSB : public UnaryKernel { + public: + static constexpr char kBindName[] = "msb_a2b"; + + Kind kind() const override { return Kind::Dynamic; } + + ArrayRef proc(KernelEvalContext* ctx, const ArrayRef& x) const override; +}; + +class AddBB : public BinaryKernel { + public: + static constexpr char kBindName[] = "add_bb"; + + CExpr latency() const override { + // Cost from other gates (from KoggeStoneAdder): + // 1 * AddBB : 1 + // logk * AndBB : 2logk (if vectorize, logk) + // return Log(K()) + Const(1); + return Const(0); + } + + CExpr comm() const override { + // Cost from other gates (from KoggeStoneAdder): + // 1 * AddBB : k + // logk * AndBB : logk * 2k + // return Log(K()) * K() * 2 + K(); + return Const(0); + } + + ArrayRef proc(KernelEvalContext* ctx, const ArrayRef& lhs, + const ArrayRef& rhs) const override; +}; + +class AddBP : public BinaryKernel { + public: + static constexpr char kBindName[] = "add_bp"; + + CExpr latency() const override { return Const(0); } + + CExpr comm() const override { return Const(0); } + + ArrayRef proc(KernelEvalContext* ctx, const ArrayRef& x, + const ArrayRef& y) const override; +}; + +class BitLTBB : public BinaryKernel { + public: + static constexpr char kBindName[] = "bitlt_bb"; + + CExpr latency() const override { return Const(0); } + + CExpr comm() const override { return Const(0); } + + ArrayRef proc(KernelEvalContext* ctx, const ArrayRef& x, + const ArrayRef& y) const override; +}; + +class BitLEBB : public BinaryKernel { + public: + static constexpr char kBindName[] = "bitle_bb"; + + CExpr latency() const override { return Const(0); } + + CExpr comm() const override { return Const(0); } + + ArrayRef proc(KernelEvalContext* ctx, const ArrayRef& x, + const ArrayRef& y) const override; +}; + +} // namespace spu::mpc::spdz2k \ No newline at end of file diff --git a/libspu/mpc/spdz2k/io.cc b/libspu/mpc/spdz2k/io.cc index 774fedda..0fc94fdb 100644 --- a/libspu/mpc/spdz2k/io.cc +++ b/libspu/mpc/spdz2k/io.cc @@ -21,6 +21,29 @@ namespace spu::mpc::spdz2k { +FieldType getRuntimeField(FieldType data_field) { + switch (data_field) { + case FM32: + return FM64; + case FM64: + return FM128; + default: + SPU_THROW("unsupported data field {} for spdz2k", data_field); + } + return FT_INVALID; +} + +Type Spdz2kIo::getShareType(Visibility vis, int owner_rank) const { + if (vis == VIS_PUBLIC) { + return makeType(field_); + } else if (vis == VIS_SECRET) { + const auto runtime_field = getRuntimeField(field_); + return makeType(runtime_field); + } + + SPU_THROW("unsupported vis type {}", vis); +} + std::vector Spdz2kIo::toShares(const ArrayRef& raw, Visibility vis, int owner_rank) const { SPU_ENFORCE(raw.eltype().isa(), "expected RingTy, got {}", @@ -33,13 +56,28 @@ std::vector Spdz2kIo::toShares(const ArrayRef& raw, Visibility vis, const auto share = raw.as(makeType(field)); return std::vector(world_size_, share); } else if (vis == VIS_SECRET) { - // by default, make as arithmetic share. - const auto zeros = ring_zeros(field, raw.numel()); - const auto splits = ring_rand_additive_splits(raw, world_size_); + const auto runtime_field = getRuntimeField(field); + ArrayRef x(makeType(runtime_field), raw.numel()); + + DISPATCH_ALL_FIELDS(field, "_", [&]() { + auto _raw = ArrayView(raw); + DISPATCH_ALL_FIELDS(runtime_field, "_", [&]() { + auto _x = ArrayView(x); + + pforeach(0, raw.numel(), [&](int64_t idx) { + _x[idx] = static_cast(_raw[idx]); + }); + }); + }); + + const auto zeros = ring_zeros(runtime_field, x.numel()); + const auto splits = ring_rand_additive_splits(x, world_size_); + bool has_mac = false; std::vector shares; - shares.reserve(splits.size()); + shares.reserve(world_size_); for (const auto& split : splits) { - shares.push_back(makeAShare(split, zeros, field)); + // due to lack of information about key, MACs of data are set to zeros + shares.push_back(makeAShare(split, zeros, runtime_field, has_mac)); } return shares; } @@ -59,12 +97,34 @@ ArrayRef Spdz2kIo::fromShares(const std::vector& shares) const { if (eltype.isa()) { ring_add_(res, getValueShare(share)); } else if (eltype.isa()) { - ring_xor_(res, getValueShare(share)); + ring_add_(res, getValueShare(share)); } else { SPU_THROW("invalid share type {}", eltype); } } - return res; + + if (eltype.isa()) { + ring_bitmask_(res, 0, SizeOf(field_) * 8); + } else { + ring_bitmask_(res, 0, 1); + } + + // TODO(zxp): use export_s to extract FM64 value from FM128 + { + ArrayRef x(makeType(field_), res.numel()); + + DISPATCH_ALL_FIELDS(field, "_", [&]() { + auto _res = ArrayView(res); + DISPATCH_ALL_FIELDS(field_, "_", [&]() { + auto _x = ArrayView(x); + pforeach(0, x.numel(), [&](int64_t idx) { + _x[idx] = static_cast(_res[idx]); + }); + }); + }); + + return x; + } } SPU_THROW("unsupported eltype {}", eltype); } diff --git a/libspu/mpc/spdz2k/io.h b/libspu/mpc/spdz2k/io.h index 303d5e1e..0c07634f 100644 --- a/libspu/mpc/spdz2k/io.h +++ b/libspu/mpc/spdz2k/io.h @@ -24,6 +24,7 @@ class Spdz2kIo final : public BaseIo { std::vector toShares(const ArrayRef& raw, Visibility vis, int owner_rank = -1) const override; + Type getShareType(Visibility vis, int owner_rank = -1) const override; ArrayRef fromShares(const std::vector& shares) const override; }; diff --git a/libspu/mpc/spdz2k/io_test.cc b/libspu/mpc/spdz2k/io_test.cc index 28c20d27..bda67de6 100644 --- a/libspu/mpc/spdz2k/io_test.cc +++ b/libspu/mpc/spdz2k/io_test.cc @@ -21,8 +21,8 @@ namespace spu::mpc::spdz2k { INSTANTIATE_TEST_SUITE_P( Spdz2kIoTest, IoTest, testing::Combine(testing::Values(makeSpdz2kIo), // - testing::Values(2, 3, 5), // - testing::Values(FieldType::FM128)), + testing::Values(2), // + testing::Values(FieldType::FM64)), [](const testing::TestParamInfo& info) { return fmt::format("{}x{}", std::get<1>(info.param), std::get<2>(info.param)); diff --git a/libspu/mpc/spdz2k/ot/BUILD.bazel b/libspu/mpc/spdz2k/ot/BUILD.bazel new file mode 100644 index 00000000..01a9fb84 --- /dev/null +++ b/libspu/mpc/spdz2k/ot/BUILD.bazel @@ -0,0 +1,73 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//bazel:spu.bzl", "spu_cc_library", "spu_cc_test") +load("@yacl//bazel:yacl.bzl", "AES_COPT_FLAGS") + +package(default_visibility = ["//visibility:public"]) + +spu_cc_library( + name = "ferret", + hdrs = ["ferret.h"], + deps = [ + "//libspu/mpc/cheetah/ot:cheetah_ot", + "//libspu/mpc/common:communicator", + ], +) + +spu_cc_library( + name = "basic_ot_prot", + srcs = ["basic_ot_prot.cc"], + hdrs = ["basic_ot_prot.h"], + deps = [ + ":ferret", + ], +) + +spu_cc_library( + name = "kos_ote", + srcs = ["kos_ote.cc"], + hdrs = ["kos_ote.h"], + copts = AES_COPT_FLAGS + ["-Wno-ignored-attributes"], + deps = [ + "//libspu/core:prelude", + "@yacl//yacl/link", + "@yacl//yacl/utils:serialize", + "@yacl//yacl/crypto/primitives/ot:base_ot", + "@yacl//yacl/crypto/base/hash:hash_utils", + "@yacl//yacl/crypto/base/hash:hash_interface", + "@yacl//yacl/crypto/tools:prg", + "@yacl//yacl/crypto/tools:random_oracle", + "@yacl//yacl/crypto/tools:random_permutation", + "@yacl//yacl/utils:matrix_utils", + "@com_github_emptoolkit_emp_tool//:emp-tool", + ] +) + +spu_cc_library( + name = "tiny_ot", + srcs = ["tiny_ot.cc"], + hdrs = ["tiny_ot.h"], + copts = AES_COPT_FLAGS + ["-Wno-ignored-attributes"], + deps = [ + "//libspu/mpc/spdz2k:commitment", + "//libspu/mpc/spdz2k/ot:kos_ote", + "//libspu/mpc/common:communicator", + "//libspu/mpc/utils:ring_ops", + "@yacl//yacl/link", + "@yacl//yacl/crypto/primitives/ot:ot_store", + "@yacl//yacl/crypto/tools:prg", + "@com_github_emptoolkit_emp_tool//:emp-tool", + ] +) diff --git a/libspu/mpc/spdz2k/ot/basic_ot_prot.cc b/libspu/mpc/spdz2k/ot/basic_ot_prot.cc new file mode 100644 index 00000000..3bf882d5 --- /dev/null +++ b/libspu/mpc/spdz2k/ot/basic_ot_prot.cc @@ -0,0 +1,47 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "libspu/mpc/spdz2k/ot/basic_ot_prot.h" + +namespace spu::mpc::spdz2k { + +BasicOTProtocols::BasicOTProtocols(std::shared_ptr conn) + : conn_(std::move(conn)) { + SPU_ENFORCE(conn_ != nullptr); + if (conn_->getRank() == 0) { + ferret_sender_ = std::make_shared(conn_, true); + ferret_receiver_ = std::make_shared(conn_, false); + } else { + ferret_receiver_ = std::make_shared(conn_, false); + ferret_sender_ = std::make_shared(conn_, true); + } +} + +std::unique_ptr BasicOTProtocols::Fork() { + // TODO: we can take from cached ROTs from the caller + auto conn = std::make_shared(conn_->lctx()->Spawn()); + return std::make_unique(conn); +} + +BasicOTProtocols::~BasicOTProtocols() { Flush(); } + +void BasicOTProtocols::Flush() { + if (ferret_sender_) { + ferret_sender_->Flush(); + } +} + +int BasicOTProtocols::Rank() const { return ferret_sender_->Rank(); } + +} // namespace spu::mpc::spdz2k diff --git a/libspu/mpc/spdz2k/ot/basic_ot_prot.h b/libspu/mpc/spdz2k/ot/basic_ot_prot.h new file mode 100644 index 00000000..014b73af --- /dev/null +++ b/libspu/mpc/spdz2k/ot/basic_ot_prot.h @@ -0,0 +1,44 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "libspu/mpc/common/communicator.h" +#include "libspu/mpc/spdz2k/ot/ferret.h" + +namespace spu::mpc::spdz2k { + +class BasicOTProtocols { + public: + explicit BasicOTProtocols(std::shared_ptr conn); + + ~BasicOTProtocols(); + + std::unique_ptr Fork(); + + int Rank() const; + + std::shared_ptr GetSenderCOT() { return ferret_sender_; } + + std::shared_ptr GetReceiverCOT() { return ferret_receiver_; } + + void Flush(); + + private: + std::shared_ptr conn_; + std::shared_ptr ferret_sender_; + std::shared_ptr ferret_receiver_; +}; + +} // namespace spu::mpc::spdz2k diff --git a/libspu/mpc/spdz2k/ot/ferret.h b/libspu/mpc/spdz2k/ot/ferret.h new file mode 100644 index 00000000..3dc094bc --- /dev/null +++ b/libspu/mpc/spdz2k/ot/ferret.h @@ -0,0 +1,120 @@ +// +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// This code is mostly from mpc/cheetah/ot/ferret.h, with two modifications: +// RMCC and VOLE + +#pragma once + +#include + +#include "absl/types/span.h" +#include "yacl/base/int128.h" + +#include "libspu/mpc/cheetah/ot/ferret.h" +#include "libspu/mpc/common/communicator.h" + +namespace spu::mpc::spdz2k { + +class FerretOT : public cheetah::FerretOT { + private: + struct Impl; + std::shared_ptr impl_; + + public: + FerretOT(std::shared_ptr conn, bool is_sender, + bool malicious = true); + + ~FerretOT(); + + // VOLE, only for SPDZ2K + // data[i] = data0[i] + a[i] * corr[i] + // Sender: input corr, output random data0 + // Receiver: input a, output data + template + void SendVole(absl::Span corr, absl::Span data0); + + template + void RecvVole(absl::Span a, absl::Span data); +}; + +FerretOT::FerretOT(std::shared_ptr conn, bool is_sender, + bool malicious) + : cheetah::FerretOT::FerretOT(conn, is_sender, malicious) {} + +FerretOT::~FerretOT() {} + +// Refer to: +// Appendix C. Implementing Vector-OLE mod 2^l, P35 +// SPDZ2k: Efficient MPC mod 2k for Dishonest Majority +// - https://eprint.iacr.org/2018/482.pdf +template +void FerretOT::SendVole(absl::Span corr, absl::Span data0) { + SPU_ENFORCE(data0.size() == corr.size()); + size_t length = data0.size(); + constexpr size_t iters = sizeof(T) * 8 / 2; + + std::vector t_data(length * iters, 0); + std::vector corrs; + for (size_t i = 0; i < iters; ++i) { + std::copy(corr.begin(), corr.end(), std::back_inserter(corrs)); + } + + // call parent class method + SendCAMCC(absl::MakeSpan(corrs.data(), corrs.size()), + absl::MakeSpan(t_data.data(), t_data.size())); + Flush(); + + for (size_t j = 0; j < length; ++j) { + data0[j] = 0; + } + for (size_t i = 0; i < iters; ++i) { + for (size_t j = 0; j < length; ++j) { + data0[j] += t_data[i * length + j] << i; + } + } +} + +template +void FerretOT::RecvVole(absl::Span a, absl::Span data) { + SPU_ENFORCE(data.size() == a.size()); + + size_t length = data.size(); + constexpr size_t iters = sizeof(T) * 8 / 2; + + std::vector t_data(length * iters, 0); + std::vector b(length * iters, 0); + + for (size_t i = 0; i < iters; ++i) { + for (size_t j = 0; j < length; ++j) { + b[i * length + j] = (a[j] >> i) & 1; + } + } + + // call parent class method + RecvCAMCC(absl::MakeSpan(b.data(), b.size()), + absl::MakeSpan(t_data.data(), t_data.size())); + + for (size_t j = 0; j < length; ++j) { + data[j] = 0; + } + for (size_t i = 0; i < iters; ++i) { + for (size_t j = 0; j < length; ++j) { + data[j] += t_data[i * length + j] << i; + } + } +} + +} // namespace spu::mpc::spdz2k diff --git a/libspu/mpc/spdz2k/ot/kos_ote.cc b/libspu/mpc/spdz2k/ot/kos_ote.cc new file mode 100644 index 00000000..9270a2b2 --- /dev/null +++ b/libspu/mpc/spdz2k/ot/kos_ote.cc @@ -0,0 +1,370 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "libspu/mpc/spdz2k/ot/kos_ote.h" + +#include + +#include "emp-tool/utils/block.h" +#include "emp-tool/utils/f2k.h" +#include "yacl/crypto/base/hash/hash_utils.h" +#include "yacl/crypto/tools/prg.h" +#include "yacl/crypto/tools/random_oracle.h" +#include "yacl/crypto/tools/random_permutation.h" +#include "yacl/link/link.h" +#include "yacl/utils/matrix_utils.h" +#include "yacl/utils/serialize.h" + +#include "libspu/core/prelude.h" + +// KOS have already implemented in sz/YACL +// However, TinyOT need an active secure delta-OT +namespace spu::mpc::spdz2k { + +namespace { +constexpr size_t kKappa = 128; +constexpr size_t kS = 64; +constexpr size_t kBatchSize = 128; + +// Binary Irreducible Polynomials +// ref: https://www.hpl.hp.com/techreports/98/HPL-98-135.pdf +constexpr uint64_t kGfMod64 = (1 << 4) | (1 << 3) | (1 << 1) | 1; + +#define U128_LHF(Value) absl::Int128High64(static_cast(Value)) +#define U128_RHF(Value) absl::Int128Low64(static_cast(Value)) + +inline uint64_t Block_Lhf(const emp::block& x) { + return static_cast(_mm_extract_epi64(x, 1)); +} + +inline uint64_t Block_Rhf(const emp::block& x) { + return static_cast(_mm_extract_epi64(x, 0)); +} + +inline uint128_t Mul64(uint64_t x, uint64_t y) { + emp::block mblock = emp::zero_block; + emp::block empty = emp::zero_block; + emp::mul128(emp::makeBlock(0, x), emp::makeBlock(0, y), &mblock, &empty); + return yacl::MakeUint128(Block_Lhf(mblock), Block_Rhf(mblock)); +} + +inline uint64_t Reduce64(uint128_t x) { + emp::block xb = emp::makeBlock(U128_LHF(x), U128_RHF(x)); + emp::block tb = emp::zero_block; + emp::block empty = emp::zero_block; // useless + // high 64 of xb + emp::mul128(emp::makeBlock(0, U128_LHF(x)), emp::makeBlock(0, kGfMod64), &tb, + &empty); + xb ^= tb; + // high 64 of mb + emp::mul128(emp::makeBlock(0, Block_Lhf(tb)), emp::makeBlock(0, kGfMod64), + &tb, &empty); + xb ^= tb; + return Block_Rhf(xb); +} + +struct CheckMsg { + uint64_t x; + std::array t; + + // Constructor + CheckMsg() { + x = 0; + std::fill(t.begin(), t.end(), 0); + } + + yacl::Buffer Pack() { + std::vector res; + res.push_back(x); + res.insert(res.end(), t.begin(), t.end()); + return {res.data(), res.size() * sizeof(uint64_t)}; + } + + void Unpack(yacl::ByteContainerView buf) { + std::memcpy(&x, buf.data(), sizeof(uint64_t)); + std::memcpy(t.data(), buf.data() + sizeof(uint64_t), + kKappa * sizeof(uint64_t)); + } +}; + +inline uint64_t GenKosSharedSeed( + const std::shared_ptr& ctx) { + SPU_ENFORCE(ctx->WorldSize() == 2); + if (ctx->Rank() == 0) { + std::random_device rd; + uint64_t seed = static_cast(rd()) << 32 | rd(); + ctx->SendAsync(ctx->NextRank(), + yacl::ByteContainerView(&seed, sizeof(uint64_t)), + fmt::format("KOS-Seed")); + return seed; + } else { + uint64_t seed = 0; + auto buf = ctx->Recv(ctx->NextRank(), fmt::format("KOS-Seed")); + std::memcpy(&seed, buf.data(), sizeof(uint64_t)); + return seed; + } +} + +inline auto ExtendBaseOt( + const std::shared_ptr& base_ot, + const size_t block_num) { + std::array, kKappa> base_ot_ext0; + std::array, kKappa> base_ot_ext1; + for (size_t k = 0; k < base_ot->Size(); ++k) { + base_ot_ext0[k] = std::vector(block_num); + base_ot_ext1[k] = std::vector(block_num); + yacl::crypto::Prg prg0(base_ot->GetBlock(k, 0)); + yacl::crypto::Prg prg1(base_ot->GetBlock(k, 1)); + + prg0.Fill(absl::MakeSpan(base_ot_ext0[k])); + prg1.Fill(absl::MakeSpan(base_ot_ext1[k])); + } + return std::make_pair(base_ot_ext0, base_ot_ext1); +} + +inline auto ExtendBaseOt( + const std::shared_ptr& base_ot, + const size_t block_num) { + std::array, kKappa> base_ot_ext; + for (size_t k = 0; k < base_ot->Size(); ++k) { + base_ot_ext[k] = std::vector(block_num); + yacl::crypto::Prg prg(base_ot->GetBlock(k)); + prg.Fill(absl::MakeSpan(base_ot_ext[k])); + } + return base_ot_ext; +} + +inline auto ExtendChoice(const std::vector& choices, + const size_t final_size) { + // Extend choices to batch_num * kBlockNum bits + // 1st part (valid_ot_num bits): original ot choices + // 2nd part (verify_ot_num bits): rand bits used for checking + // 3rd party (the rest bits): padding 1; + std::vector choices_ext = choices; + + // 2nd part Extension + yacl::crypto::Prg gen; + for (size_t i = 0; i < kS; i++) { + choices_ext.push_back(gen()); + } + // 3rd part Extension + choices_ext.resize(final_size); + return choices_ext; +} + +} // namespace + +// KOS based delta-OTE +void KosOtExtSend(const std::shared_ptr& ctx, + const std::shared_ptr& base_ot, + absl::Span send_blocks, uint128_t& delta) { + SPU_ENFORCE(ctx->WorldSize() == 2); + SPU_ENFORCE(base_ot->Size() == kKappa); + SPU_ENFORCE(!send_blocks.empty()); + SPU_ENFORCE( + kS == 64, + "Currently, KOS only support statistical security = 64 bit, but get {}", + kS); + + const size_t ot_num_valid = send_blocks.size(); + const size_t ot_num_ext = ot_num_valid + kS; // without batch padding + const size_t batch_num = (ot_num_ext + kBatchSize - 1) / kBatchSize; + const size_t block_num = batch_num * kBatchSize / 128; + + // Prepare for batched computation + std::vector q_ext(ot_num_ext); + auto ot_ext = ExtendBaseOt(base_ot, block_num); + + // Prepare for consistency check + std::array q_check; + std::fill(q_check.begin(), q_check.end(), 0); + auto seed = GenKosSharedSeed(ctx); + + auto rand_samples = std::vector(batch_num * 2); + yacl::crypto::Prg prg(seed); + prg.Fill(absl::MakeSpan(rand_samples)); + + // Note the following is identical to the IKNP protocol without the final hash + // code partically copied from yacl/crypto-primitives/ot/extension/kkrt_ote.cc + // For every batch + for (size_t i = 0; i < batch_num; ++i) { + std::array recv_msg; + const size_t offset = i * kBatchSize / 128; // block offsets + + auto buf = ctx->Recv(ctx->NextRank(), fmt::format("KOS:{}", i)); + std::memcpy(recv_msg.data(), buf.data(), buf.size()); + + // Q = (u & s) ^ G(K_s) = ((G(K_0) ^ G(K_1) ^ r)) & s) ^ G(K_s) + // Q = G(K_0) when s is 0 + // Q = G(K_0) ^ r when s is 1 + // Hence we get the wanted behavior in IKNP, that is: + // s == 0, the sender receives T = G(K_0) + // s == 1, the sender receives U = G(K_0) ^ r = T ^ r + for (size_t k = 0; k < kKappa; ++k) { + const auto& ot_slice = ot_ext[k][offset]; + + if (base_ot->GetChoice(k)) { + recv_msg[k] ^= ot_slice; + } else { + recv_msg[k] = ot_slice; + } + + // ******************* CONSISTENCY CHECK ******************* + // q_check[k] ^= U128_LHF(recv_msg[k]) & rand_samples[2 * i]; + // q_check[k] ^= U128_RHF(recv_msg[k]) & rand_samples[2 * i + 1]; + uint128_t ret = Mul64(U128_LHF(recv_msg[k]), rand_samples[2 * i]); + ret ^= Mul64(U128_RHF(recv_msg[k]), rand_samples[2 * i + 1]); + q_check[k] ^= Reduce64(ret); + // ******************* CONSISTENCY CHECK ******************* + } + + // Transpose. + yacl::SseTranspose128(&recv_msg); + + // Finalize (without crhash) + const size_t limit = std::min(kBatchSize, ot_num_ext - i * kBatchSize); + for (size_t j = 0; j < limit; ++j) { + q_ext[i * kBatchSize + j] = recv_msg[j]; + } + } + + delta = 0; + for (size_t k = 0; k < 128; ++k) { + if (base_ot->GetChoice(k)) { + delta |= (uint128_t)1 << k; + } + } + + // ******************* CONSISTENCY CHECK ******************* + CheckMsg check_msgs; + check_msgs.Unpack(ctx->Recv(ctx->NextRank(), fmt::format("KOS-CHECK"))); + + for (size_t k = 0; k < kKappa; ++k) { + uint128_t result = 0; + if (base_ot->GetChoice(k)) { + result = check_msgs.t[k] ^ (check_msgs.x); + } else { + result = check_msgs.t[k]; + } + SPU_ENFORCE(result == q_check[k]); + } + // ******************* CONSISTENCY CHECK ******************* + + q_ext.resize(ot_num_valid); + for (size_t i = 0; i < ot_num_valid; i++) { + send_blocks[i] = q_ext[i]; + } +} + +// KOS based delta-OTE +void KosOtExtRecv(const std::shared_ptr& ctx, + const std::shared_ptr& base_ot, + const std::vector& choices, + absl::Span recv_blocks) { + SPU_ENFORCE(ctx->WorldSize() == 2); // Check OT has two parties + SPU_ENFORCE(base_ot->Size() == kKappa); // Check base OT size + SPU_ENFORCE(recv_blocks.size() == choices.size()); + SPU_ENFORCE(!recv_blocks.empty()); + SPU_ENFORCE( + kS == 64, + "Currently, KOS only support statistical security = 64 bit, but get {}", + kS); + + const size_t ot_num_valid = recv_blocks.size(); + const size_t ot_num_ext = ot_num_valid + kS; // without batch padding + const size_t batch_num = (ot_num_ext + kBatchSize - 1) / kBatchSize; + const size_t block_num = batch_num * kBatchSize / 128; + + // Prepare for batched computation + std::vector t_ext(ot_num_ext); + auto choice_ext = ExtendChoice(choices, batch_num * kBatchSize); + auto ot_ext = ExtendBaseOt(base_ot, block_num); + + // Prepare for consistency check + CheckMsg check_msgs; + auto seed = GenKosSharedSeed(ctx); + + auto rand_samples = std::vector(batch_num * 2); + yacl::crypto::Prg prg(seed); + prg.Fill(absl::MakeSpan(rand_samples)); + + // Note the following is identical to the IKNP protocol without the final hash + // code partically copied from yacl/crypto-primitives/ot/extension/kkrt_ote.cc + // For a task of generating 129 OTs, we actually generates 128 * 2 = 256 + // OTs. + for (size_t i = 0; i < batch_num; ++i) { + std::array send_msg; + std::array t; + + const size_t offset = i * kBatchSize / 128; // block offsets + // uint128_t choice_slice = *(choice_ext.data() + offset); + uint128_t choice_slice = 0; + for (size_t k = 0; k < kBatchSize; ++k) { + if (choice_ext[i * kBatchSize + k]) { + choice_slice |= (uint128_t)1 << k; + } + } + + // ******************* CONSISTENCY CHECK ******************* + // check_msgs.x ^= U128_LHF(choice_slice) & rand_samples[2 * i]; + // check_msgs.x ^= U128_RHF(choice_slice) & rand_samples[2 * i + 1]; + uint128_t ret = Mul64(U128_LHF(choice_slice), rand_samples[2 * i]); + ret ^= Mul64(U128_RHF(choice_slice), rand_samples[2 * i + 1]); + check_msgs.x ^= Reduce64(ret); + // ******************* CONSISTENCY CHECK ******************* + + for (size_t k = 0; k < kKappa; ++k) { + const auto& ot_slice0 = ot_ext.first[k][offset]; + const auto& ot_slice1 = ot_ext.second[k][offset]; + send_msg[k] = ot_slice0 ^ ot_slice1 ^ choice_slice; + t[k] = ot_slice0; + + // ******************* CONSISTENCY CHECK ******************* + // check_msgs.t[k] ^= U128_LHF(t[k]) & rand_samples[2 * i]; + // check_msgs.t[k] ^= U128_RHF(t[k]) & rand_samples[2 * i + 1]; + uint128_t ret = Mul64(U128_LHF(t[k]), rand_samples[2 * i]); + ret ^= Mul64(U128_RHF(t[k]), rand_samples[2 * i + 1]); + check_msgs.t[k] ^= Reduce64(ret); + // ******************* CONSISTENCY CHECK ******************* + } + + ctx->SendAsync(ctx->NextRank(), + yacl::ByteContainerView(send_msg.data(), + send_msg.size() * sizeof(uint128_t)), + fmt::format("KOS:{}", i)); + + // Transpose. + yacl::SseTranspose128(&t); + + // Finalize (without crhash) + const size_t limit = std::min(kBatchSize, ot_num_ext - i * kBatchSize); + for (size_t j = 0; j < limit; ++j) { + t_ext[i * kBatchSize + j] = t[j]; + } + } + + // ******************* CONSISTENCY CHECK ******************* + // check_msgs.Print(); + auto buf = check_msgs.Pack(); + ctx->SendAsync(ctx->NextRank(), buf, fmt::format("KOS-CHECK")); + // ******************* CONSISTENCY CHECK ******************* + + t_ext.resize(ot_num_valid); + + for (size_t i = 0; i < ot_num_valid; ++i) { + recv_blocks[i] = t_ext[i]; + } +} +// END of KOS +}; // namespace spu::mpc::spdz2k \ No newline at end of file diff --git a/libspu/mpc/spdz2k/ot/kos_ote.h b/libspu/mpc/spdz2k/ot/kos_ote.h new file mode 100644 index 00000000..cdc1d2c3 --- /dev/null +++ b/libspu/mpc/spdz2k/ot/kos_ote.h @@ -0,0 +1,32 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "absl/types/span.h" +#include "yacl/base/dynamic_bitset.h" +#include "yacl/crypto/primitives/ot/ot_store.h" +#include "yacl/link/link.h" +namespace spu::mpc::spdz2k { + +// TODO: maybe it can move to yacl +void KosOtExtSend(const std::shared_ptr& ctx, + const std::shared_ptr& base_ot, + absl::Span send_blocks, uint128_t& delta); + +void KosOtExtRecv(const std::shared_ptr& ctx, + const std::shared_ptr& base_ot, + const std::vector& choices, + absl::Span recv_blocks); + +}; // namespace spu::mpc::spdz2k \ No newline at end of file diff --git a/libspu/mpc/spdz2k/ot/tiny_ot.cc b/libspu/mpc/spdz2k/ot/tiny_ot.cc new file mode 100644 index 00000000..00d6bff2 --- /dev/null +++ b/libspu/mpc/spdz2k/ot/tiny_ot.cc @@ -0,0 +1,535 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "tiny_ot.h" + +#include "emp-tool/utils/f2k.h" +#include "yacl/crypto/tools/prg.h" +#include "yacl/crypto/utils/rand.h" + +#include "libspu/mpc/spdz2k/commitment.h" +#include "libspu/mpc/spdz2k/ot/kos_ote.h" +#include "libspu/mpc/utils/ring_ops.h" + +namespace spu::mpc::spdz2k { + +namespace { +inline int GetBit(const std::vector& choices, size_t idx) { + uint128_t mask = uint128_t(1) << (idx & 127); + return (choices[idx / 128] & mask) ? 1 : 0; +} + +inline void SetBit(std::vector& choices, size_t idx) { + choices[idx / 128] |= (uint128_t(1) << (idx & 127)); +} + +inline emp::block U128ToBlock(uint128_t x) { + auto [high, low] = yacl::DecomposeUInt128(x); + return emp::makeBlock(high, low); +} + +inline uint128_t BlockToU128(emp::block x) { + auto high = static_cast(_mm_extract_epi64(x, 1)); + auto low = static_cast(_mm_extract_epi64(x, 0)); + return yacl::MakeInt128(high, low); +} + +inline AuthBit AuthBitSender( + const std::shared_ptr& comm, + const std::shared_ptr& base_ot, size_t size, + uint128_t tinyot_key) { + std::vector send_blocks(size); + bool use_secure_rand = true; + auto choices = + yacl::crypto::RandBits>(size, use_secure_rand); + KosOtExtRecv(comm->lctx(), base_ot, choices, absl::MakeSpan(send_blocks)); + for (size_t k = 0; k < size; ++k) { + if (choices[k]) { + send_blocks[k] ^= tinyot_key; + } + } + return AuthBit{std::move(choices), std::move(send_blocks), tinyot_key}; +} + +inline AuthBit AuthBitReceiver( + const std::shared_ptr& comm, + const std::shared_ptr& base_ot, size_t size, + uint128_t tinyot_key) { + std::vector recv_blocks(size); + uint128_t delta = 0; + KosOtExtSend(comm->lctx(), base_ot, absl::MakeSpan(recv_blocks), delta); + // ENSURE KOS use tinyot_key as delta + SPU_ENFORCE(delta == tinyot_key); + + return AuthBit{std::vector(size, false), std::move(recv_blocks), + tinyot_key}; +} + +inline void BatchSwitchKeySender(const std::shared_ptr& comm, + const std::vector new_tinyot_keys, + AuthBit& bits) { + auto sigmas = comm->recv(comm->nextRank(), "TinyOT:DirtySwitch"); + // NOTICE!!! + // we will not set bits.key although bits have already switch to some new keys + for (size_t i = 0; i < new_tinyot_keys.size(); ++i) { + if (bits.choices[i]) { + bits.mac[i] ^= (sigmas[i] ^ new_tinyot_keys[i] ^ bits.key); + } + } +} + +inline void BatchSwitchKeyReceiver( + const std::shared_ptr& comm, + const std::vector& new_tinyot_keys, AuthBit& bits) { + auto sigmas = new_tinyot_keys; + for (auto& sigma : sigmas) { + sigma ^= bits.key; + } + comm->sendAsync(comm->nextRank(), sigmas, "TinyOT:DirtySwitch"); +} + +inline void SetSender(const std::shared_ptr& comm, + std::vector new_choices, AuthBit& bits) { + SPU_ENFORCE(new_choices.size() <= bits.choices.size()); + std::vector d((new_choices.size() + 127) / 128, 0); + for (size_t i = 0; i < new_choices.size(); ++i) { + if (new_choices[i] ^ bits.choices[i]) { + SetBit(d, i); + bits.mac[i] ^= bits.key; + } + bits.choices[i] = new_choices[i]; + } + comm->sendAsync(comm->nextRank(), d, "TinyOT:Set"); +} + +inline void SetReceiver(const std::shared_ptr& comm, + AuthBit& bits) { + auto d = comm->recv(comm->nextRank(), "TinyOT:Set"); + auto bound = std::min(bits.choices.size(), d.size() * 128); + for (size_t i = 0; i < bound; ++i) { + if (GetBit(d, i) == 1) { + bits.mac[i] ^= bits.key; + } + } +} + +// ShareOT protocol +// Reference: https://eprint.iacr.org/2014/101.pdf +// Page 11, protocol 8. +std::tuple ShareOT( + const std::shared_ptr& comm, + const std::shared_ptr& send_opts, + const std::shared_ptr& recv_opts, size_t size, + uint128_t tinyot_key) { + AuthBit local_bits; + AuthBit remote_bits; + + if (comm->getRank() == 0) { + local_bits = AuthBitSender(comm, send_opts, 5 * size, tinyot_key); + remote_bits = AuthBitReceiver(comm, recv_opts, 5 * size, tinyot_key); + } else { + remote_bits = AuthBitReceiver(comm, recv_opts, 5 * size, tinyot_key); + local_bits = AuthBitSender(comm, send_opts, 5 * size, tinyot_key); + } + + std::vector e_chi_z(4 * size); + bool use_secure_rand = true; + std::vector e = + yacl::crypto::RandBits>(size, use_secure_rand); + + // + yacl::crypto::Prg prg(yacl::crypto::SecureRandSeed()); + std::vector eta(size, 0); + prg.Fill(absl::MakeSpan(eta)); + + AuthBit temp_local_bits{std::vector(size), + std::vector(size, 0), tinyot_key}; + AuthBit temp_remote_bits{std::vector(size, false), + std::vector(size, 0), tinyot_key}; + std::memcpy(temp_local_bits.mac.data(), local_bits.mac.data() + 4 * size, + size * sizeof(uint128_t)); + std::memcpy(temp_remote_bits.mac.data(), remote_bits.mac.data() + 4 * size, + size * sizeof(uint128_t)); + for (size_t i = 0; i < size; ++i) { + temp_local_bits.choices[i] = local_bits.choices[4 * size + i]; + } + + if (comm->getRank() == 0) { + SetSender(comm, e, temp_local_bits); + SetReceiver(comm, temp_remote_bits); + BatchSwitchKeySender(comm, eta, temp_local_bits); + BatchSwitchKeyReceiver(comm, eta, temp_remote_bits); + } else { + SetReceiver(comm, temp_remote_bits); + SetSender(comm, e, temp_local_bits); + BatchSwitchKeyReceiver(comm, eta, temp_remote_bits); + BatchSwitchKeySender(comm, eta, temp_local_bits); + } + for (size_t i = 0; i < size; ++i) { + e_chi_z[i] = e[i]; + e_chi_z[size + i] = (1 & temp_remote_bits.mac[i]); + e_chi_z[2 * size + i] = (1 & (eta[i] ^ temp_remote_bits.mac[i])); + e_chi_z[3 * size + i] = (1 & temp_local_bits.mac[i]); + } + + // Authorize choices [e||chi0||chi1||z] + if (comm->getRank() == 0) { + SetSender(comm, e_chi_z, local_bits); + SetReceiver(comm, remote_bits); + } else { + SetReceiver(comm, remote_bits); + SetSender(comm, e_chi_z, local_bits); + } + + AuthBit auth_e{std::vector(size), std::vector(size), + tinyot_key}; + AuthBit auth_chi0{std::vector(size), std::vector(size), + tinyot_key}; + AuthBit auth_chi1{std::vector(size), std::vector(size), + tinyot_key}; + AuthBit auth_z{std::vector(size), std::vector(size), + tinyot_key}; + + std::memcpy(auth_e.mac.data(), local_bits.mac.data(), + size * sizeof(uint128_t)); + std::memcpy(auth_chi0.mac.data(), local_bits.mac.data() + size, + size * sizeof(uint128_t)); + std::memcpy(auth_chi1.mac.data(), local_bits.mac.data() + 2 * size, + size * sizeof(uint128_t)); + std::memcpy(auth_z.mac.data(), local_bits.mac.data() + 3 * size, + size * sizeof(uint128_t)); + + for (size_t k = 0; k < size; ++k) { + auth_e.choices[k] = local_bits.choices[k]; + auth_e.mac[k] ^= remote_bits.mac[k]; + + auth_chi0.choices[k] = local_bits.choices[size + k]; + auth_chi0.mac[k] ^= remote_bits.mac[size + k]; + + auth_chi1.choices[k] = local_bits.choices[2 * size + k]; + auth_chi1.mac[k] ^= remote_bits.mac[2 * size + k]; + + auth_z.choices[k] = local_bits.choices[3 * size + k]; + auth_z.mac[k] ^= remote_bits.mac[3 * size + k]; + } + + return {auth_e, auth_chi0, auth_chi1, auth_z}; +} + +// GaOT protocol (includes Authenticate OT and Sacrifice OT) +// Reference: https://eprint.iacr.org/2014/101.pdf +// Page 14, Protocol 9. +std::tuple GaOT( + const std::shared_ptr& comm, + const std::shared_ptr& send_opts, + const std::shared_ptr& recv_opts, size_t size, + uint128_t tinyot_key) { + constexpr size_t kS = 64; + + // Theorem 3. T >= (k + log2(t)) / log2(t) + size_t expand_factor = std::ceil((kS + log2(size)) / log2(size)) + 1; + size_t total_num = expand_factor * size; + auto [e, chi0, chi1, z] = + ShareOT(comm, send_opts, recv_opts, total_num, tinyot_key); + + // Phase-I: cut and choose + yacl::crypto::Prg prg(GenSharedSeed(comm)); + auto swap_lambda = [](AuthBit& a, size_t i0, size_t i1) { + std::swap(a.choices[i0], a.choices[i1]); + std::swap(a.mac[i0], a.mac[i1]); + }; + + for (size_t i = 1; i <= total_num; ++i) { + const size_t k = prg() % total_num; + swap_lambda(e, total_num - i, k); + swap_lambda(chi0, total_num - i, k); + swap_lambda(chi1, total_num - i, k); + swap_lambda(z, total_num - i, k); + } + // Open last "size" quadruples + AuthBit s_e_chi_z{std::vector(4 * size, false), + std::vector(4 * size, 0), tinyot_key}; + const size_t bias = total_num - size; + + // s_e_chi_z_vec = [ e || chi0 || chi1 || z ] + std::vector s_e_chi_z_vec((4 * size + 127) / 128, 0); + + std::memcpy(s_e_chi_z.mac.data(), e.mac.data() + bias, + size * sizeof(uint128_t)); + std::memcpy(s_e_chi_z.mac.data() + size, chi0.mac.data() + bias, + size * sizeof(uint128_t)); + std::memcpy(s_e_chi_z.mac.data() + 2 * size, chi1.mac.data() + bias, + size * sizeof(uint128_t)); + std::memcpy(s_e_chi_z.mac.data() + 3 * size, z.mac.data() + bias, + size * sizeof(uint128_t)); + + for (size_t i = 0; i < size; ++i) { + if (e.choices[bias + i]) { + s_e_chi_z.choices[i] = true; + SetBit(s_e_chi_z_vec, i); + } + if (chi0.choices[bias + i]) { + s_e_chi_z.choices[size + i] = true; + SetBit(s_e_chi_z_vec, size + i); + } + if (chi1.choices[bias + i]) { + s_e_chi_z.choices[2 * size + i] = true; + SetBit(s_e_chi_z_vec, 2 * size + i); + } + if (z.choices[bias + i]) { + s_e_chi_z.choices[3 * size + i] = true; + SetBit(s_e_chi_z_vec, 3 * size + i); + } + } + + s_e_chi_z_vec = comm->allReduce( + s_e_chi_z_vec, "GaOT:cut_choose_e_chi_z"); + std::vector s_e_chi_z_bool(4 * size, false); + + for (size_t i = 0; i < size; ++i) { + s_e_chi_z_bool[i] = GetBit(s_e_chi_z_vec, i); + s_e_chi_z_bool[size + i] = GetBit(s_e_chi_z_vec, size + i); + s_e_chi_z_bool[2 * size + i] = GetBit(s_e_chi_z_vec, 2 * size + i); + s_e_chi_z_bool[3 * size + i] = GetBit(s_e_chi_z_vec, 3 * size + i); + if (s_e_chi_z_bool[i]) { + SPU_ENFORCE(s_e_chi_z_bool[3 * size + i] == s_e_chi_z_bool[2 * size + i]); + } else { + SPU_ENFORCE(s_e_chi_z_bool[3 * size + i] == s_e_chi_z_bool[size + i]); + } + } + + // Phase-II: bucket sacrifice + const size_t sacrifice_size = (expand_factor - 2) * size; + AuthBit a{std::vector(sacrifice_size, false), + std::vector(sacrifice_size, 0), tinyot_key}; + AuthBit b{std::vector(sacrifice_size, false), + std::vector(sacrifice_size, 0), tinyot_key}; + AuthBit c{std::vector(sacrifice_size, false), + std::vector(sacrifice_size, 0), tinyot_key}; + + for (size_t offset = 0; offset < sacrifice_size; offset += size) { + for (size_t i = 0; i < size; ++i) { + a.choices[offset + i] = e.choices[i] ^ e.choices[size + offset + i]; + a.mac[offset + i] = e.mac[i] ^ e.mac[size + offset + i]; + b.choices[offset + i] = chi0.choices[i] ^ + chi0.choices[size + offset + i] ^ + chi1.choices[i] ^ chi1.choices[size + offset + i]; + b.mac[offset + i] = chi0.mac[i] ^ chi0.mac[size + offset + i] ^ + chi1.mac[i] ^ chi1.mac[size + offset + i]; + } + } + + std::vector a_vec((sacrifice_size + 127) / 128, 0); + std::vector b_vec((sacrifice_size + 127) / 128, 0); + + for (size_t i = 0; i < sacrifice_size; ++i) { + if (a.choices[i]) { + SetBit(a_vec, i); + } + if (b.choices[i]) { + SetBit(b_vec, i); + } + } + + a_vec = + comm->allReduce(a_vec, "GaOT:Sacrifice_open_a"); + b_vec = + comm->allReduce(b_vec, "GaOT:Sacrifice_open_b"); + + for (size_t offset = 0; offset < sacrifice_size; offset += size) { + for (size_t i = 0; i < size; ++i) { + const size_t si = offset + i; + c.choices[si] = z.choices[i] ^ chi0.choices[i] ^ z.choices[size + si] ^ + chi0.choices[size + si] ^ + (GetBit(a_vec, si) * + (chi0.choices[size + si] ^ chi1.choices[size + si])) ^ + (GetBit(b_vec, si) * e.choices[i]); + c.mac[si] = + z.mac[i] ^ chi0.mac[i] ^ z.mac[size + si] ^ chi0.mac[size + si] ^ + (GetBit(a_vec, si) * (chi0.mac[size + si] ^ chi1.mac[size + si])) ^ + (GetBit(b_vec, si) * e.mac[i]); + } + } + + std::vector c_vec((sacrifice_size + 127) / 128, 0); + for (size_t i = 0; i < sacrifice_size; ++i) { + if (c.choices[i]) { + SetBit(c_vec, i); + } + } + + c_vec = + comm->allReduce(c_vec, "GaOT:Sacrifice_open_c"); + for (const auto& c : c_vec) { + SPU_ENFORCE(c == 0); + } + + // Phase-III: Maccheck + std::vector a_open_bool(sacrifice_size, false); + std::vector b_open_bool(sacrifice_size, false); + std::vector c_open_bool(sacrifice_size, false); + // all elements in c are "false". + for (size_t i = 0; i < sacrifice_size; ++i) { + if (GetBit(a_vec, i)) { + a_open_bool[i] = true; + } + if (GetBit(b_vec, i)) { + b_open_bool[i] = true; + } + } + + TinyMacCheck(comm, s_e_chi_z_bool, s_e_chi_z); + TinyMacCheck(comm, a_open_bool, a); + TinyMacCheck(comm, b_open_bool, b); + TinyMacCheck(comm, c_open_bool, c); + return {e, chi0, chi1, z}; +} + +}; // namespace + +uint128_t GenSharedSeed(const std::shared_ptr& comm) { + uint128_t seed = yacl::crypto::SecureRandSeed(); + std::string seed_str(reinterpret_cast(&seed), sizeof(uint128_t)); + std::vector all_seed_strs; + SPU_ENFORCE(commit_and_open(comm->lctx(), seed_str, &all_seed_strs)); + SPU_ENFORCE(all_seed_strs.size() == comm->getWorldSize()); + uint128_t ret = 0; + for (auto& str : all_seed_strs) { + const uint128_t cur_seed = *reinterpret_cast(str.data()); + ret ^= cur_seed; + } + return ret; +} + +// TinyMacCheck protocol +// Reference: https://eprint.iacr.org/2014/101.pdf +// Page 17, protocol 10. +bool TinyMacCheck(const std::shared_ptr& comm, + std::vector open_bits, const AuthBit& bits) { + // WARNING: As of now, uint128_t does NOT support gfmul128 + const size_t size = open_bits.size(); + // generate the coefficient for "almost universal hash" + const uint128_t seed = GenSharedSeed(comm); + std::vector coeff(size); + emp::uni_hash_coeff_gen(coeff.data(), U128ToBlock(seed), size); + + emp::block ret_val = emp::zero_block; + std::vector mac(size); + for (size_t i = 0; i < size; ++i) { + // convert uint128_t to emp::block + mac[i] = U128ToBlock(bits.mac[i]); + if (open_bits[i]) { + ret_val = ret_val ^ coeff[i]; + } + } + emp::block _sigma = emp::zero_block; + // inner product over gf128 + emp::vector_inn_prdt_sum_red(&_sigma, coeff.data(), mac.data(), size); + emp::block offset = emp::zero_block; + emp::gfmul(ret_val, U128ToBlock(bits.key), &offset); + uint128_t sigma = BlockToU128(_sigma ^ offset); + + std::string sigma_str(reinterpret_cast(&sigma), sizeof(sigma)); + std::vector all_sigma_strs; + SPU_ENFORCE(commit_and_open(comm->lctx(), sigma_str, &all_sigma_strs)); + SPU_ENFORCE(all_sigma_strs.size() == comm->getWorldSize()); + + uint128_t simga_sum = 0; + for (auto& str : all_sigma_strs) { + const uint128_t _sigma = *reinterpret_cast(str.data()); + simga_sum ^= _sigma; + } + return simga_sum == 0; +} + +AuthBit RandomBits(const std::shared_ptr& comm, + const std::shared_ptr& send_opts, + const std::shared_ptr& recv_opts, + size_t size, uint128_t tinyot_key) { + AuthBit local_bits; + AuthBit remote_bits; + if (comm->getRank() == 0) { + local_bits = AuthBitSender(comm, send_opts, size, tinyot_key); + remote_bits = AuthBitReceiver(comm, recv_opts, size, tinyot_key); + } else { + remote_bits = AuthBitReceiver(comm, recv_opts, size, tinyot_key); + local_bits = AuthBitSender(comm, send_opts, size, tinyot_key); + } + for (size_t k = 0; k < size; ++k) { + local_bits.choices[k] = local_bits.choices[k] ^ remote_bits.choices[k]; + local_bits.mac[k] ^= remote_bits.mac[k]; + } + return local_bits; +} + +// Reference: https://eprint.iacr.org/2014/101.pdf +// Page 10, Protocol 7. +std::tuple TinyMul( + const std::shared_ptr& comm, + const std::shared_ptr& send_opts, + const std::shared_ptr& recv_opts, size_t size, + uint128_t tinyot_key) { + AuthBit a = RandomBits(comm, send_opts, recv_opts, size, tinyot_key); + AuthBit b = RandomBits(comm, send_opts, recv_opts, size, tinyot_key); + + auto [e, chi0, chi1, z] = GaOT(comm, send_opts, recv_opts, size, tinyot_key); + + AuthBit f{std::vector(size), std::vector(size), tinyot_key}; + AuthBit g{std::vector(size), std::vector(size), tinyot_key}; + for (size_t k = 0; k < size; ++k) { + f.choices[k] = b.choices[k] ^ e.choices[k]; + g.choices[k] = chi0.choices[k] ^ chi1.choices[k] ^ a.choices[k]; + f.mac[k] = b.mac[k] ^ e.mac[k]; + g.mac[k] = chi0.mac[k] ^ chi1.mac[k] ^ a.mac[k]; + } + + std::vector f_buf((size + 127) / 128, 0); + std::vector g_buf((size + 127) / 128, 0); + for (size_t k = 0; k < size; ++k) { + if (f.choices[k]) { + SetBit(f_buf, k); + } + if (g.choices[k]) { + SetBit(g_buf, k); + } + } + + f_buf = comm->allReduce(f_buf, "TinyOT:Mul_f"); + g_buf = comm->allReduce(g_buf, "TinyOT:Mul_g"); + + std::vector true_f(size); + std::vector true_g(size); + + for (size_t k = 0; k < size; ++k) { + if (GetBit(f_buf, k)) { + true_f[k] = true; + } + if (GetBit(g_buf, k)) { + true_g[k] = true; + } + } + // Theoretical, we could try bits f and g together in a single TinyMacCheck + TinyMacCheck(comm, true_f, f); + TinyMacCheck(comm, true_g, g); + + AuthBit c{std::vector(size), std::vector(size), tinyot_key}; + for (size_t k = 0; k < size; ++k) { + c.choices[k] = chi0.choices[k] ^ (GetBit(f_buf, k) * a.choices[k]) ^ + (GetBit(g_buf, k) * e.choices[k]) ^ z.choices[k]; + c.mac[k] = chi0.mac[k] ^ (GetBit(f_buf, k) * a.mac[k]) ^ + (GetBit(g_buf, k) * e.mac[k]) ^ z.mac[k]; + } + return {a, b, c}; +} + +}; // namespace spu::mpc::spdz2k \ No newline at end of file diff --git a/libspu/mpc/spdz2k/ot/tiny_ot.h b/libspu/mpc/spdz2k/ot/tiny_ot.h new file mode 100644 index 00000000..ef99be70 --- /dev/null +++ b/libspu/mpc/spdz2k/ot/tiny_ot.h @@ -0,0 +1,48 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include + +#include "yacl/crypto/primitives/ot/ot_store.h" + +#include "libspu/mpc/common/communicator.h" + +namespace spu::mpc::spdz2k { + +struct AuthBit { + std::vector choices; + std::vector mac; + uint128_t key; +}; + +uint128_t GenSharedSeed(const std::shared_ptr& comm); +// TinyMacCheck protocol +// Reference: https://eprint.iacr.org/2014/101.pdf +// Page 17, protocol 10. +bool TinyMacCheck(const std::shared_ptr& comm, + std::vector open_bits, const AuthBit& bits); + +AuthBit RandomBits(const std::shared_ptr& comm, + const std::shared_ptr& send_opts, + const std::shared_ptr& recv_opts, + size_t size, uint128_t tinyot_key); + +// Reference: https://eprint.iacr.org/2014/101.pdf +// Page 10, Protocol 7. +std::tuple TinyMul( + const std::shared_ptr& comm, + const std::shared_ptr& send_opts, + const std::shared_ptr& recv_opts, size_t size, + uint128_t tinyot_key); + +}; // namespace spu::mpc::spdz2k \ No newline at end of file diff --git a/libspu/mpc/spdz2k/protocol.cc b/libspu/mpc/spdz2k/protocol.cc index 469bb73e..80c8006c 100644 --- a/libspu/mpc/spdz2k/protocol.cc +++ b/libspu/mpc/spdz2k/protocol.cc @@ -19,6 +19,8 @@ #include "libspu/mpc/common/prg_state.h" #include "libspu/mpc/common/pv2k.h" #include "libspu/mpc/spdz2k/arithmetic.h" +#include "libspu/mpc/spdz2k/boolean.h" +#include "libspu/mpc/spdz2k/conversion.h" #include "libspu/mpc/spdz2k/state.h" #include "libspu/mpc/spdz2k/type.h" @@ -41,9 +43,11 @@ void regSpdz2kProtocol(SPUContext* ctx, regPV2kKernels(ctx->prot()); // register arithmetic kernels - ctx->prot()->addState(lctx); + ctx->prot()->addState(ctx->config(), lctx); ctx->prot()->regKernel(); ctx->prot()->regKernel(); + ctx->prot()->regKernel(); + ctx->prot()->regKernel(); ctx->prot()->regKernel(); ctx->prot()->regKernel(); ctx->prot()->regKernel(); @@ -54,6 +58,34 @@ void regSpdz2kProtocol(SPUContext* ctx, ctx->prot()->regKernel(); ctx->prot()->regKernel(); ctx->prot()->regKernel(); + + // register boolean kernels + ctx->prot()->regKernel(); + ctx->prot()->regKernel(); + ctx->prot()->regKernel(); + + ctx->prot()->regKernel(); + ctx->prot()->regKernel(); + ctx->prot()->regKernel(); + ctx->prot()->regKernel(); + ctx->prot()->regKernel(); + ctx->prot()->regKernel(); + ctx->prot()->regKernel(); + ctx->prot()->regKernel(); + ctx->prot()->regKernel(); + ctx->prot()->regKernel(); + ctx->prot()->regKernel(); + + // register conversion kernels + ctx->prot()->regKernel(); + ctx->prot()->regKernel(); + ctx->prot()->regKernel(); + ctx->prot()->regKernel(); + ctx->prot()->regKernel(); + ctx->prot()->regKernel(); + ctx->prot()->regKernel(); + ctx->prot()->regKernel(); + ctx->prot()->regKernel(); } std::unique_ptr makeSpdz2kProtocol( diff --git a/libspu/mpc/spdz2k/protocol_test.cc b/libspu/mpc/spdz2k/protocol_test.cc index 96aaa2cd..86ac79cb 100644 --- a/libspu/mpc/spdz2k/protocol_test.cc +++ b/libspu/mpc/spdz2k/protocol_test.cc @@ -14,7 +14,8 @@ #include "libspu/mpc/spdz2k/protocol.h" -#include "libspu/mpc/spdz2k/abprotocol_spdz2k_test.h" +#include "libspu/mpc/ab_api_test.h" +#include "libspu/mpc/api_test.h" namespace spu::mpc::test { namespace { @@ -26,16 +27,58 @@ RuntimeConfig makeConfig(FieldType field) { return conf; } +std::unique_ptr makeMpcSpdz2kProtocol( + const RuntimeConfig& rt, const std::shared_ptr& lctx) { + RuntimeConfig mpc_rt = rt; + mpc_rt.set_beaver_type(RuntimeConfig_BeaverType_MultiParty); + + return makeSpdz2kProtocol(mpc_rt, lctx); +} } // namespace INSTANTIATE_TEST_SUITE_P( - Spdz2kArithmeticTest, ArithmeticTest, - testing::Combine(testing::Values(makeSpdz2kProtocol), // - testing::Values(makeConfig(FieldType::FM128)), // - testing::Values(2, 3, 5)), // + Spdz2k, ApiTest, + testing::Combine(testing::Values(CreateObjectFn(makeSpdz2kProtocol, + "tfp")), // + testing::Values(makeConfig(FieldType::FM64)), // + testing::Values(2)), // + [](const testing::TestParamInfo& p) { + return fmt::format("{}x{}x{}", std::get<0>(p.param).name(), + std::get<1>(p.param).field(), std::get<2>(p.param)); + }); + +INSTANTIATE_TEST_SUITE_P( + Spdz2k, ArithmeticTest, + testing::Values(std::tuple{CreateObjectFn(makeSpdz2kProtocol, "tfp"), + makeConfig(FieldType::FM64), 2}, + std::tuple{CreateObjectFn(makeMpcSpdz2kProtocol, "mpc"), + makeConfig(FieldType::FM32), 2}), [](const testing::TestParamInfo& p) { - return fmt::format("{}x{}", std::get<1>(p.param).field(), - std::get<2>(p.param)); + return fmt::format("{}x{}x{}", std::get<0>(p.param).name(), + std::get<1>(p.param).field(), std::get<2>(p.param)); + }); + +// TODO : improve performance of boolean share and conversion in offline phase +INSTANTIATE_TEST_SUITE_P( + Spdz2k, BooleanTest, + testing::Combine(testing::Values(CreateObjectFn(makeSpdz2kProtocol, + "tfp")), // + testing::Values(makeConfig(FieldType::FM64)), // + testing::Values(2)), // + [](const testing::TestParamInfo& p) { + return fmt::format("{}x{}x{}", std::get<0>(p.param).name(), + std::get<1>(p.param).field(), std::get<2>(p.param)); + }); + +INSTANTIATE_TEST_SUITE_P( + Spdz2k, ConversionTest, + testing::Combine(testing::Values(CreateObjectFn(makeSpdz2kProtocol, + "tfp")), // + testing::Values(makeConfig(FieldType::FM64)), // + testing::Values(2)), // + [](const testing::TestParamInfo& p) { + return fmt::format("{}x{}x{}", std::get<0>(p.param).name(), + std::get<1>(p.param).field(), std::get<2>(p.param)); }); } // namespace spu::mpc::test diff --git a/libspu/mpc/spdz2k/state.h b/libspu/mpc/spdz2k/state.h index 64438b96..49663eed 100644 --- a/libspu/mpc/spdz2k/state.h +++ b/libspu/mpc/spdz2k/state.h @@ -24,6 +24,7 @@ #include "libspu/core/object.h" #include "libspu/mpc/common/communicator.h" #include "libspu/mpc/spdz2k/beaver/beaver_tfp.h" +#include "libspu/mpc/spdz2k/beaver/beaver_tinyot.h" #include "libspu/mpc/spdz2k/commitment.h" namespace spu::mpc { @@ -41,71 +42,73 @@ template using Share = std::complex; class Spdz2kState : public State { - std::unique_ptr beaver_; + // #ifdef TINYOT + // using Beaver = spdz2k::BeaverTinyOt; + // #else + // using Beaver = spdz2k::BeaverTfpUnsafe; + // #endif + + std::unique_ptr beaver_; std::shared_ptr lctx_; // share of global key, share key has length of 128 bit - uint128_t key_; + uint128_t key_ = 0; + + // plaintext ring size, default set to half field bit length + size_t k_ = 0; - // shares to be checked - std::unique_ptr> arr_ref_v_; + // statistical security parameter, default set to half field bit length + size_t s_ = 0; - // plaintext ring size - const size_t k_ = 64; + FieldType data_field_ = FT_INVALID; - // statistical security parameter - const size_t s_ = 64; + FieldType runtime_field_ = FT_INVALID; - // default in FM128 - const FieldType field_ = FM128; + private: + FieldType getRuntimeField(FieldType data_field) { + switch (data_field) { + case FM32: + return FM64; + case FM64: + return FM128; + default: + SPU_THROW("unsupported data field {} for spdz2k", data_field); + } + return FT_INVALID; + } public: static constexpr char kBindName[] = "Spdz2kState"; static constexpr auto kAesType = yacl::crypto::SymmetricCrypto::CryptoType::AES128_CTR; - explicit Spdz2kState(std::shared_ptr lctx) { - beaver_ = std::make_unique(lctx); + explicit Spdz2kState(const RuntimeConfig& conf, + std::shared_ptr lctx) + : data_field_(conf.field()) { + if (conf.beaver_type() == RuntimeConfig_BeaverType_TrustedFirstParty) { + beaver_ = std::make_unique(lctx); + } else if (conf.beaver_type() == RuntimeConfig_BeaverType_MultiParty) { + beaver_ = std::make_unique(lctx); + } else { + SPU_THROW("unsupported beaver type {}", conf.beaver_type()); + } lctx_ = lctx; - key_ = beaver_->GetSpdzKey(field_, s_); - arr_ref_v_ = std::make_unique>(); + runtime_field_ = getRuntimeField(data_field_); + k_ = SizeOf(data_field_) * 8; + s_ = k_; + key_ = beaver_->InitSpdzKey(runtime_field_, s_); } - spdz2k::BeaverTfpUnsafe* beaver() { return beaver_.get(); } + FieldType getDefaultField() const { return runtime_field_; } + + spdz2k::Beaver* beaver() { return beaver_.get(); } uint128_t key() const { return key_; } size_t k() const { return k_; } size_t s() const { return s_; } - - std::vector* arr_ref_v() { return arr_ref_v_.get(); } - - // public coin, used in malicious model, all party generate new seed, then - // get exactly the same random variable. - ArrayRef genPublCoin(FieldType field, size_t numel) { - ArrayRef res(makeType(field), numel); - - // generate new seed - uint128_t self_pk = yacl::crypto::RandSeed(true); - std::vector all_strs; - - std::string self_pk_str(reinterpret_cast(&self_pk), sizeof(self_pk)); - YACL_ENFORCE(commit_and_open(lctx_, self_pk_str, &all_strs)); - - uint128_t public_seed = 0; - for (const auto& str : all_strs) { - uint128_t seed = *(reinterpret_cast(str.data())); - public_seed += seed; - } - - yacl::crypto::FillPRand( - kAesType, public_seed, 0, 0, - absl::MakeSpan(static_cast(res.data()), res.buf()->size())); - - return res; - } }; } // namespace spu::mpc diff --git a/libspu/mpc/spdz2k/type.cc b/libspu/mpc/spdz2k/type.cc index 9fef8c83..d84b53d9 100644 --- a/libspu/mpc/spdz2k/type.cc +++ b/libspu/mpc/spdz2k/type.cc @@ -25,8 +25,9 @@ void registerTypes() { static std::once_flag flag; - std::call_once(flag, - []() { TypeContext::getTypeContext()->addTypes(); }); + std::call_once(flag, []() { + TypeContext::getTypeContext()->addTypes(); + }); } } // namespace spu::mpc::spdz2k diff --git a/libspu/mpc/spdz2k/type.h b/libspu/mpc/spdz2k/type.h index 5dd90312..8b43301d 100644 --- a/libspu/mpc/spdz2k/type.h +++ b/libspu/mpc/spdz2k/type.h @@ -15,21 +15,86 @@ #pragma once #include "libspu/core/type.h" +#include "libspu/core/type_util.h" namespace spu::mpc::spdz2k { class AShrTy : public TypeImpl { using Base = TypeImpl; + bool has_mac_ = false; + public: using Base::Base; static std::string_view getStaticId() { return "spdz2k.AShr"; } explicit AShrTy(FieldType field) { field_ = field; } + explicit AShrTy(FieldType field, bool has_mac) { + field_ = field; + has_mac_ = has_mac; + } + + bool hasMac() const { return has_mac_; } + size_t size() const override { return SizeOf(GetStorageType(field_)) * 2; } }; +class BShrTy : public TypeImpl { + using Base = TypeImpl; + PtType back_type_ = PT_INVALID; + size_t k_ = 0; + + public: + using Base::Base; + + explicit BShrTy(PtType back_type, size_t nbits, FieldType field) { + SPU_ENFORCE(SizeOf(back_type) * 8 >= nbits, + "backtype={} has not enough bits={}", back_type, nbits); + back_type_ = back_type; + nbits_ = nbits; + field_ = field; + k_ = SizeOf(field) * 8 / 2; + } + + PtType getBacktype() const { return back_type_; } + + static std::string_view getStaticId() { return "spdz2k.BShr"; } + + void fromString(std::string_view detail) override { + auto comma = detail.find_first_of(','); + auto last_comma = detail.find_last_of(','); + auto back_type_str = detail.substr(0, comma); + auto nbits_str = detail.substr(comma + 1, last_comma); + SPU_ENFORCE(PtType_Parse(std::string(back_type_str), &back_type_), + "parse failed from={}", back_type_str); + nbits_ = std::stoul(std::string(nbits_str)); + auto field_str = detail.substr(last_comma + 1); + SPU_ENFORCE(FieldType_Parse(std::string(field_str), &field_), + "parse failed from={}", field_str); + }; + + std::string toString() const override { + return fmt::format("{},{},{}", PtType_Name(back_type_), nbits_, field_); + } + + size_t nbits() const { return nbits_; } + + size_t k() const { return k_; } + + size_t size() const override { + return SizeOf(GetStorageType(field_)) * 2 * k_; + } + + bool equals(TypeObject const* other) const override { + auto const* derived_other = dynamic_cast(other); + SPU_ENFORCE(derived_other); + return getBacktype() == derived_other->getBacktype() && + nbits() == derived_other->nbits() && + field() == derived_other->field(); + } +}; + void registerTypes(); } // namespace spu::mpc::spdz2k diff --git a/libspu/mpc/spdz2k/type_test.cc b/libspu/mpc/spdz2k/type_test.cc index c0d42cd0..531c16e0 100644 --- a/libspu/mpc/spdz2k/type_test.cc +++ b/libspu/mpc/spdz2k/type_test.cc @@ -36,4 +36,31 @@ TEST(AShrTy, Simple) { } } +TEST(BShrTy, Simple) { + // spdz2k::BShr constructor with field and nbits. + { + Type ty = makeType(PT_U128, 127, FM128); + EXPECT_EQ(ty.size(), 16 * 128); + + EXPECT_TRUE(ty.isa()); + EXPECT_FALSE(ty.isa()); + EXPECT_FALSE(ty.isa()); + EXPECT_TRUE(ty.isa()); + + EXPECT_EQ(ty.toString(), "spdz2k.BShr"); + + EXPECT_EQ(Type::fromString(ty.toString()), ty); + + // clone + Type cty = ty; + EXPECT_EQ(cty, ty); + EXPECT_TRUE(cty.isa()); + EXPECT_FALSE(cty.isa()); + EXPECT_FALSE(cty.isa()); + EXPECT_TRUE(cty.isa()); + + EXPECT_EQ(cty.toString(), "spdz2k.BShr"); + } +} + } // namespace spu::mpc::spdz2k diff --git a/libspu/mpc/spdz2k/value.cc b/libspu/mpc/spdz2k/value.cc index 87fb8d07..befffa14 100644 --- a/libspu/mpc/spdz2k/value.cc +++ b/libspu/mpc/spdz2k/value.cc @@ -14,21 +14,21 @@ #include "libspu/mpc/spdz2k/value.h" +#include "libspu/mpc/common/pv2k.h" #include "libspu/mpc/spdz2k/type.h" #include "libspu/mpc/utils/ring_ops.h" namespace spu::mpc::spdz2k { -namespace { - -ArrayRef makeShare(const ArrayRef& s1, const ArrayRef& s2, Type ty) { - const auto field = ty.as()->field(); +ArrayRef makeAShare(const ArrayRef& s1, const ArrayRef& s2, FieldType field, + bool has_mac) { SPU_ENFORCE(s2.eltype().as()->field() == field); SPU_ENFORCE(s1.eltype().as()->field() == field); - SPU_ENFORCE(s1.numel() == s2.numel(), "got s1={}, s2={}", s1.numel(), - s2.numel()); - SPU_ENFORCE(ty.size() == 2 * s1.elsize()); + SPU_ENFORCE(s1.numel() == s2.numel(), "s1 numel ={}, s2 numel ={}", + s1.numel(), s2.numel()); + const auto ty = makeType(field, has_mac); + SPU_ENFORCE(ty.size() == 2 * s1.elsize()); ArrayRef res(ty, s1.numel()); auto res_s1 = getValueShare(res); @@ -39,27 +39,176 @@ ArrayRef makeShare(const ArrayRef& s1, const ArrayRef& s2, Type ty) { return res; } -} // namespace +ArrayRef makeBShare(const ArrayRef& s1, const ArrayRef& s2, FieldType field, + size_t nbits) { + SPU_ENFORCE(s2.eltype().as()->field() == field); + SPU_ENFORCE(s1.eltype().as()->field() == field); + SPU_ENFORCE(s1.numel() == s2.numel(), "s1 numel ={}, s2 numel ={}", + s2.numel()); + SPU_ENFORCE(s1.numel() % nbits == 0 && s1.numel() / nbits != 0, + "s1 numel = {}, nbits = {}", s1.numel(), nbits); -ArrayRef getValueShare(const ArrayRef& in) { - const auto field = in.eltype().as()->field(); - auto ty = makeType(field); + const PtType btype = calcBShareBacktype(nbits); + const auto ty = makeType(btype, nbits, field); + const size_t k = ty.as()->k(); + SPU_ENFORCE(nbits <= k, "nbits = {}", nbits); + ArrayRef res(ty, s1.numel() / nbits); + size_t res_numel = res.numel(); + + DISPATCH_ALL_FIELDS(field, "_", [&]() { + auto _res = ArrayView>(res); + auto _s1 = ArrayView(s1); + auto _s2 = ArrayView(s2); + + pforeach(0, res_numel * k, [&](int64_t i) { + _res[i][0] = 0; + _res[i][1] = 0; + }); + + pforeach(0, res_numel, [&](int64_t i) { + pforeach(0, nbits, [&](int64_t j) { + _res[i * k + j][0] = _s1[i * nbits + j]; + _res[i * k + j][1] = _s2[i * nbits + j]; + }); + }); + }); + return res; +} + +ArrayRef getShare(const ArrayRef& in, int64_t share_idx) { SPU_ENFORCE(in.stride() != 0); - return ArrayRef(in.buf(), ty, in.numel(), in.stride() * 2, in.offset()); + SPU_ENFORCE(share_idx == 0 || share_idx == 1); + + if (in.eltype().isa()) { + const auto field = in.eltype().as()->field(); + const auto ty = makeType(field); + return ArrayRef{in.buf(), ty, in.numel(), in.stride() * 2, + in.offset() + share_idx * static_cast(ty.size())}; + } else if (in.eltype().isa()) { + const auto field = in.eltype().as()->field(); + const auto nbits = in.eltype().as()->nbits(); + const auto k = in.eltype().as()->k(); + const auto ty = makeType(field); + + if (nbits == k) { + return ArrayRef{ + in.buf(), ty, in.numel() * static_cast(nbits), + in.stride() * 2, + in.offset() + share_idx * static_cast(ty.size())}; + } else { + ArrayRef ret(ty, in.numel() * static_cast(nbits)); + + DISPATCH_ALL_FIELDS(field, "_", [&]() { + auto _in = ArrayView>(in); + auto _ret = ArrayView(ret); + size_t numel = in.numel(); + pforeach(0, numel, [&](int64_t i) { + pforeach(0, nbits, [&](int64_t j) { + _ret[i * nbits + j] = _in[i * k + j][share_idx]; + }); + }); + }); + + return ret; + } + } else { + SPU_THROW("unsupported type {}", in.eltype()); + } +} + +const ArrayRef getValueShare(const ArrayRef& in) { return getShare(in, 0); } + +const ArrayRef getMacShare(const ArrayRef& in) { return getShare(in, 1); } + +size_t maxNumBits(const ArrayRef& lhs, const ArrayRef& rhs) { + SPU_ENFORCE(lhs.eltype().isa()); + SPU_ENFORCE(rhs.eltype().isa() || rhs.eltype().isa()); + + if (rhs.eltype().isa()) { + return std::max(lhs.eltype().as()->nbits(), + rhs.eltype().as()->nbits()); + } + const auto* rhs_ty = rhs.eltype().as(); + const auto rhs_field = rhs_ty->field(); + return DISPATCH_ALL_FIELDS(rhs_field, "_", [&]() { + using PShrT = ring2k_t; + auto _rhs = ArrayView(rhs); + return std::max(lhs.eltype().as()->nbits(), maxBitWidth(_rhs)); + }); +} + +size_t minNumBits(const ArrayRef& lhs, const ArrayRef& rhs) { + SPU_ENFORCE(lhs.eltype().isa()); + SPU_ENFORCE(rhs.eltype().isa() || rhs.eltype().isa()); + + if (rhs.eltype().isa()) { + return std::min(lhs.eltype().as()->nbits(), + rhs.eltype().as()->nbits()); + } + const auto* rhs_ty = rhs.eltype().as(); + const auto rhs_field = rhs_ty->field(); + return DISPATCH_ALL_FIELDS(rhs_field, "_", [&]() { + using PShrT = ring2k_t; + auto _rhs = ArrayView(rhs); + return std::min(lhs.eltype().as()->nbits(), maxBitWidth(_rhs)); + }); } -ArrayRef getMacShare(const ArrayRef& in) { +// Convert a BShare in new_nbits +// then output only the values and macs of valid bits +std::pair BShareSwitch2Nbits(const ArrayRef& in, + size_t new_nbits) { + const auto old_nbits = in.eltype().as()->nbits(); + if (old_nbits == new_nbits) { + return {getValueShare(in), getMacShare(in)}; + } + + // const size_t p_num = in.numel() / old_nbits; + const size_t p_num = in.numel(); const auto field = in.eltype().as()->field(); - auto ty = makeType(field); + auto out_val = ring_zeros(field, p_num * new_nbits); + auto out_mac = ring_zeros(field, p_num * new_nbits); - SPU_ENFORCE(in.stride() != 0); - return ArrayRef(in.buf(), ty, in.numel(), in.stride() * 2, - in.offset() + static_cast(ty.size())); + auto in_val = getValueShare(in).clone(); + auto in_mac = getMacShare(in).clone(); + auto min_nbits = std::min(old_nbits, new_nbits); + + for (size_t i = 0; i < p_num; ++i) { + auto _in_val = ArrayRef(in_val.buf(), makeType(field), min_nbits, 1, + i * old_nbits * SizeOf(field)); + auto _in_mac = ArrayRef(in_mac.buf(), makeType(field), min_nbits, 1, + i * old_nbits * SizeOf(field)); + + auto _out_val = ArrayRef(out_val.buf(), makeType(field), min_nbits, + 1, i * new_nbits * SizeOf(field)); + auto _out_mac = ArrayRef(out_mac.buf(), makeType(field), min_nbits, + 1, i * new_nbits * SizeOf(field)); + + ring_add_(_out_val, _in_val); + ring_add_(_out_mac, _in_mac); + } + + return {out_val, out_mac}; } -ArrayRef makeAShare(const ArrayRef& s1, const ArrayRef& s2, FieldType field) { - return makeShare(s1, s2, makeType(field)); +PtType calcBShareBacktype(size_t nbits) { + if (nbits <= 8) { + return PT_U8; + } + if (nbits <= 16) { + return PT_U16; + } + if (nbits <= 32) { + return PT_U32; + } + if (nbits <= 64) { + return PT_U64; + } + if (nbits <= 128) { + return PT_U128; + } + SPU_THROW("invalid number of bits={}", nbits); } } // namespace spu::mpc::spdz2k diff --git a/libspu/mpc/spdz2k/value.h b/libspu/mpc/spdz2k/value.h index 39a4a013..2f2a93f5 100644 --- a/libspu/mpc/spdz2k/value.h +++ b/libspu/mpc/spdz2k/value.h @@ -16,6 +16,8 @@ #include "libspu/core/array_ref.h" #include "libspu/core/type_util.h" +#include "libspu/mpc/spdz2k/beaver/beaver_tfp.h" +#include "libspu/mpc/spdz2k/beaver/beaver_tinyot.h" namespace spu::mpc::spdz2k { @@ -33,15 +35,46 @@ namespace spu::mpc::spdz2k { // a[n-1].share0 (n-1)*2*k+0 // a[n-1].share1 (n-1)*2*k+k // -// you can treat spdz2k share as std::complex, where -// real(x) is the first share piece. -// imag(x) is the second share piece. +// you can imagine spdz2k share as std::complex, where +// real(x) is the value share piece. +// imag(x) is the mac share piece. -ArrayRef getValueShare(const ArrayRef& in); +// Different with other protocls! +// Only output values of valid bits for optimal memory usage +const ArrayRef getValueShare(const ArrayRef& in); -ArrayRef getMacShare(const ArrayRef& in); +// Only output macs of valid bits for optimal memory usage +const ArrayRef getMacShare(const ArrayRef& in); -ArrayRef makeAShare(const ArrayRef& s1, const ArrayRef& s2, FieldType field); +ArrayRef makeAShare(const ArrayRef& s1, const ArrayRef& s2, FieldType field, + bool has_mac = true); + +// Different with other protocls! +// input s1: value shares of valid bits +// input s2: mac shares of valid bits +// output: boolean shares of fixed length +ArrayRef makeBShare(const ArrayRef& s1, const ArrayRef& s2, FieldType field, + size_t nbits); + +size_t maxNumBits(const ArrayRef& lhs, const ArrayRef& rhs); +size_t minNumBits(const ArrayRef& lhs, const ArrayRef& rhs); + +size_t minNumBits(const ArrayRef& lhs, const ArrayRef& rhs); + +// Convert a BShare in new_nbits +// then output the corresponding value and mac +std::pair BShareSwitch2Nbits(const ArrayRef& in, + size_t new_nbits); + +PtType calcBShareBacktype(size_t nbits); + +template +size_t maxBitWidth(ArrayView av) { + // TODO: use av.maxBitWidth to improve performance + return sizeof(T) * 8; +} + +ArrayRef getShare(const ArrayRef& in, int64_t share_idx); #define PFOR_GRAIN_SIZE 8192 diff --git a/libspu/mpc/utils/BUILD.bazel b/libspu/mpc/utils/BUILD.bazel index 017ef23b..cce3ca41 100644 --- a/libspu/mpc/utils/BUILD.bazel +++ b/libspu/mpc/utils/BUILD.bazel @@ -86,9 +86,9 @@ spu_cc_library( srcs = ["linalg.cc"], hdrs = ["linalg.h"], deps = [ - "@com_github_eigenteam_eigen//:eigen3", "//libspu/core:parallel_utils", "//libspu/core:prelude", + "@com_github_eigenteam_eigen//:eigen3", ] + select({ "@bazel_tools//src/conditions:darwin_x86_64": ["@local_homebrew_x64//:openmp"], "@bazel_tools//src/conditions:darwin_arm64": ["@local_homebrew_arm64//:openmp"], diff --git a/libspu/mpc/utils/ring_ops.cc b/libspu/mpc/utils/ring_ops.cc index 1a899986..c38a1c5f 100644 --- a/libspu/mpc/utils/ring_ops.cc +++ b/libspu/mpc/utils/ring_ops.cc @@ -155,7 +155,11 @@ void ring_bitmask_impl(ArrayRef& ret, const ArrayRef& x, size_t low, return DISPATCH_ALL_FIELDS(field, kModule, [&]() { using U = ring2k_t; - U mask = (((U)1U << (high - low)) - 1) << low; + U mask = 0; + if (high - low < SizeOf(field) * 8) { + mask = (U)1U << (high - low); + } + mask = (mask - 1) << low; auto mark_fn = [&](U el) { return el & mask; }; diff --git a/libspu/psi/bucket_psi.cc b/libspu/psi/bucket_psi.cc index 6f151293..8e9fcc81 100644 --- a/libspu/psi/bucket_psi.cc +++ b/libspu/psi/bucket_psi.cc @@ -44,8 +44,6 @@ #include "libspu/psi/utils/serialize.h" #include "libspu/psi/utils/utils.h" -#include "interconnection/algos/psi.pb.h" - namespace spu::psi { namespace { diff --git a/libspu/psi/core/BUILD.bazel b/libspu/psi/core/BUILD.bazel index 66ba960f..1ab22927 100644 --- a/libspu/psi/core/BUILD.bazel +++ b/libspu/psi/core/BUILD.bazel @@ -30,7 +30,7 @@ spu_cc_library( cc_proto_library( name = "ic_protocol_psi_cc_proto", - deps = ["@org_interconnection//interconnection/algos:psi"], + deps = ["@org_interconnection//interconnection/runtime:ecdh_psi"], ) spu_cc_library( diff --git a/libspu/psi/core/communication.cc b/libspu/psi/core/communication.cc index 3ba40640..266e9e8b 100644 --- a/libspu/psi/core/communication.cc +++ b/libspu/psi/core/communication.cc @@ -18,7 +18,7 @@ #include "libspu/core/prelude.h" -#include "interconnection/algos/psi.pb.h" +#include "interconnection/runtime/ecdh_psi.pb.h" namespace spu::psi { @@ -46,7 +46,7 @@ std::shared_ptr CreateP2PLinkCtx( } yacl::Buffer IcPsiBatchSerializer::Serialize(PsiDataBatch&& batch) { - org::interconnection::algos::psi::EcdhPsiCipherBatch proto; + org::interconnection::v2::runtime::EcdhPsiCipherBatch proto; proto.set_type(batch.type); proto.set_batch_index(batch.batch_index); proto.set_is_last_batch(batch.is_last_batch); @@ -60,7 +60,7 @@ yacl::Buffer IcPsiBatchSerializer::Serialize(PsiDataBatch&& batch) { } PsiDataBatch IcPsiBatchSerializer::Deserialize(yacl::ByteContainerView buf) { - org::interconnection::algos::psi::EcdhPsiCipherBatch proto; + org::interconnection::v2::runtime::EcdhPsiCipherBatch proto; SPU_ENFORCE(proto.ParseFromArray(buf.data(), buf.size()), "parse EcdhPsiCipherBatch proto fail"); diff --git a/libspu/spu.proto b/libspu/spu.proto index e1e11596..80cdab79 100644 --- a/libspu/spu.proto +++ b/libspu/spu.proto @@ -115,8 +115,7 @@ enum ProtocolKind { CHEETAH = 4; } -// The spu Value proto, used for spu value serialization. -message ValueProto { +message ValueMetaProto { // The data type. DataType data_type = 1; @@ -130,20 +129,15 @@ message ValueProto { // i.e. `aby3.AShr` means an aby3 arithmetic share in FM64. // usually, the application does not care about this attribute. string storage_type = 4; - - // The runtime/protocol dependent value data. - bytes content = 5; } -message ValueMeta { - // The data type. - DataType data_type = 1; - - // The data visibility. - Visibility visibility = 2; - - // The shape of the value. - ShapeProto shape = 3; +// The spu Value proto, used for spu value serialization. +message ValueChunkProto { + // chunk info + uint64 total_bytes = 1; + uint64 chunk_offset = 2; + // chunk bytes + bytes content = 3; } ////////////////////////////////////////////////////////////////////////// @@ -194,9 +188,15 @@ message RuntimeConfig { // When enabled, runtime records detailed pphlo timing data, debug purpose // only. + // WARNING: the `send bytes` information is only accurate when + // `experimental_enable_inter_op_par` and `experimental_enable_intra_op_par` + // options are disabled. bool enable_pphlo_profile = 15; // When enabled, runtime records detailed hal timing data, debug purpose only. + // WARNING: the `send bytes` information is only accurate when + // `experimental_enable_inter_op_par` and `experimental_enable_intra_op_par` + // options are disabled. bool enable_hal_profile = 16; reserved 17, 18; @@ -207,6 +207,10 @@ message RuntimeConfig { // to do with security. uint64 public_random_seed = 19; + // max chunk size for Value::toProto + // default: 128 * 1024 * 1024 + uint64 share_max_chunk_size = 20; + // @exclude // Fixed-point arithmetic related, reserved for [50, 100) @@ -275,9 +279,10 @@ message RuntimeConfig { TrustedFirstParty = 0; // generate beaver triple through an additional trusted third party. TrustedThirdParty = 1; + // generate beaver triple through multi-party. + MultiParty = 2; } - - // beaver config, works for semi2k only for now. + // beaver config, works for semi2k and spdz2k for now. BeaverType beaver_type = 70; // TrustedThirdParty configs. diff --git a/requirements.txt b/requirements.txt index ea2a7a7c..d8df1fac 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,5 @@ protobuf>=3.19, <4 cloudpickle>=2.0.0 multiprocess>=0.70.12.2 cachetools>=5.0.0 -jax>=0.3.15,<=0.4.12 # Last version that supports numpy 1.9 -jaxlib>=0.3.15,<=0.4.12 +jax[cpu]>=0.3.15 # Last version that supports numpy 1.9 termcolor>=2.0.0 diff --git a/spu/__init__.py b/spu/__init__.py index b45876e8..7e3d52a2 100644 --- a/spu/__init__.py +++ b/spu/__init__.py @@ -21,7 +21,6 @@ PtType, ProtocolKind, FieldType, - ValueProto, ShapeProto, RuntimeConfig, ExecutableProto, @@ -41,7 +40,6 @@ "PtType", "ProtocolKind", "FieldType", - "ValueProto", "ShapeProto", "RuntimeConfig", "ExecutableProto", diff --git a/spu/api.py b/spu/api.py index 39b9f227..a0946a18 100644 --- a/spu/api.py +++ b/spu/api.py @@ -44,37 +44,48 @@ def run(self, executable: spu_pb2.ExecutableProto) -> None: """ return self._vm.Run(executable.SerializeToString()) - def set_var(self, name: str, value: bytes) -> None: + def set_var(self, name: str, value: libspu.Share) -> None: """Set an SPU value. Args: name (str): Id of value. - value (spu_pb2.ValueProto): value data. + value (libspu.Share): value data. """ return self._vm.SetVar(name, value) - def get_var(self, name: str) -> bytes: + def get_var(self, name: str) -> libspu.Share: """Get an SPU value. Args: name (str): Id of value. Returns: - spu_pb2.ValueProto: Data data. + libspu.Share: Data data. """ return self._vm.GetVar(name) - def get_var_meta(self, name: str) -> spu_pb2.ValueMeta: + def get_var_chunk_count(self, name: str) -> int: + """Get an SPU value. + + Args: + name (str): Id of value. + + Returns: + int: chunks count in libspu.Share + """ + return self._vm.GetVarChunksCount(name) + + def get_var_meta(self, name: str) -> spu_pb2.ValueMetaProto: """Get an SPU value without content. Args: name (str): Id of value. Returns: - spu_pb2.ValueProto: Data with out content. + spu_pb2.ValueMeta: Data meta with out content. """ - ret = spu_pb2.ValueProto() + ret = spu_pb2.ValueMetaProto() ret.ParseFromString(self._vm.GetVarMeta(name)) return ret @@ -103,9 +114,14 @@ def __init__(self, world_size: int, config: spu_pb2.RuntimeConfig): """ self._io = libspu.IoWrapper(world_size, config.SerializeToString()) + def get_share_chunk_count( + self, x: 'np.ndarray', vtype: spu_pb2.Visibility, owner_rank: int = -1 + ) -> int: + return self._io.GetShareChunkCount(x, vtype, owner_rank) + def make_shares( self, x: 'np.ndarray', vtype: spu_pb2.Visibility, owner_rank: int = -1 - ) -> List[bytes]: + ) -> List[libspu.Share]: """Convert from NumPy array to list of SPU value(s). Args: @@ -114,20 +130,20 @@ def make_shares( owner_rank (int): the index of the trusted piece. if >= 0, colocation optimization may be applied. Returns: - [spu_pb2.ValueProto]: output. + [libspu.Share]: output. """ return self._io.MakeShares(x, vtype, owner_rank) - def reconstruct(self, str_shares: List[bytes]) -> 'np.ndarray': + def reconstruct(self, shares: List[libspu.Share]) -> 'np.ndarray': """Convert from list of SPU value(s) to NumPy array. Args: - xs (spu_pb2.ValueProto]): input. + xs ([libspu.Share]): input. Returns: np.ndarray: output. """ - return self._io.Reconstruct(str_shares) + return self._io.Reconstruct(shares) @cached(cache=LRUCache(maxsize=128)) diff --git a/spu/libspu.cc b/spu/libspu.cc index c8bb1293..73f20a56 100644 --- a/spu/libspu.cc +++ b/spu/libspu.cc @@ -257,6 +257,37 @@ void BindLink(py::module& m) { }); } +struct PyBindShare { + py::bytes meta; + std::vector share_chunks; +}; + +static spu::Value ValueFromPyBindShare(const PyBindShare& py_share) { + spu::ValueProto value; + spu::ValueMetaProto meta; + SPU_ENFORCE(meta.ParseFromString(py_share.meta)); + value.meta.Swap(&meta); + for (const auto& s : py_share.share_chunks) { + spu::ValueChunkProto chunk; + SPU_ENFORCE(chunk.ParseFromString(s)); + value.chunks.emplace_back(std::move(chunk)); + } + return Value::fromProto(value); +} + +static PyBindShare ValueToPyBindShare(const spu::Value& value, + size_t max_chunk_size) { + PyBindShare ret; + + const auto value_pb = value.toProto(max_chunk_size); + ret.meta = value_pb.meta.SerializeAsString(); + ret.share_chunks.reserve(value_pb.chunks.size()); + for (const auto& s : value_pb.chunks) { + ret.share_chunks.emplace_back(s.SerializeAsString()); + } + return ret; +} + // Wrap Runtime, it's workaround for protobuf pybind11/protoc conflict. class RuntimeWrapper { std::unique_ptr sctx_; @@ -264,6 +295,8 @@ class RuntimeWrapper { // the golbals, could be used to cross session stuffs. spu::device::SymbolTable env_; + size_t max_chunk_size_; + public: explicit RuntimeWrapper(const std::shared_ptr& lctx, const std::string& config_pb) { @@ -275,6 +308,10 @@ class RuntimeWrapper { sctx_ = std::make_unique(config, lctx); mpc::Factory::RegisterProtocol(sctx_.get(), lctx); + max_chunk_size_ = config.share_max_chunk_size(); + if (max_chunk_size_ == 0) { + max_chunk_size_ = 128UL * 1024 * 1024; + } } void Run(const py::bytes& exec_pb) { @@ -285,15 +322,16 @@ class RuntimeWrapper { spu::device::execute(&executor, sctx_.get(), exec, &env_); } - void SetVar(const std::string& name, const py::bytes& value) { - ValueProto proto; - SPU_ENFORCE(proto.ParseFromString(value)); + void SetVar(const std::string& name, const PyBindShare& share) { + env_.setVar(name, ValueFromPyBindShare(share)); + } - env_.setVar(name, spu::Value::fromProto(proto)); + PyBindShare GetVar(const std::string& name) const { + return ValueToPyBindShare(env_.getVar(name), max_chunk_size_); } - py::bytes GetVar(const std::string& name) const { - return env_.getVar(name).toProto().SerializeAsString(); + size_t GetVarChunksCount(const std::string& name) { + return env_.getVar(name).chunksCount(max_chunk_size_); } py::bytes GetVarMeta(const std::string& name) const { @@ -368,7 +406,9 @@ constexpr void SizeCheck() { } class IoWrapper { + private: std::unique_ptr ptr_; + size_t max_chunk_size_; public: IoWrapper(size_t world_size, const std::string& config_pb) { @@ -376,10 +416,30 @@ class IoWrapper { SPU_ENFORCE(config.ParseFromString(config_pb)); ptr_ = std::make_unique(world_size, config); + max_chunk_size_ = config.share_max_chunk_size(); + if (max_chunk_size_ == 0) { + max_chunk_size_ = 128UL * 1024 * 1024; + } } - std::vector MakeShares(const py::array& arr, int visibility, - int owner_rank = -1) { + size_t GetShareChunkCount(const py::array& arr, int visibility, + int owner_rank) { + const py::buffer_info& binfo = arr.request(); + const PtType pt_type = PyFormatToPtType(binfo.format); + + spu::PtBufferView view( + binfo.ptr, pt_type, + std::vector(binfo.shape.begin(), binfo.shape.end()), + ByteToElementStrides(binfo.strides.begin(), binfo.strides.end(), + binfo.itemsize)); + const size_t share_size = ptr_->getShareSize( + view, static_cast(visibility), owner_rank); + size_t num_chunks = (share_size + max_chunk_size_ - 1) / max_chunk_size_; + return num_chunks; + } + + std::vector MakeShares(const py::array& arr, int visibility, + int owner_rank = -1) { // When working with Python, do a static size check, this has no runtime // cost SizeCheck(); @@ -393,29 +453,27 @@ class IoWrapper { ByteToElementStrides(binfo.strides.begin(), binfo.strides.end(), binfo.itemsize)); - auto shares = ptr_->makeShares( + const auto shares = ptr_->makeShares( view, static_cast(visibility), owner_rank); - std::vector serialized(shares.size()); - for (size_t idx = 0; idx < shares.size(); ++idx) { - std::string s; - SPU_ENFORCE(shares[idx].toProto().SerializeToString(&s)); - serialized[idx] = py::bytes(s); + + std::vector serialized; + serialized.reserve(shares.size()); + for (const auto& share : shares) { + serialized.emplace_back(ValueToPyBindShare(share, max_chunk_size_)); } return serialized; } - py::array reconstruct(const std::vector& vals) { + py::array Reconstruct(const std::vector& vals) { std::vector shares; SPU_ENFORCE(!vals.empty()); - for (const auto& val_str : vals) { - spu::ValueProto vp; - SPU_ENFORCE(vp.ParseFromString(val_str)); - shares.push_back(spu::Value::fromProto(vp)); + shares.reserve(vals.size()); + for (const auto& val : vals) { + shares.emplace_back(ValueFromPyBindShare(val)); } - // sanity - for (size_t idx = 1; idx < vals.size(); ++idx) { + for (size_t idx = 1; idx < shares.size(); ++idx) { const auto& cur = shares[idx]; const auto& prev = shares[idx - 1]; SPU_ENFORCE(cur.storage_type() == prev.storage_type(), @@ -593,6 +651,19 @@ PYBIND11_MODULE(libspu, m) { } }); + py::class_(m, "Share", "Share in python runtime") + .def(py::init<>(), NO_GIL) + .def_readwrite("share_chunks", &PyBindShare::share_chunks, "share chunks") + .def_readwrite("meta", &PyBindShare::meta, "meta of share") + .def(py::pickle( + [](const PyBindShare& s) { // dump + return py::make_tuple(s.meta, s.share_chunks); + }, + [](const py::tuple& t) { // load + return PyBindShare{t[0].cast(), + t[1].cast>()}; + })); + // bind spu virtual machine. py::class_(m, "RuntimeWrapper", "SPU virtual device") .def(py::init, std::string>(), @@ -604,6 +675,7 @@ PYBIND11_MODULE(libspu, m) { // SetVar & GetVar are using // py::byte, so they must acquire gil... .def("GetVar", &RuntimeWrapper::GetVar) + .def("GetVarChunksCount", &RuntimeWrapper::GetVarChunksCount) .def("GetVarMeta", &RuntimeWrapper::GetVarMeta) .def("DelVar", &RuntimeWrapper::DelVar); @@ -611,7 +683,8 @@ PYBIND11_MODULE(libspu, m) { py::class_(m, "IoWrapper", "SPU VM IO") .def(py::init()) .def("MakeShares", &IoWrapper::MakeShares) - .def("Reconstruct", &IoWrapper::reconstruct); + .def("GetShareChunkCount", &IoWrapper::GetShareChunkCount) + .def("Reconstruct", &IoWrapper::Reconstruct); // bind compiler. m.def( diff --git a/spu/tests/frontend_test.py b/spu/tests/frontend_test.py index 2822c960..6e27ca07 100644 --- a/spu/tests/frontend_test.py +++ b/spu/tests/frontend_test.py @@ -53,13 +53,18 @@ def test_jax_compile_static_args(self): self.assertEqual(executable.name, "test_jax_add") self.assertEqual(executable.input_names, ["in1", "in2", "in3", "in4"]) self.assertEqual(executable.output_names, ["test-out0"]) + print(executable.code.decode()) self.assertTrue( " func.func @main(%arg0: tensor<2x!pphlo.pub>, %arg1: tensor<2x!pphlo.pub>) -> tensor<2x!pphlo.pub> {\n" " %0 = \"pphlo.constant\"() {value = dense<3.000000e+00> : tensor<2xf32>} : () -> tensor<2x!pphlo.pub>\n" " %1 = \"pphlo.convert\"(%arg0) : (tensor<2x!pphlo.pub>) -> tensor<2x!pphlo.pub>\n" " %2 = \"pphlo.add\"(%1, %0) : (tensor<2x!pphlo.pub>, tensor<2x!pphlo.pub>) -> tensor<2x!pphlo.pub>\n" + " \"pphlo.free\"(%0) : (tensor<2x!pphlo.pub>) -> ()\n" + " \"pphlo.free\"(%1) : (tensor<2x!pphlo.pub>) -> ()\n" " %3 = \"pphlo.convert\"(%arg1) : (tensor<2x!pphlo.pub>) -> tensor<2x!pphlo.pub>\n" " %4 = \"pphlo.add\"(%2, %3) : (tensor<2x!pphlo.pub>, tensor<2x!pphlo.pub>) -> tensor<2x!pphlo.pub>\n" + " \"pphlo.free\"(%3) : (tensor<2x!pphlo.pub>) -> ()\n" + " \"pphlo.free\"(%2) : (tensor<2x!pphlo.pub>) -> ()\n" " return %4 : tensor<2x!pphlo.pub>\n" in executable.code.decode() ) self.assertEqual(output.shape, (2,)) diff --git a/spu/tests/spu_io_test.py b/spu/tests/spu_io_test.py index 281d31a6..580c0f83 100644 --- a/spu/tests/spu_io_test.py +++ b/spu/tests/spu_io_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import unittest import numpy as np @@ -24,7 +23,7 @@ def _bytes_to_pb(msg: bytes): - ret = spu_pb2.ValueProto() + ret = spu_pb2.ValueMetaProto() ret.ParseFromString(msg) return ret @@ -37,13 +36,19 @@ def _bytes_to_pb(msg: bytes): spu_pb2.ProtocolKind.ABY3, ), field=(spu_pb2.FieldType.FM64, spu_pb2.FieldType.FM128), + chunk_size=(4, 11, 33, 67, 127, 65535), ) class UnitTests(parameterized.TestCase): - def test_io(self, wsize, prot, field): + def test_io(self, wsize, prot, field, chunk_size): if prot == spu_pb2.ProtocolKind.ABY3 and wsize != 3: return - config = spu_pb2.RuntimeConfig(protocol=prot, field=field, fxp_fraction_bits=18) + config = spu_pb2.RuntimeConfig( + protocol=prot, + field=field, + fxp_fraction_bits=18, + share_max_chunk_size=chunk_size, + ) io = ppapi.Io(wsize, config) # SINT @@ -52,9 +57,11 @@ def test_io(self, wsize, prot, field): xs = io.make_shares(x, spu_pb2.Visibility.VIS_SECRET) self.assertEqual(len(xs), wsize) self.assertEqual( - _bytes_to_pb(xs[0]).shape, + _bytes_to_pb(xs[0].meta).shape, spu_pb2.ShapeProto(dims=(3, 4, 5)), ) + chunk_count = io.get_share_chunk_count(x, spu_pb2.Visibility.VIS_SECRET) + self.assertEqual(len(xs[0].share_chunks), chunk_count) y = io.reconstruct(xs) npt.assert_equal(x, y) @@ -65,9 +72,11 @@ def test_io(self, wsize, prot, field): xs = io.make_shares(x, spu_pb2.Visibility.VIS_SECRET) self.assertEqual(len(xs), wsize) self.assertEqual( - _bytes_to_pb(xs[0]).shape, + _bytes_to_pb(xs[0].meta).shape, spu_pb2.ShapeProto(dims=(3, 4, 5)), ) + chunk_count = io.get_share_chunk_count(x, spu_pb2.Visibility.VIS_SECRET) + self.assertEqual(len(xs[0].share_chunks), chunk_count) y = io.reconstruct(xs) npt.assert_almost_equal(x, y, decimal=5) @@ -78,18 +87,41 @@ def test_io(self, wsize, prot, field): xs = io.make_shares(x, spu_pb2.Visibility.VIS_PUBLIC) self.assertEqual(len(xs), wsize) self.assertEqual( - _bytes_to_pb(xs[0]).shape, + _bytes_to_pb(xs[0].meta).shape, spu_pb2.ShapeProto(dims=(3, 4, 5)), ) + chunk_count = io.get_share_chunk_count(x, spu_pb2.Visibility.VIS_PUBLIC) + self.assertEqual(len(xs[0].share_chunks), chunk_count) + y = io.reconstruct(xs) + + npt.assert_almost_equal(x, y, decimal=5) + + # empty + x = np.random.rand(1, 0) + + xs = io.make_shares(x, spu_pb2.Visibility.VIS_SECRET) + self.assertEqual(len(xs), wsize) + self.assertEqual( + _bytes_to_pb(xs[0].meta).shape, + spu_pb2.ShapeProto(dims=(1, 0)), + ) + chunk_count = io.get_share_chunk_count(x, spu_pb2.Visibility.VIS_SECRET) + + self.assertEqual(len(xs[0].share_chunks), chunk_count) y = io.reconstruct(xs) npt.assert_almost_equal(x, y, decimal=5) - def test_io_strides(self, wsize, prot, field): + def test_io_strides(self, wsize, prot, field, chunk_size): if prot == spu_pb2.ProtocolKind.ABY3 and wsize != 3: return - config = spu_pb2.RuntimeConfig(protocol=prot, field=field, fxp_fraction_bits=18) + config = spu_pb2.RuntimeConfig( + protocol=prot, + field=field, + fxp_fraction_bits=18, + share_max_chunk_size=chunk_size, + ) io = ppapi.Io(wsize, config) # SINT @@ -99,9 +131,11 @@ def test_io_strides(self, wsize, prot, field): xs = io.make_shares(x, spu_pb2.Visibility.VIS_SECRET) self.assertEqual(len(xs), wsize) self.assertEqual( - _bytes_to_pb(xs[0]).shape, + _bytes_to_pb(xs[0].meta).shape, spu_pb2.ShapeProto(dims=(3, 4, 4)), ) + chunk_count = io.get_share_chunk_count(x, spu_pb2.Visibility.VIS_SECRET) + self.assertEqual(len(xs[0].share_chunks), chunk_count) y = io.reconstruct(xs) npt.assert_equal(x, y) @@ -113,10 +147,12 @@ def test_io_strides(self, wsize, prot, field): xs = io.make_shares(x, spu_pb2.Visibility.VIS_SECRET) self.assertEqual(len(xs), wsize) self.assertEqual( - _bytes_to_pb(xs[0]).shape, + _bytes_to_pb(xs[0].meta).shape, spu_pb2.ShapeProto(dims=(3, 4, 4)), ) y = io.reconstruct(xs) + chunk_count = io.get_share_chunk_count(x, spu_pb2.Visibility.VIS_SECRET) + self.assertEqual(len(xs[0].share_chunks), chunk_count) npt.assert_almost_equal(x, y, decimal=5) @@ -127,18 +163,25 @@ def test_io_strides(self, wsize, prot, field): xs = io.make_shares(x, spu_pb2.Visibility.VIS_PUBLIC) self.assertEqual(len(xs), wsize) self.assertEqual( - _bytes_to_pb(xs[0]).shape, + _bytes_to_pb(xs[0].meta).shape, spu_pb2.ShapeProto(dims=(3, 4, 4)), ) y = io.reconstruct(xs) + chunk_count = io.get_share_chunk_count(x, spu_pb2.Visibility.VIS_PUBLIC) + self.assertEqual(len(xs[0].share_chunks), chunk_count) npt.assert_almost_equal(x, y, decimal=5) - def test_io_scalar(self, wsize, prot, field): + def test_io_scalar(self, wsize, prot, field, chunk_size): if prot == spu_pb2.ProtocolKind.ABY3 and wsize != 3: return - config = spu_pb2.RuntimeConfig(protocol=prot, field=field, fxp_fraction_bits=18) + config = spu_pb2.RuntimeConfig( + protocol=prot, + field=field, + fxp_fraction_bits=18, + share_max_chunk_size=chunk_size, + ) io = ppapi.Io(wsize, config) # SINT @@ -146,8 +189,10 @@ def test_io_scalar(self, wsize, prot, field): xs = io.make_shares(x, spu_pb2.Visibility.VIS_SECRET) self.assertEqual(len(xs), wsize) - self.assertEqual(_bytes_to_pb(xs[0]).shape, spu_pb2.ShapeProto(dims=())) + self.assertEqual(_bytes_to_pb(xs[0].meta).shape, spu_pb2.ShapeProto(dims=())) y = io.reconstruct(xs) + chunk_count = io.get_share_chunk_count(x, spu_pb2.Visibility.VIS_SECRET) + self.assertEqual(len(xs[0].share_chunks), chunk_count) npt.assert_equal(x, y) @@ -156,8 +201,10 @@ def test_io_scalar(self, wsize, prot, field): xs = io.make_shares(x, spu_pb2.Visibility.VIS_SECRET) self.assertEqual(len(xs), wsize) - self.assertEqual(_bytes_to_pb(xs[0]).shape, spu_pb2.ShapeProto(dims=())) + self.assertEqual(_bytes_to_pb(xs[0].meta).shape, spu_pb2.ShapeProto(dims=())) y = io.reconstruct(xs) + chunk_count = io.get_share_chunk_count(x, spu_pb2.Visibility.VIS_SECRET) + self.assertEqual(len(xs[0].share_chunks), chunk_count) npt.assert_almost_equal(x, y, decimal=5) @@ -166,8 +213,10 @@ def test_io_scalar(self, wsize, prot, field): xs = io.make_shares(x, spu_pb2.Visibility.VIS_PUBLIC) self.assertEqual(len(xs), wsize) - self.assertEqual(_bytes_to_pb(xs[0]).shape, spu_pb2.ShapeProto(dims=())) + self.assertEqual(_bytes_to_pb(xs[0].meta).shape, spu_pb2.ShapeProto(dims=())) y = io.reconstruct(xs) + chunk_count = io.get_share_chunk_count(x, spu_pb2.Visibility.VIS_PUBLIC) + self.assertEqual(len(xs[0].share_chunks), chunk_count) npt.assert_almost_equal(x, y, decimal=5) diff --git a/spu/utils/distributed.py b/spu/utils/distributed.py index 34a2b13d..cdd2d065 100644 --- a/spu/utils/distributed.py +++ b/spu/utils/distributed.py @@ -516,11 +516,13 @@ def get(self, obj: PYU.Object): class ValueWrapper: """Workarounds for ValueProto could not be pickled.""" - def __init__(self, shape: Sequence[int], dtype: np.dtype, vtype, value_str: str): + def __init__( + self, shape: Sequence[int], dtype: np.dtype, vtype, spu_share: libspu.Share + ): self.shape = shape self.dtype = dtype self.vtype = vtype - self.value_str = value_str + self.spu_share = spu_share def __repr__(self): return f"ValueWrapper({self.shape},{self.dtype},{self.vtype})" @@ -569,7 +571,7 @@ def builtin_spu_run( # do infeed. for idx, arg in enumerate(args_flat): if isinstance(arg, ValueWrapper): - rt.set_var(spu_exec.input_names[idx], arg.value_str) + rt.set_var(spu_exec.input_names[idx], arg.spu_share) else: arg = np.asarray(jax.numpy.asarray(arg)) fst, *_ = io.make_shares(arg, spu_pb2.Visibility.VIS_PUBLIC) @@ -586,9 +588,9 @@ def builtin_spu_run( shape_spu_to_np(value_meta.shape), dtype_spu_to_np(value_meta.data_type), value_meta.visibility, - value_str, + spu_share, ) - for value_str, value_meta in zip(values_str, values_meta) + for spu_share, value_meta in zip(values_str, values_meta) ] # cleanup @@ -976,7 +978,7 @@ def compile( def get(self, obj: SPU.Object): value_wrappers = [nc.get(ref) for nc, ref in zip(self.node_clients, obj.refs)] io = spu_api.Io(len(self.internal_addrs), self.runtime_config) - return io.reconstruct([w.value_str for w in value_wrappers]) + return io.reconstruct([w.spu_share for w in value_wrappers]) def save(self, spu_objects: List[SPU.Object], filename: str): assert ( @@ -1199,13 +1201,7 @@ def reconstruct(wsize: int, spu_config_str: str, shares: List[ValueWrapper]): spu_config.ParseFromString(spu_config_str) spu_io = spu_api.Io(wsize, spu_config) - protos = [] - for share in shares: - proto = spu_pb2.ValueProto() - proto.ParseFromString(share.value_str) - protos.append(proto) - - return spu_io.reconstruct([s.value_str for s in shares]) + return spu_io.reconstruct([share.spu_share for share in shares]) return to(reconstruct)( len(obj.device.node_clients), diff --git a/spu/utils/frontend.py b/spu/utils/frontend.py index c517ff25..9556b6dd 100644 --- a/spu/utils/frontend.py +++ b/spu/utils/frontend.py @@ -70,6 +70,23 @@ def _jax_compilation( ): import jax + from jax._src.xla_bridge import register_backend_factory, _backend_lock, _backends + from jax._src.lib import xla_client + + # Register interpreter backend since we don't want any cpu/gpu/tpu specific optimization + try: + has_interpreter_backend = False + with _backend_lock: + if 'interpreter' in _backends: + has_interpreter_backend = True + + if not has_interpreter_backend: + register_backend_factory( + 'interpreter', xla_client.make_interpreter_client, priority=-100 + ) + finally: + pass # Silent re-register error.... + fn, kwargs = _argnames_partial_except(fn, static_argnames, kwargs) cfn, output = jax.xla_computation( diff --git a/spu/version.py b/spu/version.py index c08bffe2..c455dee4 100644 --- a/spu/version.py +++ b/spu/version.py @@ -13,4 +13,4 @@ # limitations under the License. -__version__ = "0.4.1b0" +__version__ = "0.4.2b0"