From c1a5f3e5349f6280f6e4547bf239277cef2525b4 Mon Sep 17 00:00:00 2001 From: anakinxc <103552181+anakinxc@users.noreply.github.com> Date: Fri, 4 Aug 2023 21:42:22 +0800 Subject: [PATCH] Repo sync (#282) --- CONTRIBUTING.md | 2 +- docs/development/add_protocols.rst | 224 ++-- docs/reference/complexity.md | 3 - docs/reference/np_op_status.md | 526 ++------ docs/reference/pphlo_op_doc.md | 437 +++--- docs/reference/psi_config.md | 103 +- docs/reference/runtime_config.md | 217 +-- docs/reference/xla_status.md | 12 +- examples/cpp/pir/keyword_pir_mem_server.cc | 14 + examples/cpp/pir/keyword_pir_server.cc | 2 +- examples/cpp/pir/keyword_pir_setup.cc | 14 + examples/python/pir/pir_mem_server.py | 8 + examples/python/pir/pir_server.py | 3 +- examples/python/pir/pir_setup.py | 8 + libspu/compiler/passes/lower_mixed_type_op.cc | 19 +- .../compiler/tests/lower_mixed_type_op.mlir | 9 + libspu/core/encoding.cc | 34 +- libspu/core/encoding_test.cc | 12 +- libspu/core/ndarray_ref.cc | 5 +- libspu/core/ndarray_ref.h | 69 +- libspu/core/ndarray_ref_test.cc | 15 +- libspu/core/pt_buffer_view.cc | 2 +- libspu/core/shape.h | 8 +- libspu/core/value.cc | 2 +- libspu/core/xt_helper.h | 8 +- .../device/pphlo/pphlo_executor_test_runner.h | 8 +- libspu/kernel/hlo/indexing_test.cc | 10 +- libspu/mpc/aby3/arithmetic.cc | 280 ++-- libspu/mpc/aby3/boolean.cc | 242 ++-- libspu/mpc/aby3/conversion.cc | 186 +-- libspu/mpc/aby3/io.cc | 34 +- libspu/mpc/aby3/value.h | 3 +- .../mpc/cheetah/arith/cheetah_conv2d_test.cc | 6 +- libspu/mpc/cheetah/arithmetic.cc | 16 +- libspu/mpc/common/prg_state.cc | 14 +- libspu/mpc/common/pv2k.cc | 6 +- libspu/mpc/io_interface.h | 6 +- libspu/mpc/ref2k/ref2k.cc | 10 +- libspu/mpc/semi2k/arithmetic.cc | 44 +- libspu/mpc/semi2k/beaver/beaver_test.cc | 20 +- libspu/mpc/semi2k/boolean.cc | 38 +- libspu/mpc/semi2k/conversion.cc | 17 +- libspu/mpc/spdz2k/abprotocol_spdz2k_test.cc | 12 +- libspu/mpc/spdz2k/arithmetic.cc | 16 +- libspu/mpc/spdz2k/beaver/beaver_test.cc | 18 +- libspu/mpc/spdz2k/beaver/beaver_tfp.cc | 15 +- libspu/mpc/spdz2k/beaver/beaver_tinyot.cc | 102 +- libspu/mpc/spdz2k/boolean.cc | 46 +- libspu/mpc/spdz2k/conversion.cc | 25 +- libspu/mpc/spdz2k/io.cc | 10 +- libspu/mpc/spdz2k/value.cc | 14 +- libspu/mpc/utils/ring_ops.cc | 127 +- libspu/mpc/utils/ring_ops.h | 3 +- libspu/mpc/utils/tiling_util.h | 4 +- libspu/pir/BUILD.bazel | 9 + libspu/pir/pir.cc | 319 +++-- libspu/pir/pir.proto | 9 + libspu/pir/pir_test.cc | 310 +++++ libspu/psi/core/labeled_psi/BUILD.bazel | 20 +- libspu/psi/core/labeled_psi/apsi_bench.cc | 139 +- .../psi/core/labeled_psi/apsi_label_test.cc | 112 +- libspu/psi/core/labeled_psi/apsi_test.cc | 58 +- libspu/psi/core/labeled_psi/kv_test.cc | 9 +- libspu/psi/core/labeled_psi/psi_params.cc | 56 +- libspu/psi/core/labeled_psi/psi_params.h | 3 +- libspu/psi/core/labeled_psi/receiver.cc | 2 +- libspu/psi/core/labeled_psi/sender.cc | 54 +- libspu/psi/core/labeled_psi/sender.h | 4 +- libspu/psi/core/labeled_psi/sender_db.cc | 1167 +---------------- libspu/psi/core/labeled_psi/sender_db.h | 181 +-- libspu/psi/core/labeled_psi/sender_kvdb.cc | 844 ++++++++++++ libspu/psi/core/labeled_psi/sender_kvdb.h | 164 +++ libspu/psi/core/labeled_psi/sender_memdb.cc | 958 ++++++++++++++ libspu/psi/core/labeled_psi/sender_memdb.h | 228 ++++ .../psi/core/labeled_psi/serializable.proto | 2 +- libspu/psi/core/labeled_psi/serialize.h | 25 +- sml/lr/simple_lr.py | 247 ++-- sml/lr/simple_lr_emul.py | 33 +- sml/lr/simple_lr_test.py | 21 +- sml/utils/fxp_approx.py | 63 + spu/libspu.cc | 6 + spu/ops/groupby/BUILD.bazel | 32 + spu/ops/groupby/groupby.py | 255 ++++ spu/ops/groupby/groupby_test.py | 294 +++++ spu/version.py | 2 +- 85 files changed, 5473 insertions(+), 3241 deletions(-) create mode 100644 libspu/pir/pir_test.cc create mode 100644 libspu/psi/core/labeled_psi/sender_kvdb.cc create mode 100644 libspu/psi/core/labeled_psi/sender_kvdb.h create mode 100644 libspu/psi/core/labeled_psi/sender_memdb.cc create mode 100644 libspu/psi/core/labeled_psi/sender_memdb.h create mode 100644 spu/ops/groupby/BUILD.bazel create mode 100644 spu/ops/groupby/groupby.py create mode 100644 spu/ops/groupby/groupby_test.py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index b51aa049..05bb0593 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -58,7 +58,7 @@ docker exec -it spu-dev-$(whoami) bash #### Linux ```sh -Install gcc>=11.2, cmake>=3.18, ninja, nasm>=2.15, python==3.8, bazel==5.4.1, golang +Install gcc>=11.2, cmake>=3.18, ninja, nasm>=2.15, python==3.8, bazel==6.2.1, golang python3 -m pip install -r requirements.txt python3 -m pip install -r requirements-dev.txt diff --git a/docs/development/add_protocols.rst b/docs/development/add_protocols.rst index 91d85afc..31b4b5e2 100644 --- a/docs/development/add_protocols.rst +++ b/docs/development/add_protocols.rst @@ -6,16 +6,16 @@ Adding New MPC Protocols Introduction ------------ -This document is mainly for developers who want to add custom MPC protocols in SPU. -Before reading this document, we recommend that developers have a basic understanding -of the SPU system architecture (i.e., :ref:`/development/compiler.rst`, :ref:`/development/type_system.rst` and :ref:`/development/runtime.rst`) -and the `layout `_ of the SPU code repository. -In short, SPU translates the high-level applications (such as machine learning model training) written in JAX -to an MPC-specific intermediate representation named PPHLO and then dispatches PPHLO operations to the low-level MPC protocols. -In theory, protocol developers only need to implement a basic set of MPC operation APIs to fully use the SPU infrastructure to -run machine learning or data analysis programs. That is to say, for most MPC protocol development, -it only needs to add some source code files into the libspu/mpc folder. -At present, SPU has integrated several protocols such as ABY3 and Cheetah, +This document is mainly for developers who want to add custom MPC protocols in SPU. +Before reading this document, we recommend that developers have a basic understanding +of the SPU system architecture (i.e., :ref:`/development/compiler.rst`, :ref:`/development/type_system.rst` and :ref:`/development/runtime.rst`) +and the `layout `_ of the SPU code repository. +In short, SPU translates the high-level applications (such as machine learning model training) written in JAX +to an MPC-specific intermediate representation named PPHLO and then dispatches PPHLO operations to the low-level MPC protocols. +In theory, protocol developers only need to implement a basic set of MPC operation APIs to fully use the SPU infrastructure to +run machine learning or data analysis programs. That is to say, for most MPC protocol development, +it only needs to add some source code files into the libspu/mpc folder. +At present, SPU has integrated several protocols such as ABY3 and Cheetah, which can be regarded as a guide for protocol implementation. A walk-through guide @@ -24,14 +24,14 @@ We further illustrate the procedures of adding protocols step by step in a **top Add a new protocol kind ~~~~~~~~~~~~~~~~~~~~~~~ -When users launch the SPU backend runtime, they will specify the running MPC protocol kind -through the runtime config. The protocol kinds supported by SPU are enumerations defined -in `spu.proto `_. Thus, -developers must add their protocol kinds in this protobuf file to enable SPU to be aware +When users launch the SPU backend runtime, they will specify the running MPC protocol kind +through the runtime config. The protocol kinds supported by SPU are enumerations defined +in `spu.proto `_. Thus, +developers must add their protocol kinds in this protobuf file to enable SPU to be aware of the new protocols. .. code-block:: protobuf - :caption: ProtocolKind enumerations + :caption: ProtocolKind enumerations enum ProtocolKind { PROT_INVALID = 0; @@ -45,48 +45,55 @@ of the new protocols. CHEETAH = 4; } -Generate protocol object -~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Each MPC protocol execution environment is abstracted as a C++ instance of an `Object `_ -class in SPU. SPU generates an MPC object through a factory function named `CreateContext `_ -according to the runtime config. Therefore, protocol developers need to add their functions in **CreateCompute** to generate protocol objects. +Register protocol +~~~~~~~~~~~~~~~~~ +Each MPC protocol execution environment is abstracted as a C++ instance of an `Object `_ +class in SPU. SPU constructs an MPC object when creating an **SPUContext**. Then, SPU registers a concrete protocol implementation through a factory function +named `RegisterProtocol `_ according to the runtime config. Therefore, protocol developers +need to add their functions in **RegisterProtocol** to implement protocols. .. code-block:: c++ - :caption: CreateCompute function + :caption: RegisterProtocol function - std::unique_ptr Factory::CreateContext( - const RuntimeConfig& conf, - const std::shared_ptr& lctx) { - switch (conf.protocol()) { + void Factory::RegisterProtocol( + SPUContext* ctx, const std::shared_ptr& lctx) { + // TODO: support multi-protocols. + switch (ctx->config().protocol()) { ... ... case ProtocolKind::ABY3: { - return makeAby3Protocol(conf, lctx); + return regAby3Protocol(ctx, lctx); } case ProtocolKind::CHEETAH: { - return makeCheetahProtocol(conf, lctx); + return regCheetahProtocol(ctx, lctx); } default: { - SPU_THROW("Invalid protocol kind {}", conf.protocol()); + SPU_THROW("Invalid protocol kind {}", ctx->config().protocol()); } } - return nullptr; } +Implement protocol IO interface +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Another function called by the factory class is the **CreateIO** function. As different protocols use different secret sharing schemas, +which means SPU has to use different ways to input/output secret data from plaintext data. As a results, developers have to implement these protocol-specific APIs +defined in `io_interface.h `_. +Developers can check the `ABY3 implementation `_ as a reference. + Understand protocol object ~~~~~~~~~~~~~~~~~~~~~~~~~~ SPU protocol `Object `_ may be the key concept for adding new protocols. Let's take a closer look at its design. -The goal of Object class is to realize the generalization and flexibility of developing MPC protocols through dynamic binding. -An Object instance has a series of kernels and states. A kernel and a state can be regarded as a +The goal of **Object** class is to realize the generalization and flexibility of developing MPC protocols through dynamic binding. +An Object instance has a series of kernels and states. A kernel and a state can be regarded as a member function and a member variable of an Object, respectively. .. code-block:: c++ :caption: SPU protocol Object class class Object final { - std::map> kernels_; - std::map> states_; + std::map> kernels_; + std::map> states_; ... public: @@ -97,11 +104,16 @@ member function and a member variable of an Object, respectively. // register customized kernels template - void regKernel(std::string_view name) { + void regKernel() { + regKernel(KernelT::kBindName, std::make_unique()); + } + + template + void regKernel(const std::string& name) { return regKernel(name, std::make_unique()); } - // add customized kernels + // add customized states template void addState(Args&&... args) { addState(StateT::kBindName, @@ -116,83 +128,75 @@ Construct protocol object We take the ABY3 implementation as a specific example to further explain the description above. First of all, we can see that there is an independent aby3 directory under the `libspu/mpc `_ -directory in SPU's repository layout. The aby3 directory includes the C++ source files and header -files required by the ABY3 protocol implementation. These files may be confusing at first glance. +directory in SPU's repository layout. The aby3 directory includes the C++ source files and header +files required by the ABY3 protocol implementation. These files may be confusing at first glance. The key to know its code organization is to open the `protocol `_ -file, which defines the **makeAby3Protocol** function for generating an Object instance. +file, which defines the **regAby3Protocol** function for registering kernels and states. This function will be called by the factory class described in previous step. .. code-block:: c++ - :caption: ABY3 protocol Object instance generation + :caption: ABY3 protocol registration - std::unique_ptr makeAby3Protocol( - const RuntimeConfig& conf, - const std::shared_ptr& lctx) { + void regAby3Protocol(SPUContext* ctx, + const std::shared_ptr& lctx) { // register ABY3 arithmetic shares and boolean shares aby3::registerTypes(); - // instantiate Object instance - auto obj = - std::make_unique(fmt::format("{}-{}", lctx->Rank(), "ABY3")); - // add ABY3 required states - obj->addState(conf.field()); - obj->addState(lctx); - obj->addState(lctx); + ctx->prot()->addState(ctx->config().field()); + ctx->prot()->addState(lctx); + ctx->prot()->addState(lctx); - // register public kernels and api kernels - regPV2kKernels(obj.get()); - regABKernels(obj.get()); + // register public kernels + regPV2kKernels(ctx->prot()); // register arithmetic & binary kernels ... - obj->regKernel(); - obj->regKernel(); - obj->regKernel(); - obj->regKernel(); + ctx->prot()->regKernel(); + ctx->prot()->regKernel(); + ctx->prot()->regKernel(); + ctx->prot()->regKernel(); ... return obj; } -Inside the **makeAby3Protocol** function, it does three things. +Inside the **regAby3Protocol** function, it does three things. - The first is to register the protocol types. These types are defined in the `type.h `_ header file, \ - representing an arithmetic secret share and a boolean secret share, respectively. + representing an arithmetic secret share and a boolean secret share, respectively. - The second is to register protocol states (variables), specifically including the three states of Z2kState, \ Communicator, and PrgState, which are used to store the ring information, communication facilities, and \ - pseudorandom number generator for protocol implementation. + pseudorandom number generator for protocol implementation. -- The third is to register the protocol kernels (functions). We can see that three types of kernels are registered. \ - The first type is the kernels implemented in the `pv2k.cc `_ \ - file, using **Pub2k** as the naming prefix of kernel classes. The second type is the kernels implemented in the \ - `ab_api.cc `_ file, using **ABProt** as the \ - naming prefix of kernel classes. The third type is implemented in `arithmetic.cc `_, \ +- The third is to register the protocol kernels (functions). We can see that two types of kernels are registered. \ + The first type is the common kernels implemented in the `pv2k.cc `_ \ + file. The second type is implemented in `arithmetic.cc `_, \ `boolean.cc `_ and other files under the aby3 directory. Implement protocol kernels ~~~~~~~~~~~~~~~~~~~~~~~~~~ -In this section, we further explain why the ABY3 developer registers these three types of kernels. -In SPU, the interfaces between MPC and HAL layers are defined in the `api.cc `_ -file, which consists of a set of operations with public or secret operands (referred as **basic APIs** for the rest of this document). -As long as a protocol developer implements basic APIs, he/her can use the SPU full-stack infrastructure +In this section, we further explain why the ABY3 developer registers these two types of kernels. +In SPU, the interfaces between MPC and HAL layers are defined in the `api.h `_ +file, which consists of a set of operations with public or secret operands (referred as **basic APIs** for the rest of this document). +As long as a protocol developer implements basic APIs, he/she can use the SPU full-stack infrastructure to run high-level applications, e.g., training complex neural network models. .. code-block:: c++ :caption: Some SPU MPC basic APIs ... - ArrayRef mul_pp(Object* ctx, const ArrayRef&, const ArrayRef&); - ArrayRef mul_sp(Object* ctx, const ArrayRef&, const ArrayRef&); - ArrayRef mul_ss(Object* ctx, const ArrayRef&, const ArrayRef&); - ArrayRef and_pp(Object* ctx, const ArrayRef&, const ArrayRef&); - ArrayRef and_sp(Object* ctx, const ArrayRef&, const ArrayRef&); - ArrayRef and_ss(Object* ctx, const ArrayRef&, const ArrayRef&); + Value mul_pp(SPUContext* ctx, const Value& x, const Value& y); + Value mul_sp(SPUContext* ctx, const Value& x, const Value& y); + Value mul_ss(SPUContext* ctx, const Value& x, const Value& y); + Value and_pp(SPUContext* ctx, const Value& x, const Value& y); + Value and_sp(SPUContext* ctx, const Value& x, const Value& y); + Value and_ss(SPUContext* ctx, const Value& x, const Value& y); ... -Among the basic APIs, some protocols working on Rings share the same logic on some operations processing public operands, -so SPU developers pre-implement these APIs as kernels and place them in the common directory. +Among the basic APIs, some protocols working on Rings share the same logic on some operations processing public operands, +so SPU developers pre-implement these APIs as kernels and place them in the common directory. As a result, the ABY3 developer can directly register these kernels through the **regPV2kKernels** function. .. code-block:: c++ @@ -208,10 +212,8 @@ As a result, the ABY3 developer can directly register these kernels through the ce::CExpr comm() const override { return ce::Const(0); } // protocol implementation - ArrayRef proc(KernelEvalContext* ctx, const ArrayRef& lhs, - const ArrayRef& rhs) const override { - // used for logging and tracing - SPU_TRACE_MPC_LEAF(ctx, lhs, rhs); + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& lhs, + const NdArrayRef& rhs) const override { // sanity check SPU_ENFORCE(lhs.eltype() == rhs.eltype()); return ring_and(lhs, rhs).as(lhs.eltype()); @@ -225,48 +227,21 @@ As a result, the ABY3 developer can directly register these kernels through the ... obj->regKernel(); obj->regKernel(); - // and_pp kernel is implemented as a AndPP class + // and_pp kernel is implemented as an AndPP class obj->regKernel(); obj->regKernel(); ... -Moreover, as we can see that the basic APIs do not have (arithmetic/boolean) secret sharing semantics. -A series of popular protocols (such as ABY3) use arithmetic secret sharing, boolean secret sharing, -and their conversions to achieve secret operations defined in the basic APIs. Therefore, we call this type of protocol **AB protocols**, -and further pre-implement the basic APIs of AB protocols as kernels by dispatching secret operations to -arithmetic/boolean secret sharing operations. For example, the multiplication of two secret operands will be -further decomposed into the multiplication of two arithmetic secret shares or the AND of two boolean secret shares. -As ABY3 is a type of AB protocol, the ABY3 developer can directly call the **regABKernels** function to register these kernels. - -.. code-block:: c++ - :caption: Pre-implemented *mul_ss* kernel for AB protocols - - class ABProtMulSS : public BinaryKernel { - ... - // decompose a secret multiplication into secret share operations - ArrayRef proc(KernelEvalContext* ctx, const ArrayRef& lhs, - const ArrayRef& rhs) const override { - ... - if (_LAZY_AB) { - if (_IsB(lhs) && _NBits(lhs) == 1 && _IsB(rhs) && _NBits(rhs) == 1) { - return _AndBB(lhs, rhs); - } - return _MulAA(_2A(lhs), _2A(rhs)); - } - return _MulAA(lhs, rhs); - } - }; - -Finally, ABY3 protocol-specific operations need to be implemented by developers as kernels to register. -For example, the multiplication of two arithmetic secret shares of ABY3 is implemented as the **MulAA** kernel located in the +Besides, ABY3 protocol-specific operations need to be implemented by developers as kernels to register. +For example, the multiplication of two arithmetic secret shares of ABY3 is implemented as the **MulAA** kernel located in the `arithmetic.cc `_ source file. When kernels are implemented and registered, a new protocol is finally added. .. code-block:: c++ :caption: ABY3 *mul_aa* kernel for arithmetic share multiplication - - ArrayRef MulAA::proc(KernelEvalContext* ctx, const ArrayRef& lhs, - const ArrayRef& rhs) const { + + NdArrayRef MulAA::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, + const NdArrayRef& rhs) const { // get required states const auto field = lhs.eltype().as()->field(); auto* comm = ctx->getState(); @@ -274,23 +249,24 @@ When kernels are implemented and registered, a new protocol is finally added. // dispatch the real implementation to different fields return DISPATCH_ALL_FIELDS(field, "aby3.mulAA", [&]() { - // the real protocol implementation + // the real protocol implementation + ... }); } Testing ~~~~~~~ -After a protocol is added, the developer usually wants to test whether the protocol works as expected. -There are two ways to test the protocol functionality in SPU. The first way is to run python examples. +After a protocol is added, the developer usually wants to test whether the protocol works as expected. +There are two ways to test the protocol functionality in SPU. The first way is to run python examples. SPU has provided users with a series of application `examples `_. -If a protocol fully implements SPU's basic APIs, the developer can run these high-level examples to verify +If a protocol fully implements SPU's basic APIs, the developer can run these high-level examples to verify whether the low-level protocol development is correct. -The second way is to write and run unittest. Some protocols do not cover all the basic APIs and cannot run examples, -or developers only want to test the functionalities of some specific MPC operations (such as addition and multiplication). -In these cases it is more practical to run unittest. SPU developers have construct a general test frameworks in -`api_test.cc `_ and -`ab_api_test.cc `_. -Developers of new protocols need to instantiate these frameworks to test their own protocol functionalities. -Developers can refer to the `protocol_test.cc `_ +The second way is to write and run unittest. Some protocols do not cover all the basic APIs and cannot run examples, +or developers only want to test the functionalities of some specific MPC operations (such as addition and multiplication). +In these cases it is more practical to run unittest. SPU developers have construct a general test frameworks in +`api_test.cc `_ and +`ab_api_test.cc `_. +Developers of new protocols need to instantiate these frameworks to test their own protocol functionalities. +Developers can refer to the `protocol_test.cc `_ file in the aby3 directory to learn how to write their own protocol test files. diff --git a/docs/reference/complexity.md b/docs/reference/complexity.md index 2700815c..6315763d 100644 --- a/docs/reference/complexity.md +++ b/docs/reference/complexity.md @@ -1,5 +1,4 @@ # Semi2k - |kernel | latency | comm | |-------|------------------|---------------------------------| |a2b |(log(K)+1)\*log(N)|(2\*log(K)+1)\*2\*K\*(N-1)\*(N-1)| @@ -18,9 +17,7 @@ |xor_bp |0 |0 | |and_bb |1 |K\*2\*(N-1) | |and_bp |0 |0 | - # Aby3 - |kernel | latency | comm | |-------|----------|--------------| |a2b |log(K)+1+1|log(K)\*K+K\*2| diff --git a/docs/reference/np_op_status.md b/docs/reference/np_op_status.md index 4d1da4f6..7f5b6757 100644 --- a/docs/reference/np_op_status.md +++ b/docs/reference/np_op_status.md @@ -4,26 +4,26 @@ JAX NumPy Operators Status # Overview + SPU recommends users to use JAX as frontend language to write logics. We found most of users would utilize **jax.numpy** modules in their programs. We have conducted tests with some *selected* Operators for reference. Just keep in mind, if you couldn't find a **jax.numpy** operator in this list, it doesn't mean it's not supported. We just haven't test it, e.g. jax.numpy.sort. And we don't test other **JAX** modules at this moment. Please contact us if - - You need to confirm the status of another **jax.numpy** operator not listed here. - You find a **jax.numpy** is not working even it is marked as **PASS**. e.g. The precision is bad. + + # Tested Operators List ## abs -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.abs.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - bool_ @@ -35,13 +35,11 @@ Please check *Supported Dtypes* as well. ## add -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.add.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - bool_ @@ -53,78 +51,66 @@ Please check *Supported Dtypes* as well. ## angle -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.angle.html ### Status Not supported by compiler or runtime. But we could implement on demand in future. Please check *Note* for details. - ### Note -stablehlo.atan2 +stablehlo.atan2 ## angle -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.angle.html ### Status Not supported by compiler or runtime. But we could implement on demand in future. Please check *Note* for details. - ### Note -stablehlo.atan2 +stablehlo.atan2 ## arccos -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.arccos.html ### Status Not supported by compiler or runtime. But we could implement on demand in future. Please check *Note* for details. - ### Note -stablehlo.atan2 +stablehlo.atan2 ## arccosh -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.arccosh.html ### Status Not supported by compiler or runtime. But we could implement on demand in future. Please check *Note* for details. - ### Note -nan +nan ## arcsin -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.arcsin.html ### Status Not supported by compiler or runtime. But we could implement on demand in future. Please check *Note* for details. - ### Note -stablehlo.atan2 +stablehlo.atan2 ## arcsinh -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.arcsinh.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -135,52 +121,44 @@ Please check *Supported Dtypes* as well. ## arctan -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.arctan.html ### Status Not supported by compiler or runtime. But we could implement on demand in future. Please check *Note* for details. - ### Note -stablehlo.atan2 +stablehlo.atan2 ## arctan2 -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.arctan2.html ### Status Not supported by compiler or runtime. But we could implement on demand in future. Please check *Note* for details. - ### Note -stablehlo.atan2 +stablehlo.atan2 ## arctanh -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.arctanh.html ### Status Not supported by compiler or runtime. But we could implement on demand in future. Please check *Note* for details. - ### Note -[-1, 1] nan +[-1, 1] nan ## argmax -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.argmax.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - bool_ @@ -192,13 +170,11 @@ Please check *Supported Dtypes* as well. ## argmin -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.argmin.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - bool_ @@ -210,13 +186,11 @@ Please check *Supported Dtypes* as well. ## array_equal -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.array_equal.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -227,13 +201,11 @@ Please check *Supported Dtypes* as well. ## array_equiv -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.array_equiv.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -244,13 +216,11 @@ Please check *Supported Dtypes* as well. ## atleast_1d -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.atleast_1d.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -261,13 +231,11 @@ Please check *Supported Dtypes* as well. ## atleast_2d -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.atleast_2d.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -278,13 +246,11 @@ Please check *Supported Dtypes* as well. ## atleast_3d -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.atleast_3d.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -295,13 +261,11 @@ Please check *Supported Dtypes* as well. ## bitwise_and -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.bitwise_and.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - int16 @@ -311,13 +275,11 @@ Please check *Supported Dtypes* as well. ## bitwise_not -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.bitwise_not.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - int16 @@ -327,13 +289,11 @@ Please check *Supported Dtypes* as well. ## bitwise_or -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.bitwise_or.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - int16 @@ -343,13 +303,11 @@ Please check *Supported Dtypes* as well. ## bitwise_xor -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.bitwise_xor.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - int16 @@ -359,26 +317,22 @@ Please check *Supported Dtypes* as well. ## cbrt -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.cbrt.html ### Status Not supported by compiler or runtime. But we could implement on demand in future. Please check *Note* for details. - ### Note -stablehlo.cbrt +stablehlo.cbrt ## ceil -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ceil.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -389,13 +343,11 @@ Please check *Supported Dtypes* as well. ## conj -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.conj.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -406,13 +358,11 @@ Please check *Supported Dtypes* as well. ## conjugate -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.conjugate.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -423,39 +373,33 @@ Please check *Supported Dtypes* as well. ## copysign -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.copysign.html ### Status Not supported by compiler or runtime. But we could implement on demand in future. Please check *Note* for details. - ### Note -shift +shift ## cos -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.cos.html ### Status Not supported by compiler or runtime. But we could implement on demand in future. Please check *Note* for details. - ### Note -stablehlo.cosine +stablehlo.cosine ## cosh -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.cosh.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -466,26 +410,22 @@ Please check *Supported Dtypes* as well. ## deg2rad -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.deg2rad.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 ## divide -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.divide.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -496,13 +436,11 @@ Please check *Supported Dtypes* as well. ## divmod -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.divmod.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -513,26 +451,22 @@ Please check *Supported Dtypes* as well. ## ediff1d -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ediff1d.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - int32 ## equal -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.equal.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - bool_ @@ -544,13 +478,11 @@ Please check *Supported Dtypes* as well. ## exp -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.exp.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -561,13 +493,11 @@ Please check *Supported Dtypes* as well. ## exp2 -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.exp2.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -578,13 +508,11 @@ Please check *Supported Dtypes* as well. ## expm1 -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.expm1.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -595,26 +523,22 @@ Please check *Supported Dtypes* as well. ## fabs -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.fabs.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 ## fix -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.fix.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -625,26 +549,22 @@ Please check *Supported Dtypes* as well. ## float_power -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.float_power.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 ## floor -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.floor.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -655,13 +575,11 @@ Please check *Supported Dtypes* as well. ## floor_divide -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.floor_divide.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -672,13 +590,11 @@ Please check *Supported Dtypes* as well. ## fmax -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.fmax.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -689,13 +605,11 @@ Please check *Supported Dtypes* as well. ## fmin -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.fmin.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -706,13 +620,11 @@ Please check *Supported Dtypes* as well. ## fmod -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.fmod.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -723,26 +635,22 @@ Please check *Supported Dtypes* as well. ## gcd -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.gcd.html ### Status Not supported by compiler or runtime. But we could implement on demand in future. Please check *Note* for details. - ### Note -secret while +secret while ## greater -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.greater.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - bool_ @@ -754,13 +662,11 @@ Please check *Supported Dtypes* as well. ## greater_equal -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.greater_equal.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - bool_ @@ -772,13 +678,11 @@ Please check *Supported Dtypes* as well. ## heaviside -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.heaviside.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -789,13 +693,11 @@ Please check *Supported Dtypes* as well. ## hypot -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.hypot.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -806,26 +708,22 @@ Please check *Supported Dtypes* as well. ## i0 -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.i0.html ### Status Result is incorrect. Please check *Note* for details. - ### Note -accuracy +accuracy ## imag -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.imag.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -836,13 +734,11 @@ Please check *Supported Dtypes* as well. ## invert -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.invert.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - int16 @@ -852,13 +748,11 @@ Please check *Supported Dtypes* as well. ## isclose -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.isclose.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - bool_ @@ -870,13 +764,11 @@ Please check *Supported Dtypes* as well. ## iscomplex -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.iscomplex.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -887,66 +779,54 @@ Please check *Supported Dtypes* as well. ## isfinite -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.isfinite.html ### Status Not supported by design. We couldn't fix this in near future. Please check *Note* for details. - ## isinf -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.isinf.html ### Status Not supported by design. We couldn't fix this in near future. Please check *Note* for details. - ## isnan -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.isnan.html ### Status Not supported by design. We couldn't fix this in near future. Please check *Note* for details. - ## isneginf -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.isneginf.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 ## isposinf -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.isposinf.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 ## isreal -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.isreal.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -957,13 +837,11 @@ Please check *Supported Dtypes* as well. ## isrealobj -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.isrealobj.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -974,13 +852,11 @@ Please check *Supported Dtypes* as well. ## kron -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.kron.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -991,39 +867,33 @@ Please check *Supported Dtypes* as well. ## lcm -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.lcm.html ### Status Not supported by compiler or runtime. But we could implement on demand in future. Please check *Note* for details. - ### Note -secret while +secret while ## ldexp -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ldexp.html ### Status Result is incorrect. Please check *Note* for details. - ### Note -IEEE-754 +IEEE-754 ## left_shift -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.left_shift.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - int16 @@ -1033,13 +903,11 @@ Please check *Supported Dtypes* as well. ## less -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.less.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - bool_ @@ -1051,13 +919,11 @@ Please check *Supported Dtypes* as well. ## less_equal -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.less_equal.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - bool_ @@ -1069,13 +935,11 @@ Please check *Supported Dtypes* as well. ## log -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.log.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -1086,13 +950,11 @@ Please check *Supported Dtypes* as well. ## log10 -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.log10.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -1103,13 +965,11 @@ Please check *Supported Dtypes* as well. ## log1p -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.log1p.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -1120,13 +980,11 @@ Please check *Supported Dtypes* as well. ## log1p -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.log1p.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -1137,13 +995,11 @@ Please check *Supported Dtypes* as well. ## log2 -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.log2.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -1154,39 +1010,33 @@ Please check *Supported Dtypes* as well. ## logaddexp -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.logaddexp.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 ## logaddexp2 -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.logaddexp2.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 ## logical_and -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.logical_and.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - bool_ @@ -1198,13 +1048,11 @@ Please check *Supported Dtypes* as well. ## logical_not -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.logical_not.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - bool_ @@ -1216,13 +1064,11 @@ Please check *Supported Dtypes* as well. ## logical_or -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.logical_or.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - bool_ @@ -1234,13 +1080,11 @@ Please check *Supported Dtypes* as well. ## logical_xor -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.logical_xor.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - bool_ @@ -1252,13 +1096,11 @@ Please check *Supported Dtypes* as well. ## maximum -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.maximum.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - bool_ @@ -1270,13 +1112,11 @@ Please check *Supported Dtypes* as well. ## mean -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.mean.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -1287,13 +1127,11 @@ Please check *Supported Dtypes* as well. ## minimum -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.minimum.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - bool_ @@ -1305,13 +1143,11 @@ Please check *Supported Dtypes* as well. ## mod -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.mod.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -1322,13 +1158,11 @@ Please check *Supported Dtypes* as well. ## modf -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.modf.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -1339,13 +1173,11 @@ Please check *Supported Dtypes* as well. ## multiply -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.multiply.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - bool_ @@ -1357,58 +1189,46 @@ Please check *Supported Dtypes* as well. ## nanargmax -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.nanargmax.html ### Status Not supported by design. We couldn't fix this in near future. Please check *Note* for details. - ## nanargmin -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.nanargmin.html ### Status Not supported by design. We couldn't fix this in near future. Please check *Note* for details. - ## nanmean -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.nanmean.html ### Status Not supported by design. We couldn't fix this in near future. Please check *Note* for details. - ## nanprod -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.nanprod.html ### Status Not supported by design. We couldn't fix this in near future. Please check *Note* for details. - ## nansum -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.nansum.html ### Status Not supported by design. We couldn't fix this in near future. Please check *Note* for details. - ## negative -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.negative.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -1419,26 +1239,22 @@ Please check *Supported Dtypes* as well. ## nextafter -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.nextafter.html ### Status Result is incorrect. Please check *Note* for details. - ### Note -IEEE-754 +IEEE-754 ## not_equal -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.not_equal.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - bool_ @@ -1450,13 +1266,11 @@ Please check *Supported Dtypes* as well. ## outer -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.outer.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -1467,13 +1281,11 @@ Please check *Supported Dtypes* as well. ## polyval -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.polyval.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -1484,13 +1296,11 @@ Please check *Supported Dtypes* as well. ## positive -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.positive.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -1501,13 +1311,11 @@ Please check *Supported Dtypes* as well. ## power -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.power.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -1518,13 +1326,11 @@ Please check *Supported Dtypes* as well. ## prod -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.prod.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - bool_ @@ -1536,26 +1342,22 @@ Please check *Supported Dtypes* as well. ## rad2deg -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.rad2deg.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 ## ravel -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ravel.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - bool_ @@ -1567,13 +1369,11 @@ Please check *Supported Dtypes* as well. ## real -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.real.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -1584,26 +1384,22 @@ Please check *Supported Dtypes* as well. ## reciprocal -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.reciprocal.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 ## remainder -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.remainder.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -1614,13 +1410,11 @@ Please check *Supported Dtypes* as well. ## right_shift -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.right_shift.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - int16 @@ -1630,26 +1424,22 @@ Please check *Supported Dtypes* as well. ## rint -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.rint.html ### Status Not supported by compiler or runtime. But we could implement on demand in future. Please check *Note* for details. - ### Note -stablehlo.round_nearest_even +stablehlo.round_nearest_even ## sign -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.sign.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -1660,13 +1450,11 @@ Please check *Supported Dtypes* as well. ## signbit -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.signbit.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -1677,52 +1465,44 @@ Please check *Supported Dtypes* as well. ## sin -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.sin.html ### Status Not supported by compiler or runtime. But we could implement on demand in future. Please check *Note* for details. - ### Note -stablehlo.sine +stablehlo.sine ## sinc -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.sinc.html ### Status Not supported by compiler or runtime. But we could implement on demand in future. Please check *Note* for details. - ### Note -stablehlo.sine +stablehlo.sine ## sinh -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.sinh.html ### Status Not supported by compiler or runtime. But we could implement on demand in future. Please check *Note* for details. - ### Note -stablehlo.cosine +stablehlo.cosine ## sqrt -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.sqrt.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -1733,13 +1513,11 @@ Please check *Supported Dtypes* as well. ## square -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.square.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -1750,13 +1528,11 @@ Please check *Supported Dtypes* as well. ## subtract -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.subtract.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -1767,13 +1543,11 @@ Please check *Supported Dtypes* as well. ## sum -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.sum.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - int16 @@ -1783,26 +1557,22 @@ Please check *Supported Dtypes* as well. ## tan -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.tan.html ### Status Not supported by compiler or runtime. But we could implement on demand in future. Please check *Note* for details. - ### Note -stablehlo.sine +stablehlo.sine ## tanh -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.tanh.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -1813,13 +1583,11 @@ Please check *Supported Dtypes* as well. ## transpose -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.transpose.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - bool_ @@ -1831,13 +1599,11 @@ Please check *Supported Dtypes* as well. ## true_divide -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.true_divide.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -1848,13 +1614,11 @@ Please check *Supported Dtypes* as well. ## trunc -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.trunc.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 @@ -1865,13 +1629,11 @@ Please check *Supported Dtypes* as well. ## unwrap -JAX NumPy Document link: - +JAX NumPy Document link: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.unwrap.html ### Status **PASS** Please check *Supported Dtypes* as well. - ### Supported Dtypes - float32 diff --git a/docs/reference/pphlo_op_doc.md b/docs/reference/pphlo_op_doc.md index c3e3296a..88753d0e 100644 --- a/docs/reference/pphlo_op_doc.md +++ b/docs/reference/pphlo_op_doc.md @@ -6,7 +6,7 @@ _Pad operator_ Pads the edges of `operand` with the `padding_value` and according to the passed configuration. -See . +See https://www.tensorflow.org/xla/operation_semantics#pad. Traits: AlwaysSpeculatableImplTrait, InferTensorType, SameOperandsAndResultElementType @@ -14,7 +14,7 @@ Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, InferTypeOpIn Effects: MemoryEffects::Effect{} -#### Attributes +#### Attributes: | Attribute | MLIR Type | Description | | :-------: | :-------: | ----------- | @@ -22,14 +22,14 @@ Effects: MemoryEffects::Effect{} | `edge_padding_high` | ::mlir::DenseIntElementsAttr | 64-bit signless integer elements attribute | `interior_padding` | ::mlir::DenseIntElementsAttr | 64-bit signless integer elements attribute -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `operand` | statically shaped tensor of PPHlo public type or PPHlo secret type values | `padding_value` | statically shaped tensor of PPHlo public type or PPHlo secret type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -42,7 +42,7 @@ _Absolute value operator_ Returns `abs(operand)` element-wise. See -. +https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. Traits: AlwaysSpeculatableImplTrait, Elementwise, SameOperandsAndResultShape, SameOperandsAndResultType @@ -50,13 +50,13 @@ Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (Mem Effects: MemoryEffects::Effect{} -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `operand` | statically shaped tensor of PPHlo public type or PPHlo secret type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -69,7 +69,7 @@ _Addition operator_ Returns `lhs + rhs` element-wise. See -. +https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. Traits: AlwaysSpeculatableImplTrait, Commutative, Elementwise, SameOperandsAndResultShape @@ -77,14 +77,14 @@ Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `lhs` | statically shaped tensor of PPHlo public type or PPHlo secret type values | `rhs` | statically shaped tensor of PPHlo public type or PPHlo secret type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -97,7 +97,7 @@ _Logical and_ Returns `logical_and(lhs, rhs)` element-wise. See -. +https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. Traits: AlwaysSpeculatableImplTrait, Commutative, Elementwise, SameOperandsAndResultShape @@ -105,14 +105,14 @@ Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `lhs` | statically shaped tensor of public integer type or secret integer type values | `rhs` | statically shaped tensor of public integer type or secret integer type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -130,7 +130,7 @@ Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} -#### Attributes +#### Attributes: | Attribute | MLIR Type | Description | | :-------: | :-------: | ----------- | @@ -141,13 +141,13 @@ Effects: MemoryEffects::Effect{} | `padding` | ::mlir::DenseIntElementsAttr | 64-bit signless integer elements attribute | `onehot_index` | ::mlir::BoolAttr | bool attribute -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `input` | statically shaped tensor of PPHlo public type or PPHlo secret type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -164,7 +164,7 @@ Similar to a 'tf.bitcast' in TensorFlow, performs an element-wise bitcast low-level cast, so machines with different floating-point representations will give different results. - See . + See https://www.tensorflow.org/xla/operation_semantics#bitcastconverttype. Traits: AlwaysSpeculatableImplTrait @@ -172,13 +172,13 @@ Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `operand` | statically shaped tensor of PPHlo public type or PPHlo secret type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -201,7 +201,7 @@ broadcast from must be size 1 (degenerate broadcasting). For a scalar (0D tensor) operand, `broadcast_dimensions` must be empty. The The scalar value will be broadcast to every element in the target shape. -See . +See https://www.tensorflow.org/xla/broadcasting. Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultElementType @@ -209,19 +209,19 @@ Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} -#### Attributes +#### Attributes: | Attribute | MLIR Type | Description | | :-------: | :-------: | ----------- | | `broadcast_dimensions` | ::mlir::DenseIntElementsAttr | 64-bit signless integer elements attribute -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `operand` | statically shaped tensor of PPHlo public type or PPHlo secret type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -242,13 +242,13 @@ of index, otherwise all branches will be executed and select results based on in Traits: RecursiveMemoryEffects, SingleBlockImplicitTerminator -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `index` | statically shaped tensor of public integer type or secret integer type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -261,7 +261,7 @@ _Ceil operator_ Returns `Ceil(operand)` element-wise. See -. +https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. Traits: AlwaysSpeculatableImplTrait, Elementwise, SameOperandsAndResultShape, SameOperandsAndResultType @@ -269,13 +269,13 @@ Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (Mem Effects: MemoryEffects::Effect{} -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `operand` | statically shaped tensor of public fixed-point type or secret fixed-point type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -291,7 +291,7 @@ Note: All three arrays must be the same shape. Alternatively, as a restricted form of broadcasting, min and/or max can be a scalar (0D tensor) of the element type of the tensor operand. -See . +See https://www.tensorflow.org/xla/operation_semantics#clamp. Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultShape @@ -299,7 +299,7 @@ Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | @@ -307,7 +307,7 @@ Effects: MemoryEffects::Effect{} | `operand` | statically shaped tensor of PPHlo public type or PPHlo secret type values | `max` | statically shaped tensor of PPHlo public type or PPHlo secret type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -319,7 +319,7 @@ _Concatenate op_ Concatenates a set of tensors along the specified dimension. -See . +See https://www.tensorflow.org/xla/operation_semantics#concatenate. Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultElementType @@ -327,19 +327,19 @@ Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (Mem Effects: MemoryEffects::Effect{} -#### Attributes +#### Attributes: | Attribute | MLIR Type | Description | | :-------: | :-------: | ----------- | | `dimension` | ::mlir::IntegerAttr | 64-bit signless integer attribute -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `val` | statically shaped tensor of PPHlo public type or PPHlo secret type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -356,13 +356,13 @@ Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} -#### Attributes +#### Attributes: | Attribute | MLIR Type | Description | | :-------: | :-------: | ----------- | | `value` | ::mlir::ElementsAttr | constant vector/tensor attribute -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -376,7 +376,7 @@ Performs element-wise conversion of values from one type to another, e.g.fxp to int. See -. +https://www.tensorflow.org/xla/operation_semantics#convertelementtype. Traits: AlwaysSpeculatableImplTrait, Elementwise, SameOperandsAndResultShape @@ -384,13 +384,13 @@ Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `operand` | statically shaped tensor of PPHlo public type or PPHlo secret type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -400,6 +400,7 @@ Effects: MemoryEffects::Effect{} _Convolution operator_ + Syntax: ``` @@ -411,7 +412,7 @@ operation ::= `pphlo.convolution` `(`operands`)` Computes a convolution of the kind used in neural networks. -See . +See https://www.tensorflow.org/xla/operation_semantics#conv_convolution. Traits: AlwaysSpeculatableImplTrait @@ -419,7 +420,7 @@ Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} -#### Attributes +#### Attributes: | Attribute | MLIR Type | Description | | :-------: | :-------: | ----------- | @@ -428,14 +429,14 @@ Effects: MemoryEffects::Effect{} | `feature_group_count` | ::mlir::IntegerAttr | 64-bit signless integer attribute | `batch_group_count` | ::mlir::IntegerAttr | 64-bit signless integer attribute -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `lhs` | statically shaped tensor of PPHlo public type or PPHlo secret type values | `rhs` | statically shaped tensor of PPHlo public type or PPHlo secret type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -445,6 +446,7 @@ Effects: MemoryEffects::Effect{} _CustomCall operation_ + Syntax: ``` @@ -456,10 +458,9 @@ Encapsulates an implementation-defined operation `call_target_name` that takes `inputs` and `called_computations` and produces `results`. See: - +https://github.com/openxla/stablehlo/blob/main/docs/spec.md#custom_call Example: - ```mlir %results = stablehlo.custom_call @foo(%input0) { backend_config = "bar", @@ -469,20 +470,20 @@ Example: Interfaces: MemoryEffectOpInterface -#### Attributes +#### Attributes: | Attribute | MLIR Type | Description | | :-------: | :-------: | ----------- | | `call_target_name` | ::mlir::StringAttr | string attribute | `has_side_effect` | ::mlir::BoolAttr | bool attribute -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `inputs` | statically shaped tensor of PPHlo public type or PPHlo secret type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -490,7 +491,7 @@ Interfaces: MemoryEffectOpInterface ### `pphlo.dbg_print` (pphlo::DbgPrintOp) -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | @@ -503,7 +504,7 @@ _Division operator_ Returns `lhs / rhs` element-wise. See -. +https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. Traits: AlwaysSpeculatableImplTrait, Elementwise, SameOperandsAndResultShape @@ -511,14 +512,14 @@ Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `lhs` | statically shaped tensor of PPHlo public type or PPHlo secret type values | `rhs` | statically shaped tensor of PPHlo public type or PPHlo secret type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -531,7 +532,7 @@ _General Dot operator_ Performs general dot products between vectors, vector/matrix and matrix/matrix multiplication. -See . +See https://www.tensorflow.org/xla/operation_semantics#dotgeneral. Traits: AlwaysSpeculatableImplTrait @@ -539,20 +540,20 @@ Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} -#### Attributes +#### Attributes: | Attribute | MLIR Type | Description | | :-------: | :-------: | ----------- | | `dot_dimension_numbers` | ::mlir::pphlo::DotDimensionNumbersAttr | Attribute that models the dimension information for dot. -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `lhs` | statically shaped tensor of PPHlo public type or PPHlo secret type values | `rhs` | statically shaped tensor of PPHlo public type or PPHlo secret type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -565,7 +566,7 @@ _Dot operator_ Performs dot products between vectors, vector/matrix and matrix/matrix multiplication. -See . +See https://www.tensorflow.org/xla/operation_semantics#dot. Traits: AlwaysSpeculatableImplTrait @@ -573,14 +574,14 @@ Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `lhs` | statically shaped tensor of PPHlo public type or PPHlo secret type values | `rhs` | statically shaped tensor of PPHlo public type or PPHlo secret type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -592,7 +593,7 @@ _Dynamic Slice operator_ Extracts a sub-array from the input array at dynamic start_indices. -See . +See https://www.tensorflow.org/xla/operation_semantics#dynamicslice. Traits: AlwaysSpeculatableImplTrait @@ -600,20 +601,20 @@ Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (Mem Effects: MemoryEffects::Effect{} -#### Attributes +#### Attributes: | Attribute | MLIR Type | Description | | :-------: | :-------: | ----------- | | `slice_sizes` | ::mlir::DenseIntElementsAttr | 64-bit signless integer elements attribute -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `operand` | statically shaped tensor of PPHlo public type or PPHlo secret type values | `start_indices` | 0D tensor of public integer type or secret integer type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -626,7 +627,7 @@ _Dynamic Update Slice operator_ DynamicUpdateSlice generates a result which is the value of the input array operand, with a slice update overwritten at start_indices. -See . +See https://www.tensorflow.org/xla/operation_semantics#dynamicupdateslice. Traits: AlwaysSpeculatableImplTrait @@ -634,7 +635,7 @@ Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (Mem Effects: MemoryEffects::Effect{} -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | @@ -642,7 +643,7 @@ Effects: MemoryEffects::Effect{} | `update` | statically shaped tensor of PPHlo public type or PPHlo secret type values | `start_indices` | 0D tensor of public integer type or secret integer type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -659,7 +660,7 @@ Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -672,7 +673,7 @@ _Equal comparison operator_ Returns `lhs` == `rhs` element-wise. See -. +https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations. Traits: AlwaysSpeculatableImplTrait, Elementwise, SameOperandsAndResultShape @@ -680,14 +681,14 @@ Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `lhs` | statically shaped tensor of PPHlo public type or PPHlo secret type values | `rhs` | statically shaped tensor of PPHlo public type or PPHlo secret type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -700,7 +701,7 @@ _Exponential operator_ Returns `e ^ (operand)` element-wise. See -. +https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. Traits: AlwaysSpeculatableImplTrait, Elementwise, SameOperandsAndResultShape, SameOperandsAndResultType @@ -708,13 +709,13 @@ Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (Mem Effects: MemoryEffects::Effect{} -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `operand` | statically shaped tensor of public fixed-point type or secret fixed-point type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -727,7 +728,7 @@ _Exponential minus one operator_ Returns `e^(operand) - 1` element-wise. See -. +https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. Traits: AlwaysSpeculatableImplTrait, Elementwise, SameOperandsAndResultShape, SameOperandsAndResultType @@ -735,13 +736,13 @@ Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (Mem Effects: MemoryEffects::Effect{} -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `operand` | statically shaped tensor of public fixed-point type or secret fixed-point type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -754,7 +755,7 @@ _Floor operator_ Returns `Floor(operand)` element-wise. See -. +https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. Traits: AlwaysSpeculatableImplTrait, Elementwise, SameOperandsAndResultShape, SameOperandsAndResultType @@ -762,13 +763,13 @@ Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (Mem Effects: MemoryEffects::Effect{} -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `operand` | statically shaped tensor of public fixed-point type or secret fixed-point type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -776,7 +777,7 @@ Effects: MemoryEffects::Effect{} ### `pphlo.free` (pphlo::FreeOp) -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | @@ -789,7 +790,7 @@ _Gather operator_ Stitches together several slices of `operand` from offsets specified in `start_indices` (each slice at a potentially different runtime offset). -See . +See https://www.tensorflow.org/xla/operation_semantics#gather. Traits: AlwaysSpeculatableImplTrait @@ -797,7 +798,7 @@ Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} -#### Attributes +#### Attributes: | Attribute | MLIR Type | Description | | :-------: | :-------: | ----------- | @@ -805,14 +806,14 @@ Effects: MemoryEffects::Effect{} | `slice_sizes` | ::mlir::DenseIntElementsAttr | 64-bit signless integer elements attribute | `indices_are_sorted` | ::mlir::BoolAttr | bool attribute -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `operand` | statically shaped tensor of PPHlo public type or PPHlo secret type values | `start_indices` | statically shaped tensor of public integer type or secret integer type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -825,7 +826,7 @@ _Greater_equal comparison operator_ Returns `lhs` >= `rhs` element-wise. See -. +https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations. Traits: AlwaysSpeculatableImplTrait, Elementwise, SameOperandsAndResultShape @@ -833,14 +834,14 @@ Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `lhs` | statically shaped tensor of PPHlo public type or PPHlo secret type values | `rhs` | statically shaped tensor of PPHlo public type or PPHlo secret type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -857,14 +858,14 @@ Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `lhs` | statically shaped tensor of PPHlo public type or PPHlo secret type values | `rhs` | statically shaped tensor of PPHlo public type or PPHlo secret type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -877,17 +878,17 @@ _If operator_ Returns the result of executing either a true or false function depending on the result of a condition function. -See . +See https://www.tensorflow.org/xla/operation_semantics#conditional. Traits: RecursiveMemoryEffects, SingleBlockImplicitTerminator -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `condition` | statically shaped tensor of public integer type or secret integer type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -899,7 +900,7 @@ _Iota operator_ Creates a rank 1 array of values starting at zero and incrementing by one. -See +See https://www.tensorflow.org/xla/operation_semantics#iota Traits: AlwaysSpeculatableImplTrait @@ -907,13 +908,13 @@ Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} -#### Attributes +#### Attributes: | Attribute | MLIR Type | Description | | :-------: | :-------: | ----------- | | `iota_dimension` | ::mlir::IntegerAttr | 64-bit signless integer attribute -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -930,14 +931,14 @@ Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `lhs` | statically shaped tensor of PPHlo public type or PPHlo secret type values | `rhs` | statically shaped tensor of PPHlo public type or PPHlo secret type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -950,7 +951,7 @@ _Less comparison operator_ Returns `lhs` < `rhs` element-wise. See -. +https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations. Traits: AlwaysSpeculatableImplTrait, Elementwise, SameOperandsAndResultShape @@ -958,14 +959,14 @@ Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `lhs` | statically shaped tensor of PPHlo public type or PPHlo secret type values | `rhs` | statically shaped tensor of PPHlo public type or PPHlo secret type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -978,7 +979,7 @@ _Log1p operator_ Returns `log(operand + 1)` element-wise. See -. +https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. Traits: AlwaysSpeculatableImplTrait, Elementwise, SameOperandsAndResultShape, SameOperandsAndResultType @@ -986,13 +987,13 @@ Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (Mem Effects: MemoryEffects::Effect{} -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `operand` | statically shaped tensor of public fixed-point type or secret fixed-point type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -1005,7 +1006,7 @@ _Log operator_ Returns `log(operand)` element-wise. See -. +https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. Traits: AlwaysSpeculatableImplTrait, Elementwise, SameOperandsAndResultShape, SameOperandsAndResultType @@ -1013,13 +1014,13 @@ Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (Mem Effects: MemoryEffects::Effect{} -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `operand` | statically shaped tensor of public fixed-point type or secret fixed-point type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -1032,7 +1033,7 @@ _Reciprocal operator_ Returns `logistic(operand)` element-wise. See -. +https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. Traits: AlwaysSpeculatableImplTrait, Elementwise, SameOperandsAndResultShape, SameOperandsAndResultType @@ -1040,13 +1041,13 @@ Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (Mem Effects: MemoryEffects::Effect{} -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `operand` | statically shaped tensor of public fixed-point type or secret fixed-point type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -1059,7 +1060,7 @@ _Maximum operator_ Returns `max(lhs, rhs)` element-wise. See -. +https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. Traits: AlwaysSpeculatableImplTrait, Commutative, Elementwise, SameOperandsAndResultShape @@ -1067,14 +1068,14 @@ Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `lhs` | statically shaped tensor of PPHlo public type or PPHlo secret type values | `rhs` | statically shaped tensor of PPHlo public type or PPHlo secret type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -1088,7 +1089,7 @@ Generates a result which is the value of the input array `operand`, with several slices (at indices specified by `scatter_indices`) updated with the values in `updates` using `update_computation`. -See . +See https://www.tensorflow.org/xla/operation_semantics#scatter. Traits: AlwaysSpeculatableImplTrait @@ -1096,7 +1097,7 @@ Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} -#### Attributes +#### Attributes: | Attribute | MLIR Type | Description | | :-------: | :-------: | ----------- | @@ -1104,14 +1105,14 @@ Effects: MemoryEffects::Effect{} | `window_strides` | ::mlir::DenseIntElementsAttr | 64-bit signless integer elements attribute | `padding` | ::mlir::DenseIntElementsAttr | 64-bit signless integer elements attribute -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `scatter_indices` | statically shaped tensor of public integer type or secret integer type values | `update` | statically shaped tensor of PPHlo public type or PPHlo secret type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -1124,7 +1125,7 @@ _Minimum operator_ Returns `min(lhs, rhs)` element-wise. See -. +https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. Traits: AlwaysSpeculatableImplTrait, Commutative, Elementwise, SameOperandsAndResultShape @@ -1132,14 +1133,14 @@ Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `lhs` | statically shaped tensor of PPHlo public type or PPHlo secret type values | `rhs` | statically shaped tensor of PPHlo public type or PPHlo secret type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -1152,7 +1153,7 @@ _Multiplication operator_ Returns `lhs * rhs` element-wise. See -. +https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. Traits: AlwaysSpeculatableImplTrait, Commutative, Elementwise, SameOperandsAndResultShape @@ -1160,14 +1161,14 @@ Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `lhs` | statically shaped tensor of PPHlo public type or PPHlo secret type values | `rhs` | statically shaped tensor of PPHlo public type or PPHlo secret type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -1181,7 +1182,7 @@ Returns `- operand` element - wise. See -. +https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. Traits: AlwaysSpeculatableImplTrait, Elementwise, SameOperandsAndResultShape, SameOperandsAndResultType @@ -1189,13 +1190,13 @@ Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (Mem Effects: MemoryEffects::Effect{} -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `operand` | statically shaped tensor of PPHlo public type or PPHlo secret type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -1208,7 +1209,7 @@ _Not-equal comparison operator_ Returns `lhs` != `rhs` element-wise. See -. +https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations. Traits: AlwaysSpeculatableImplTrait, Elementwise, SameOperandsAndResultShape @@ -1216,14 +1217,14 @@ Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `lhs` | statically shaped tensor of PPHlo public type or PPHlo secret type values | `rhs` | statically shaped tensor of PPHlo public type or PPHlo secret type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -1236,7 +1237,7 @@ _Not operator_ Returns `!operand` element-wise. See -. +https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. Traits: AlwaysSpeculatableImplTrait, Elementwise, SameOperandsAndResultShape, SameOperandsAndResultType @@ -1244,13 +1245,13 @@ Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (Mem Effects: MemoryEffects::Effect{} -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `operand` | statically shaped tensor of public integer type or secret integer type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -1263,7 +1264,7 @@ _Logical or_ Returns `logical_or(lhs, rhs)` element-wise. See -. +https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. Traits: AlwaysSpeculatableImplTrait, Commutative, Elementwise, SameOperandsAndResultShape @@ -1271,14 +1272,14 @@ Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `lhs` | statically shaped tensor of public integer type or secret integer type values | `rhs` | statically shaped tensor of public integer type or secret integer type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -1291,7 +1292,7 @@ _Power operator_ Returns `lhs ^ rhs` element-wise. See -. +https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. Traits: AlwaysSpeculatableImplTrait, Elementwise, SameOperandsAndResultShape @@ -1299,14 +1300,14 @@ Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `lhs` | statically shaped tensor of PPHlo public type or PPHlo secret type values | `rhs` | statically shaped tensor of PPHlo public type or PPHlo secret type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -1324,13 +1325,13 @@ Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (Mem Effects: MemoryEffects::Effect{} -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `operand` | statically shaped tensor of PPHlo public type or PPHlo secret type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -1343,7 +1344,7 @@ _Reciprocal operator_ Returns `1.0 / operand` element-wise. See -. +https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. Traits: AlwaysSpeculatableImplTrait, Elementwise, SameOperandsAndResultShape, SameOperandsAndResultType @@ -1351,13 +1352,13 @@ Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (Mem Effects: MemoryEffects::Effect{} -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `operand` | statically shaped tensor of public fixed-point type or secret fixed-point type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -1370,24 +1371,24 @@ _Reduce operator_ Returns the result of executing a reduction function on one or more arrays in parallel. -See . +See https://www.tensorflow.org/xla/operation_semantics#reduce. Traits: RecursiveMemoryEffects, SameVariadicOperandSize, SingleBlockImplicitTerminator -#### Attributes +#### Attributes: | Attribute | MLIR Type | Description | | :-------: | :-------: | ----------- | | `dimensions` | ::mlir::DenseIntElementsAttr | 64-bit signless integer elements attribute -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `inputs` | statically shaped tensor of PPHlo public type or PPHlo secret type values | `init_values` | statically shaped tensor of PPHlo public type or PPHlo secret type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -1400,11 +1401,11 @@ _ReduceWindow operator_ Returns the result of executing a reduction function over all elements in each window of one or more arrays. -See . +See https://www.tensorflow.org/xla/operation_semantics#reducewindow. Traits: RecursiveMemoryEffects, SameVariadicOperandSize, SingleBlockImplicitTerminator -#### Attributes +#### Attributes: | Attribute | MLIR Type | Description | | :-------: | :-------: | ----------- | @@ -1414,14 +1415,14 @@ Traits: RecursiveMemoryEffects, SameVariadicOperandSize, SingleBlockImplicitTerm | `window_dilations` | ::mlir::DenseIntElementsAttr | 64-bit signless integer elements attribute | `padding` | ::mlir::DenseIntElementsAttr | 64-bit signless integer elements attribute -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `inputs` | statically shaped tensor of PPHlo public type or PPHlo secret type values | `init_values` | statically shaped tensor of PPHlo public type or PPHlo secret type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -1434,7 +1435,7 @@ _Remainder operator_ Returns `lhs % rhs` element-wise. See -. +https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. Traits: AlwaysSpeculatableImplTrait, Elementwise, SameOperandsAndResultShape @@ -1442,14 +1443,14 @@ Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `lhs` | statically shaped tensor of PPHlo public type or PPHlo secret type values | `rhs` | statically shaped tensor of PPHlo public type or PPHlo secret type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -1461,7 +1462,7 @@ _Reshape operator_ Reshapes the dimensions of `operand` into a new configuration. -See . +See https://www.tensorflow.org/xla/operation_semantics#reshape. Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultElementType @@ -1469,13 +1470,13 @@ Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `operand` | statically shaped tensor of PPHlo public type or PPHlo secret type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -1487,13 +1488,14 @@ _ The `pphlo.return` operation terminates a region and returns values. _ + Traits: AlwaysSpeculatableImplTrait, Terminator Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | @@ -1506,7 +1508,7 @@ _Reverse operator_ Reverses the specified dimensions of `operand` according to the given `dimensions`. -See . +See https://www.tensorflow.org/xla/operation_semantics#rev_reverse. Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultType @@ -1514,19 +1516,19 @@ Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (Mem Effects: MemoryEffects::Effect{} -#### Attributes +#### Attributes: | Attribute | MLIR Type | Description | | :-------: | :-------: | ----------- | | `dimensions` | ::mlir::DenseIntElementsAttr | 64-bit signless integer elements attribute -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `operand` | statically shaped tensor of PPHlo public type or PPHlo secret type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -1541,18 +1543,18 @@ following the uniform distribution over the interval `[a,b)`. The parameters and output element type have to be an integral type or a fixed point type, and the types have to be consistent. -See . +See https://www.tensorflow.org/xla/operation_semantics#rnguniform. Traits: SameOperandsAndResultElementType -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `a` | statically shaped tensor of PPHlo public type or PPHlo secret type values | `b` | statically shaped tensor of PPHlo public type or PPHlo secret type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -1566,7 +1568,7 @@ Returns `Round(operand)` element-wise, rounding to nearest integer with half-way cases rounding away from zero. See -. +https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. Traits: AlwaysSpeculatableImplTrait, Elementwise, SameOperandsAndResultShape, SameOperandsAndResultType @@ -1574,13 +1576,13 @@ Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (Mem Effects: MemoryEffects::Effect{} -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `operand` | statically shaped tensor of public fixed-point type or secret fixed-point type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -1593,7 +1595,7 @@ _Reciprocal of square-root operator_ Returns `rsqrt(operand)` element-wise. See -. +https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. Traits: AlwaysSpeculatableImplTrait, Elementwise, SameOperandsAndResultShape, SameOperandsAndResultType @@ -1601,13 +1603,13 @@ Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (Mem Effects: MemoryEffects::Effect{} -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `operand` | statically shaped tensor of public fixed-point type or secret fixed-point type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -1624,11 +1626,11 @@ to the output which is initialized with `init_value`. Multiple scattered elements which land in the same output location are combined using the `scatter` function. -See . +See https://www.tensorflow.org/xla/operation_semantics#selectandscatter. Traits: RecursiveMemoryEffects -#### Attributes +#### Attributes: | Attribute | MLIR Type | Description | | :-------: | :-------: | ----------- | @@ -1636,7 +1638,7 @@ Traits: RecursiveMemoryEffects | `window_strides` | ::mlir::DenseIntElementsAttr | 64-bit signless integer elements attribute | `padding` | ::mlir::DenseIntElementsAttr | 64-bit signless integer elements attribute -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | @@ -1644,7 +1646,7 @@ Traits: RecursiveMemoryEffects | `source` | statically shaped tensor of PPHlo public type or PPHlo secret type values | `init_value` | statically shaped tensor of PPHlo public type or PPHlo secret type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -1659,7 +1661,7 @@ based on the values of `pred`. All three operands must be of the same shape with the exception of `pred`, which may also be a scalar in which case it is broadcasted. -See . +See https://www.tensorflow.org/xla/operation_semantics#select. Traits: AlwaysSpeculatableImplTrait @@ -1667,7 +1669,7 @@ Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | @@ -1675,7 +1677,7 @@ Effects: MemoryEffects::Effect{} | `on_true` | statically shaped tensor of PPHlo public type or PPHlo secret type values | `on_false` | statically shaped tensor of PPHlo public type or PPHlo secret type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -1688,7 +1690,7 @@ _Shift Left operator_ Returns `lhs << rhs` element-wise. See -. +https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. Traits: AlwaysSpeculatableImplTrait, Elementwise, SameOperandsAndResultShape @@ -1696,14 +1698,14 @@ Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `lhs` | statically shaped tensor of PPHlo public type or PPHlo secret type values | `rhs` | statically shaped tensor of PPHlo public type or PPHlo secret type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -1716,7 +1718,7 @@ _Shift right arithmetic operator_ Returns arithmetic `lhs >> rhs` element-wise. See -. +https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. Traits: AlwaysSpeculatableImplTrait, Elementwise, SameOperandsAndResultShape @@ -1724,14 +1726,14 @@ Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `lhs` | statically shaped tensor of PPHlo public type or PPHlo secret type values | `rhs` | statically shaped tensor of PPHlo public type or PPHlo secret type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -1744,7 +1746,7 @@ _Shift right logical operator_ Returns logical `lhs >> rhs` element-wise. See -. +https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. Traits: AlwaysSpeculatableImplTrait, Elementwise, SameOperandsAndResultShape @@ -1752,14 +1754,14 @@ Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `lhs` | statically shaped tensor of PPHlo public type or PPHlo secret type values | `rhs` | statically shaped tensor of PPHlo public type or PPHlo secret type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -1778,7 +1780,7 @@ sign(x) = -1 : x < 0 ``` See -. +https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. Traits: AlwaysSpeculatableImplTrait, Elementwise, SameOperandsAndResultShape, SameOperandsAndResultType @@ -1786,26 +1788,25 @@ Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (Mem Effects: MemoryEffects::Effect{} -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `operand` | statically shaped tensor of PPHlo public type or PPHlo secret type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | «unnamed» | statically shaped tensor of PPHlo public type or PPHlo secret type values ### `pphlo.slice` (pphlo::SliceOp) - The dynamic shape version of SliceOp. Extracts a sub-array from the input array according to start_indices, limit_indices and strides. Expect start_indices/limit_indices/strides to be statically shaped and matching the rank of the input. -See +See https://www.tensorflow.org/xla/operation_semantics#slice Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultElementType @@ -1813,7 +1814,7 @@ Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} -#### Attributes +#### Attributes: | Attribute | MLIR Type | Description | | :-------: | :-------: | ----------- | @@ -1821,13 +1822,13 @@ Effects: MemoryEffects::Effect{} | `limit_indices` | ::mlir::DenseIntElementsAttr | 64-bit signless integer elements attribute | `strides` | ::mlir::DenseIntElementsAttr | 64-bit signless integer elements attribute -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `operand` | statically shaped tensor of PPHlo public type or PPHlo secret type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -1840,24 +1841,24 @@ _Sort operator_ Sorts the given `operands` at the given `dimension` with the given `comparator`. -See . +See https://www.tensorflow.org/xla/operation_semantics#sort. Traits: RecursiveMemoryEffects, SameOperandsAndResultShape -#### Attributes +#### Attributes: | Attribute | MLIR Type | Description | | :-------: | :-------: | ----------- | | `dimension` | ::mlir::IntegerAttr | 64-bit signless integer attribute | `is_stable` | ::mlir::BoolAttr | bool attribute -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `operands` | statically shaped tensor of PPHlo public type or PPHlo secret type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -1870,7 +1871,7 @@ _Square-root operator_ Returns `sqrt(operand)` element-wise. See -. +https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. Traits: AlwaysSpeculatableImplTrait, Elementwise, SameOperandsAndResultShape, SameOperandsAndResultType @@ -1878,13 +1879,13 @@ Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (Mem Effects: MemoryEffects::Effect{} -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `operand` | statically shaped tensor of public fixed-point type or secret fixed-point type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -1897,7 +1898,7 @@ _Subtraction operator_ Returns `lhs - rhs` element-wise. See -. +https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. Traits: AlwaysSpeculatableImplTrait, Elementwise, SameOperandsAndResultShape @@ -1905,14 +1906,14 @@ Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `lhs` | statically shaped tensor of PPHlo public type or PPHlo secret type values | `rhs` | statically shaped tensor of PPHlo public type or PPHlo secret type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -1925,7 +1926,7 @@ _Tanh operator_ Returns `tanh(operand)` element-wise. See -. +https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. Traits: AlwaysSpeculatableImplTrait, Elementwise, SameOperandsAndResultShape, SameOperandsAndResultType @@ -1933,13 +1934,13 @@ Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (Mem Effects: MemoryEffects::Effect{} -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `operand` | statically shaped tensor of public fixed-point type or secret fixed-point type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -1953,7 +1954,7 @@ Permutes the dimensions of `operand` according to the given `permutation`. `res_dimensions[i] = operand_dimensions[permutation[i]]` -See . +See https://www.tensorflow.org/xla/operation_semantics#transpose. Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultElementType @@ -1961,19 +1962,19 @@ Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} -#### Attributes +#### Attributes: | Attribute | MLIR Type | Description | | :-------: | :-------: | ----------- | | `permutation` | ::mlir::DenseIntElementsAttr | 64-bit signless integer elements attribute -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `operand` | statically shaped tensor of PPHlo public type or PPHlo secret type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -1986,17 +1987,17 @@ _While operator_ Returns the result of executing a body function until the cond body returns true. -See . +See https://www.tensorflow.org/xla/operation_semantics#while. Traits: PPHLO_PairwiseSameOperandAndResultType, RecursiveMemoryEffects, SingleBlockImplicitTerminator -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `args` | statically shaped tensor of PPHlo public type or PPHlo secret type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | @@ -2009,7 +2010,7 @@ _Logical xor_ Returns `logical_xor(lhs, rhs)` element-wise. See -. +https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. Traits: AlwaysSpeculatableImplTrait, Commutative, Elementwise, SameOperandsAndResultShape @@ -2017,14 +2018,14 @@ Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface) Effects: MemoryEffects::Effect{} -#### Operands +#### Operands: | Operand | Description | | :-----: | ----------- | | `lhs` | statically shaped tensor of public integer type or secret integer type values | `rhs` | statically shaped tensor of public integer type or secret integer type values -#### Results +#### Results: | Result | Description | | :----: | ----------- | diff --git a/docs/reference/psi_config.md b/docs/reference/psi_config.md index ac699764..7e3628d0 100644 --- a/docs/reference/psi_config.md +++ b/docs/reference/psi_config.md @@ -2,26 +2,34 @@ ## Table of Contents + + - Messages - - [BucketPsiConfig](#bucketpsiconfig) - - [DpPsiParams](#dppsiparams) - - [InputParams](#inputparams) - - [MemoryPsiConfig](#memorypsiconfig) - - [OutputParams](#outputparams) - - [PsiResultReport](#psiresultreport) + - [BucketPsiConfig](#bucketpsiconfig) + - [DpPsiParams](#dppsiparams) + - [InputParams](#inputparams) + - [MemoryPsiConfig](#memorypsiconfig) + - [OutputParams](#outputparams) + - [PsiResultReport](#psiresultreport) + + - Enums - - [CurveType](#curvetype) - - [PsiType](#psitype) + - [CurveType](#curvetype) + - [PsiType](#psitype) + + - [Scalar Value Types](#scalar-value-types) + + ## Messages -### BucketPsiConfig +### BucketPsiConfig The Bucket-psi configuration. ```python @@ -35,46 +43,49 @@ The Bucket-psi configuration. report = psi.bucket_psi(lctx, config) # run psi and get report ``` + | Field | Type | Description | | ----- | ---- | ----------- | -| psi_type | [PsiType](#psitype) | The psi type. | -| receiver_rank | [uint32](#uint32) | Specified the receiver rank. Receiver can get psi result. | -| broadcast_result | [bool](#bool) | Whether to broadcast psi result to all parties. | -| input_params | [InputParams](#inputparams) | The input parameters of psi. | -| output_params | [OutputParams](#outputparams) | The output parameters of psi. | -| curve_type | [CurveType](#curvetype) | Optional, specified elliptic curve cryptography used in psi when needed. | -| bucket_size | [uint32](#uint32) | Optional, specified the hash bucket size used in psi. | -| preprocess_path | [string](#string) | Optional,The path of offline preprocess file. | -| ecdh_secret_key_path | [string](#string) | Optional,secret key path of ecdh_oprf, 256bit/32bytes binary file. | -| dppsi_params | [DpPsiParams](#dppsiparams) | Optional,Params for dp-psi | +| psi_type | [ PsiType](#psitype) | The psi type. | +| receiver_rank | [ uint32](#uint32) | Specified the receiver rank. Receiver can get psi result. | +| broadcast_result | [ bool](#bool) | Whether to broadcast psi result to all parties. | +| input_params | [ InputParams](#inputparams) | The input parameters of psi. | +| output_params | [ OutputParams](#outputparams) | The output parameters of psi. | +| curve_type | [ CurveType](#curvetype) | Optional, specified elliptic curve cryptography used in psi when needed. | +| bucket_size | [ uint32](#uint32) | Optional, specified the hash bucket size used in psi. | +| preprocess_path | [ string](#string) | Optional,The path of offline preprocess file. | +| ecdh_secret_key_path | [ string](#string) | Optional,secret key path of ecdh_oprf, 256bit/32bytes binary file. | +| dppsi_params | [ DpPsiParams](#dppsiparams) | Optional,Params for dp-psi | -### DpPsiParams +### DpPsiParams The input parameters of dp-psi. + | Field | Type | Description | | ----- | ---- | ----------- | -| bob_sub_sampling | [double](#double) | bob sub-sampling bernoulli_distribution probability. | -| epsilon | [double](#double) | dp epsilon | +| bob_sub_sampling | [ double](#double) | bob sub-sampling bernoulli_distribution probability. | +| epsilon | [ double](#double) | dp epsilon | -### InputParams +### InputParams The input parameters of psi. + | Field | Type | Description | | ----- | ---- | ----------- | -| path | [string](#string) | The path of input csv file. | +| path | [ string](#string) | The path of input csv file. | | select_fields | [repeated string](#string) | The select fields of input data. | -| precheck | [bool](#bool) | Whether to check select fields duplicate. | +| precheck | [ bool](#bool) | Whether to check select fields duplicate. | -### MemoryPsiConfig +### MemoryPsiConfig The In-memory psi configuration. ```python @@ -88,43 +99,46 @@ The In-memory psi configuration. ) # run psi and get joined list ``` + | Field | Type | Description | | ----- | ---- | ----------- | -| psi_type | [PsiType](#psitype) | The psi type. | -| receiver_rank | [uint32](#uint32) | Specified the receiver rank. Receiver can get psi result. | -| broadcast_result | [bool](#bool) | Whether to broadcast psi result to all parties. | -| curve_type | [CurveType](#curvetype) | Optional, specified elliptic curve cryptography used in psi when needed. | -| dppsi_params | [DpPsiParams](#dppsiparams) | Optional,Params for dp-psi | +| psi_type | [ PsiType](#psitype) | The psi type. | +| receiver_rank | [ uint32](#uint32) | Specified the receiver rank. Receiver can get psi result. | +| broadcast_result | [ bool](#bool) | Whether to broadcast psi result to all parties. | +| curve_type | [ CurveType](#curvetype) | Optional, specified elliptic curve cryptography used in psi when needed. | +| dppsi_params | [ DpPsiParams](#dppsiparams) | Optional,Params for dp-psi | -### OutputParams +### OutputParams The output parameters of psi. + | Field | Type | Description | | ----- | ---- | ----------- | -| path | [string](#string) | The path of output csv file. | -| need_sort | [bool](#bool) | Whether to sort output file by select fields. | +| path | [ string](#string) | The path of output csv file. | +| need_sort | [ bool](#bool) | Whether to sort output file by select fields. | -### PsiResultReport +### PsiResultReport The report of psi result. + | Field | Type | Description | | ----- | ---- | ----------- | -| original_count | [int64](#int64) | The data count of input. | -| intersection_count | [int64](#int64) | The count of intersection. Get `-1` when self party can not get result. | +| original_count | [ int64](#int64) | The data count of input. | +| intersection_count | [ int64](#int64) | The count of intersection. Get `-1` when self party can not get result. | ## Enums -### CurveType +### CurveType The specified elliptic curve cryptography used in psi. | Name | Number | Description | @@ -133,18 +147,20 @@ The specified elliptic curve cryptography used in psi. | CURVE_25519 | 1 | Daniel J. Bernstein. Curve25519: new diffie-hellman speed records | | CURVE_FOURQ | 2 | FourQ: four-dimensional decompositions on a Q-curve over the Mersenne prime | | CURVE_SM2 | 3 | SM2 is an elliptic curve based cryptosystem (ECC) published as a Chinese National Standard as GBT.32918.1-2016 and published in ISO/IEC 14888-3:2018 | -| CURVE_SECP256K1 | 4 | parameters of the elliptic curve defined in Standards for Efficient Cryptography (SEC) | +| CURVE_SECP256K1 | 4 | parameters of the elliptic curve defined in Standards for Efficient Cryptography (SEC) http://www.secg.org/sec2-v2.pdf | -### PsiType + + +### PsiType The algorithm type of psi. | Name | Number | Description | | ---- | ------ | ----------- | | INVALID_PSI_TYPE | 0 | none | | ECDH_PSI_2PC | 1 | DDH based PSI | -| KKRT_PSI_2PC | 2 | Efficient Batched Oblivious PRF with Applications to Private Set Intersection | -| BC22_PSI_2PC | 3 | PSI from Pseudorandom Correlation Generators | +| KKRT_PSI_2PC | 2 | Efficient Batched Oblivious PRF with Applications to Private Set Intersection https://eprint.iacr.org/2016/799.pdf | +| BC22_PSI_2PC | 3 | PSI from Pseudorandom Correlation Generators https://eprint.iacr.org/2022/334 | | ECDH_PSI_3PC | 4 | Multi-party PSI based on ECDH (Say A, B, C (receiver)) notice: two-party intersection cardinarlity leak (|A intersect B|) | | ECDH_PSI_NPC | 5 | Iterative running 2-party ecdh psi to get n-party PSI. Notice: two-party intersection leak | | KKRT_PSI_NPC | 6 | Iterative running 2-party kkrt psi to get n-party PSI. Notice: two-party intersection leak | @@ -153,7 +169,8 @@ The algorithm type of psi. | ECDH_OPRF_UB_PSI_2PC_OFFLINE | 9 | ecdh-oprf 2-party Unbalanced-PSI offline phase. | | ECDH_OPRF_UB_PSI_2PC_ONLINE | 10 | ecdh-oprf 2-party Unbalanced-PSI online phase. | | ECDH_OPRF_UB_PSI_2PC_SHUFFLE_ONLINE | 11 | ecdh-oprf 2-party Unbalanced-PSI with shuffling online phase. large set party get intersection result | -| DP_PSI_2PC | 12 | Differentially-Private PSI bases on ECDH-PSI, and provides: Differentially private PSI results. | +| DP_PSI_2PC | 12 | Differentially-Private PSI https://arxiv.org/pdf/2208.13249.pdf bases on ECDH-PSI, and provides: Differentially private PSI results. | + diff --git a/docs/reference/runtime_config.md b/docs/reference/runtime_config.md index 809d39f5..1522b8b7 100644 --- a/docs/reference/runtime_config.md +++ b/docs/reference/runtime_config.md @@ -2,64 +2,78 @@ ## Table of Contents + + - Messages - - [CompilationSource](#compilationsource) - - [CompilerOptions](#compileroptions) - - [ExecutableProto](#executableproto) - - [RuntimeConfig](#runtimeconfig) - - [ShapeProto](#shapeproto) - - [TTPBeaverConfig](#ttpbeaverconfig) - - [ValueChunkProto](#valuechunkproto) - - [ValueMetaProto](#valuemetaproto) + - [CompilationSource](#compilationsource) + - [CompilerOptions](#compileroptions) + - [ExecutableProto](#executableproto) + - [RuntimeConfig](#runtimeconfig) + - [ShapeProto](#shapeproto) + - [TTPBeaverConfig](#ttpbeaverconfig) + - [ValueChunkProto](#valuechunkproto) + - [ValueMetaProto](#valuemetaproto) + + - Enums - - [DataType](#datatype) - - [FieldType](#fieldtype) - - [ProtocolKind](#protocolkind) - - [PtType](#pttype) - - [RuntimeConfig.BeaverType](#runtimeconfigbeavertype) - - [RuntimeConfig.ExpMode](#runtimeconfigexpmode) - - [RuntimeConfig.LogMode](#runtimeconfiglogmode) - - [RuntimeConfig.SigmoidMode](#runtimeconfigsigmoidmode) - - [SourceIRType](#sourceirtype) - - [Visibility](#visibility) - - [XLAPrettyPrintKind](#xlaprettyprintkind) + - [DataType](#datatype) + - [FieldType](#fieldtype) + - [ProtocolKind](#protocolkind) + - [PtType](#pttype) + - [RuntimeConfig.BeaverType](#runtimeconfigbeavertype) + - [RuntimeConfig.ExpMode](#runtimeconfigexpmode) + - [RuntimeConfig.LogMode](#runtimeconfiglogmode) + - [RuntimeConfig.SigmoidMode](#runtimeconfigsigmoidmode) + - [SourceIRType](#sourceirtype) + - [Visibility](#visibility) + - [XLAPrettyPrintKind](#xlaprettyprintkind) + + - [Scalar Value Types](#scalar-value-types) + + ## Messages + ### CompilationSource + + | Field | Type | Description | | ----- | ---- | ----------- | -| ir_type | [SourceIRType](#sourceirtype) | Input IR type | -| ir_txt | [bytes](#bytes) | IR | +| ir_type | [ SourceIRType](#sourceirtype) | Input IR type | +| ir_txt | [ bytes](#bytes) | IR | | input_visibility | [repeated Visibility](#visibility) | Input visibilities | + ### CompilerOptions + + | Field | Type | Description | | ----- | ---- | ----------- | -| enable_pretty_print | [bool](#bool) | Pretty print | -| pretty_print_dump_dir | [string](#string) | none | -| xla_pp_kind | [XLAPrettyPrintKind](#xlaprettyprintkind) | none | -| disable_sqrt_plus_epsilon_rewrite | [bool](#bool) | Disable sqrt(x) + eps to sqrt(x+eps) rewrite | -| disable_div_sqrt_rewrite | [bool](#bool) | Disable x/sqrt(y) to x*rsqrt(y) rewrite | -| disable_reduce_truncation_optimization | [bool](#bool) | Disable reduce truncation optimization | -| disable_maxpooling_optimization | [bool](#bool) | Disable maxpooling optimization | -| disallow_mix_types_opts | [bool](#bool) | Disallow mix type operations | -| disable_select_optimization | [bool](#bool) | Disable SelectOp optimization | -| enable_optimize_denominator_with_broadcast | [bool](#bool) | Enable optimize x/bcast(y) -> x * bcast(1/y) | +| enable_pretty_print | [ bool](#bool) | Pretty print | +| pretty_print_dump_dir | [ string](#string) | none | +| xla_pp_kind | [ XLAPrettyPrintKind](#xlaprettyprintkind) | none | +| disable_sqrt_plus_epsilon_rewrite | [ bool](#bool) | Disable sqrt(x) + eps to sqrt(x+eps) rewrite | +| disable_div_sqrt_rewrite | [ bool](#bool) | Disable x/sqrt(y) to x*rsqrt(y) rewrite | +| disable_reduce_truncation_optimization | [ bool](#bool) | Disable reduce truncation optimization | +| disable_maxpooling_optimization | [ bool](#bool) | Disable maxpooling optimization | +| disallow_mix_types_opts | [ bool](#bool) | Disallow mix type operations | +| disable_select_optimization | [ bool](#bool) | Disable SelectOp optimization | +| enable_optimize_denominator_with_broadcast | [ bool](#bool) | Enable optimize x/bcast(y) -> x * bcast(1/y) | -### ExecutableProto +### ExecutableProto The executable format accepted by SPU runtime. - Inputs should be prepared before running executable. @@ -77,98 +91,110 @@ The executable format accepted by SPU runtime. y = rt.get_var('y') # get the executable from spu runtime. ``` + | Field | Type | Description | | ----- | ---- | ----------- | -| name | [string](#string) | The name of the executable. | +| name | [ string](#string) | The name of the executable. | | input_names | [repeated string](#string) | The input names. | | output_names | [repeated string](#string) | The output names. | -| code | [bytes](#bytes) | The bytecode of the program, with format IR_MLIR_SPU. | +| code | [ bytes](#bytes) | The bytecode of the program, with format IR_MLIR_SPU. | -### RuntimeConfig +### RuntimeConfig The SPU runtime configuration. + | Field | Type | Description | | ----- | ---- | ----------- | -| protocol | [ProtocolKind](#protocolkind) | The protocol kind. | -| field | [FieldType](#fieldtype) | The field type. | -| fxp_fraction_bits | [int64](#int64) | Number of fraction bits of fixed-point number. 0(default) indicates implementation defined. | -| enable_action_trace | [bool](#bool) | When enabled, runtime prints verbose info of the call stack, debug purpose only. | -| enable_type_checker | [bool](#bool) | When enabled, runtime checks runtime type infos against the compile-time ones, exceptions are raised if mismatches happen. Note: Runtime outputs prefer runtime type infos even when flag is on. | -| enable_pphlo_trace | [bool](#bool) | When enabled, runtime prints executed pphlo list, debug purpose only. | -| enable_processor_dump | [bool](#bool) | When enabled, runtime dumps executed executables in the dump_dir, debug purpose only. | -| processor_dump_dir | [string](#string) | none | -| enable_pphlo_profile | [bool](#bool) | When enabled, runtime records detailed pphlo timing data, debug purpose only. WARNING: the `send bytes` information is only accurate when `experimental_enable_inter_op_par` and `experimental_enable_intra_op_par` options are disabled. | -| enable_hal_profile | [bool](#bool) | When enabled, runtime records detailed hal timing data, debug purpose only. WARNING: the `send bytes` information is only accurate when `experimental_enable_inter_op_par` and `experimental_enable_intra_op_par` options are disabled. | -| public_random_seed | [uint64](#uint64) | The public random variable generated by the runtime, the concrete prg function is implementation defined. Note: this seed only applies to `public variable` only, it has nothing to do with security. | -| share_max_chunk_size | [uint64](#uint64) | max chunk size for Value::toProto default: 128 *1024* 1024 | -| fxp_div_goldschmidt_iters | [int64](#int64) | The iterations use in f_div with Goldschmidt method. 0(default) indicates implementation defined. | -| fxp_exp_mode | [RuntimeConfig.ExpMode](#runtimeconfigexpmode) | The exponent approximation method. | -| fxp_exp_iters | [int64](#int64) | Number of iterations of `exp` approximation, 0(default) indicates impl defined. | -| fxp_log_mode | [RuntimeConfig.LogMode](#runtimeconfiglogmode) | The logarithm approximation method. | -| fxp_log_iters | [int64](#int64) | Number of iterations of `log` approximation, 0(default) indicates impl-defined. | -| fxp_log_orders | [int64](#int64) | Number of orders of `log` approximation, 0(default) indicates impl defined. | -| sigmoid_mode | [RuntimeConfig.SigmoidMode](#runtimeconfigsigmoidmode) | The sigmoid function approximation model. | -| enable_lower_accuracy_rsqrt | [bool](#bool) | Enable a simpler rsqrt approximation | -| beaver_type | [RuntimeConfig.BeaverType](#runtimeconfigbeavertype) | beaver config, works for semi2k and spdz2k for now. | -| ttp_beaver_config | [TTPBeaverConfig](#ttpbeaverconfig) | TrustedThirdParty configs. | -| trunc_allow_msb_error | [bool](#bool) | For protocol like SecureML, the most significant bit may have error with low probability, which lead to huge calculation error. | -| experimental_disable_mmul_split | [bool](#bool) | Experimental: DO NOT USE | -| experimental_enable_inter_op_par | [bool](#bool) | inter op parallel, aka, DAG level parallel. | -| experimental_enable_intra_op_par | [bool](#bool) | intra op parallel, aka, hal/mpc level parallel. | -| experimental_disable_vectorization | [bool](#bool) | disable kernel level vectorization. | -| experimental_inter_op_concurrency | [uint64](#uint64) | inter op concurrency. | +| protocol | [ ProtocolKind](#protocolkind) | The protocol kind. | +| field | [ FieldType](#fieldtype) | The field type. | +| fxp_fraction_bits | [ int64](#int64) | Number of fraction bits of fixed-point number. 0(default) indicates implementation defined. | +| enable_action_trace | [ bool](#bool) | When enabled, runtime prints verbose info of the call stack, debug purpose only. | +| enable_type_checker | [ bool](#bool) | When enabled, runtime checks runtime type infos against the compile-time ones, exceptions are raised if mismatches happen. Note: Runtime outputs prefer runtime type infos even when flag is on. | +| enable_pphlo_trace | [ bool](#bool) | When enabled, runtime prints executed pphlo list, debug purpose only. | +| enable_processor_dump | [ bool](#bool) | When enabled, runtime dumps executed executables in the dump_dir, debug purpose only. | +| processor_dump_dir | [ string](#string) | none | +| enable_pphlo_profile | [ bool](#bool) | When enabled, runtime records detailed pphlo timing data, debug purpose only. WARNING: the `send bytes` information is only accurate when `experimental_enable_inter_op_par` and `experimental_enable_intra_op_par` options are disabled. | +| enable_hal_profile | [ bool](#bool) | When enabled, runtime records detailed hal timing data, debug purpose only. WARNING: the `send bytes` information is only accurate when `experimental_enable_inter_op_par` and `experimental_enable_intra_op_par` options are disabled. | +| public_random_seed | [ uint64](#uint64) | The public random variable generated by the runtime, the concrete prg function is implementation defined. Note: this seed only applies to `public variable` only, it has nothing to do with security. | +| share_max_chunk_size | [ uint64](#uint64) | max chunk size for Value::toProto default: 128 * 1024 * 1024 | +| fxp_div_goldschmidt_iters | [ int64](#int64) | The iterations use in f_div with Goldschmidt method. 0(default) indicates implementation defined. | +| fxp_exp_mode | [ RuntimeConfig.ExpMode](#runtimeconfigexpmode) | The exponent approximation method. | +| fxp_exp_iters | [ int64](#int64) | Number of iterations of `exp` approximation, 0(default) indicates impl defined. | +| fxp_log_mode | [ RuntimeConfig.LogMode](#runtimeconfiglogmode) | The logarithm approximation method. | +| fxp_log_iters | [ int64](#int64) | Number of iterations of `log` approximation, 0(default) indicates impl-defined. | +| fxp_log_orders | [ int64](#int64) | Number of orders of `log` approximation, 0(default) indicates impl defined. | +| sigmoid_mode | [ RuntimeConfig.SigmoidMode](#runtimeconfigsigmoidmode) | The sigmoid function approximation model. | +| enable_lower_accuracy_rsqrt | [ bool](#bool) | Enable a simpler rsqrt approximation | +| beaver_type | [ RuntimeConfig.BeaverType](#runtimeconfigbeavertype) | beaver config, works for semi2k and spdz2k for now. | +| ttp_beaver_config | [ TTPBeaverConfig](#ttpbeaverconfig) | TrustedThirdParty configs. | +| trunc_allow_msb_error | [ bool](#bool) | For protocol like SecureML, the most significant bit may have error with low probability, which lead to huge calculation error. | +| experimental_disable_mmul_split | [ bool](#bool) | Experimental: DO NOT USE | +| experimental_enable_inter_op_par | [ bool](#bool) | inter op parallel, aka, DAG level parallel. | +| experimental_enable_intra_op_par | [ bool](#bool) | intra op parallel, aka, hal/mpc level parallel. | +| experimental_disable_vectorization | [ bool](#bool) | disable kernel level vectorization. | +| experimental_inter_op_concurrency | [ uint64](#uint64) | inter op concurrency. | + ### ShapeProto + + | Field | Type | Description | | ----- | ---- | ----------- | | dims | [repeated int64](#int64) | none | + ### TTPBeaverConfig + + | Field | Type | Description | | ----- | ---- | ----------- | -| server_host | [string](#string) | TrustedThirdParty beaver server's remote ip:port or load-balance uri. | -| session_id | [string](#string) | if empty, use link id as session id. | -| adjust_rank | [int32](#int32) | which rank do adjust rpc call, usually choose the rank closer to the server. | +| server_host | [ string](#string) | TrustedThirdParty beaver server's remote ip:port or load-balance uri. | +| session_id | [ string](#string) | if empty, use link id as session id. | +| adjust_rank | [ int32](#int32) | which rank do adjust rpc call, usually choose the rank closer to the server. | -### ValueChunkProto +### ValueChunkProto The spu Value proto, used for spu value serialization. + | Field | Type | Description | | ----- | ---- | ----------- | -| total_bytes | [uint64](#uint64) | chunk info | -| chunk_offset | [uint64](#uint64) | none | -| content | [bytes](#bytes) | chunk bytes | +| total_bytes | [ uint64](#uint64) | chunk info | +| chunk_offset | [ uint64](#uint64) | none | +| content | [ bytes](#bytes) | chunk bytes | + ### ValueMetaProto + + | Field | Type | Description | | ----- | ---- | ----------- | -| data_type | [DataType](#datatype) | The data type. | -| visibility | [Visibility](#visibility) | The data visibility. | -| shape | [ShapeProto](#shapeproto) | The shape of the value. | -| storage_type | [string](#string) | The storage type, defined by the underline evaluation engine. i.e. `aby3.AShr` means an aby3 arithmetic share in FM64. usually, the application does not care about this attribute. | +| data_type | [ DataType](#datatype) | The data type. | +| visibility | [ Visibility](#visibility) | The data visibility. | +| shape | [ ShapeProto](#shapeproto) | The shape of the value. | +| storage_type | [ string](#string) | The storage type, defined by the underline evaluation engine. i.e. `aby3.AShr` means an aby3 arithmetic share in FM64. usually, the application does not care about this attribute. | ## Enums -### DataType +### DataType The SPU datatype | Name | Number | Description | @@ -187,8 +213,10 @@ The SPU datatype | DT_F32 | 11 | float | | DT_F64 | 12 | double | -### FieldType + + +### FieldType A security parameter type. The secure evaluation is based on some algebraic structure (ring or field), @@ -200,8 +228,10 @@ The secure evaluation is based on some algebraic structure (ring or field), | FM64 | 2 | Ring 2^64 | | FM128 | 3 | Ring 2^128 | -### ProtocolKind + + +### ProtocolKind The protocol kind. | Name | Number | Description | @@ -212,8 +242,10 @@ The protocol kind. | ABY3 | 3 | A honest majority 3PC-protocol. SecretFlow provides the semi-honest implementation without Yao. | | CHEETAH | 4 | The famous [Cheetah](https://eprint.iacr.org/2022/207) protocol, a very fast 2PC protocol. | -### PtType + + +### PtType Plaintext type SPU runtime does not process with plaintext directly, plaintext type is @@ -238,16 +270,22 @@ buffer, we have to let spu know which type the plaintext buffer is. | PT_BOOL | 13 | bool | | PT_F16 | 14 | half | + + + ### RuntimeConfig.BeaverType + | Name | Number | Description | | ---- | ------ | ----------- | | TrustedFirstParty | 0 | assume first party (rank0) as trusted party to generate beaver triple. | | TrustedThirdParty | 1 | generate beaver triple through an additional trusted third party. | | MultiParty | 2 | generate beaver triple through multi-party. | -### RuntimeConfig.ExpMode + + +### RuntimeConfig.ExpMode The exponential approximation method. | Name | Number | Description | @@ -256,8 +294,10 @@ The exponential approximation method. | EXP_PADE | 1 | The pade approximation. | | EXP_TAYLOR | 2 | Taylor series approximation. | -### RuntimeConfig.LogMode + + +### RuntimeConfig.LogMode The logarithm approximation method. | Name | Number | Description | @@ -266,8 +306,10 @@ The logarithm approximation method. | LOG_PADE | 1 | The pade approximation. | | LOG_NEWTON | 2 | The newton approximation. | -### RuntimeConfig.SigmoidMode + + +### RuntimeConfig.SigmoidMode The sigmoid approximation method. | Name | Number | Description | @@ -277,8 +319,10 @@ The sigmoid approximation method. | SIGMOID_SEG3 | 2 | Piece-wise simulation. f(x) = 0.5 + 0.125x if -4 <= x <= 4 1 if x > 4 0 if -4 > x | | SIGMOID_REAL | 3 | The real definition, which depends on exp's accuracy. f(x) = 1 / (1 + exp(-x)) | -### SourceIRType + + +### SourceIRType Compiler relate definition //////////////////////////////////////////////////////////////////////// @@ -287,8 +331,10 @@ Compiler relate definition | XLA | 0 | none | | MLIR_HLO | 1 | none | -### Visibility + + +### Visibility The visibility type. SPU is a secure evaluation runtime, but not all data are secret, some of them @@ -301,14 +347,19 @@ performance significantly. | VIS_SECRET | 1 | Invisible(unknown) for all or some of the parties. | | VIS_PUBLIC | 2 | Visible(public) for all parties. | + + + ### XLAPrettyPrintKind + | Name | Number | Description | | ---- | ------ | ----------- | | TEXT | 0 | none | | DOT | 1 | none | | HTML | 2 | none | + diff --git a/docs/reference/xla_status.md b/docs/reference/xla_status.md index dd4844aa..000f6fa1 100644 --- a/docs/reference/xla_status.md +++ b/docs/reference/xla_status.md @@ -3,7 +3,7 @@ List of XLA(mhlo-mlir) Ops that SPU supports: The list of mhlo ops is obtained from this file: - https://github.com/tensorflow/mlir-hlo/blob/master/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td + https://github.com/openxla/xla/blob/main/xla/mlir_hlo/mhlo/IR/hlo_ops.td General limitation with SPU: * Dynamic shape is not supported @@ -107,7 +107,7 @@ Count: Total = 1, fully supported = 0 | Op Name | supported(fully/partial/no) | notes | | :------------: | :-------------------------: | ----------- | | `after_all` | no | -| `if` | partial | condition variable must be a public scalar +| `if` | fully | | `case` | no | | `while` | partial | condition region must return a public scalar | `all_gather` | no | @@ -116,7 +116,7 @@ Count: Total = 1, fully supported = 0 | `all_to_all` | no | | `reduce` | fully | inherits limitations from reduce function -Count: Total = 9, fully supported = 1, partial = 2 +Count: Total = 9, fully supported = 2, partial = 1 ### XLA tuple ops @@ -132,8 +132,8 @@ Count: Total = 2, fully supported = 2 | Op Name | supported(fully/partial/no) | notes | | :------------: | :-------------------------: | ----------- | | `slice` | fully | -| `dynamic-slice`| partial | start_indices must be public values -| `dynamic-update-slice`| partial | start_indices must be public values +| `dynamic-slice`| fully | +| `dynamic-update-slice`| fully | | `batch_norm_grad`| fully | Rely on XLA's batchnorm_expander pass | `batch_norm_inference`| fully | Rely on XLA's batchnorm_expander pass | `batch_norm_training` | fully | Rely on XLA's batchnorm_expander pass @@ -174,7 +174,7 @@ Count: Total = 2, fully supported = 2 | `torch_index_select` | no | | `optimization_barrier` | no | -Count: Total = 42, fully supported = 26, partial = 3 +Count: Total = 42, fully supported = 28, partial = 1 ### XLA RNG ops diff --git a/examples/cpp/pir/keyword_pir_mem_server.cc b/examples/cpp/pir/keyword_pir_mem_server.cc index cb30ce03..23e2002c 100644 --- a/examples/cpp/pir/keyword_pir_mem_server.cc +++ b/examples/cpp/pir/keyword_pir_mem_server.cc @@ -59,6 +59,17 @@ llvm::cl::opt LabelPadLengthOpt( "max_label_length", llvm::cl::init(288), llvm::cl::desc("pad label data to max len")); +llvm::cl::opt CompressOpt("compress", llvm::cl::init(false), + llvm::cl::desc("compress seal he plaintext")); + +llvm::cl::opt BucketSizeOpt("bucket", llvm::cl::init(1000000), + llvm::cl::desc("bucket size of pir query")); + +llvm::cl::opt MaxItemsPerBinOpt( + "max_items_per_bin", llvm::cl::init(0), + llvm::cl::desc( + "max items per bin, i.e. Interpolate polynomial max degree ")); + namespace { constexpr uint32_t kLinkRecvTimeout = 30 * 60 * 1000; @@ -92,6 +103,9 @@ int main(int argc, char **argv) { config.set_label_max_len(LabelPadLengthOpt.getValue()); config.set_oprf_key_path(""); config.set_setup_path("::memory"); + config.set_compressed(CompressOpt.getValue()); + config.set_bucket_size(BucketSizeOpt.getValue()); + config.set_max_items_per_bin(MaxItemsPerBinOpt.getValue()); spu::pir::PirResultReport report = spu::pir::PirMemoryServer(link_ctx, config); diff --git a/examples/cpp/pir/keyword_pir_server.cc b/examples/cpp/pir/keyword_pir_server.cc index 48ecdcb9..fcc0e287 100644 --- a/examples/cpp/pir/keyword_pir_server.cc +++ b/examples/cpp/pir/keyword_pir_server.cc @@ -58,7 +58,7 @@ constexpr uint32_t kLinkRecvTimeout = 30 * 60 * 1000; int main(int argc, char **argv) { llvm::cl::ParseCommandLineOptions(argc, argv); - SPDLOG_INFO("setup"); + SPDLOG_INFO("server"); auto sctx = MakeSPUContext(); auto link_ctx = sctx->lctx(); diff --git a/examples/cpp/pir/keyword_pir_setup.cc b/examples/cpp/pir/keyword_pir_setup.cc index 055f2aee..befed3dc 100644 --- a/examples/cpp/pir/keyword_pir_setup.cc +++ b/examples/cpp/pir/keyword_pir_setup.cc @@ -62,6 +62,12 @@ llvm::cl::opt LabelsColumnsOpt("label_columns", llvm::cl::init("label"), llvm::cl::desc("label columns")); +llvm::cl::opt CompressOpt("compress", llvm::cl::init(false), + llvm::cl::desc("compress seal he plaintext")); + +llvm::cl::opt BucketSizeOpt("bucket", llvm::cl::init(1000000), + llvm::cl::desc("bucket size of pir query")); + llvm::cl::opt LabelPadLengthOpt( "max_label_length", llvm::cl::init(288), llvm::cl::desc("pad label data to max len")); @@ -70,6 +76,11 @@ llvm::cl::opt SetupPathOpt( "setup_path", llvm::cl::init("."), llvm::cl::desc("[out] output path for db setup data")); +llvm::cl::opt MaxItemsPerBinOpt( + "max_items_per_bin", llvm::cl::init(0), + llvm::cl::desc( + "max items per bin, i.e. Interpolate polynomial max degree")); + int main(int argc, char **argv) { llvm::cl::ParseCommandLineOptions(argc, argv); @@ -96,6 +107,9 @@ int main(int argc, char **argv) { config.set_label_max_len(LabelPadLengthOpt.getValue()); config.set_oprf_key_path(OprfKeyPathOpt.getValue()); config.set_setup_path(SetupPathOpt.getValue()); + config.set_compressed(CompressOpt.getValue()); + config.set_bucket_size(BucketSizeOpt.getValue()); + config.set_max_items_per_bin(MaxItemsPerBinOpt.getValue()); spu::pir::PirResultReport report = spu::pir::PirSetup(config); diff --git a/examples/python/pir/pir_mem_server.py b/examples/python/pir/pir_mem_server.py index 3c40f041..27676b11 100644 --- a/examples/python/pir/pir_mem_server.py +++ b/examples/python/pir/pir_mem_server.py @@ -35,6 +35,11 @@ flags.DEFINE_integer("count_per_query", 1, "count_per_query") flags.DEFINE_integer("max_label_length", 256, "max_label_length") flags.DEFINE_string("setup_path", "setup_path", "data output path") + +flags.DEFINE_boolean("compressed", False, "compress seal he plaintext") +flags.DEFINE_integer("bucket_size", 1000000, "bucket size of pir query") +flags.DEFINE_integer("max_items_per_bin", 0, "max items per bin, i.e. Interpolate polynomial max degree") + FLAGS = flags.FLAGS @@ -77,6 +82,9 @@ def main(_): label_max_len=FLAGS.max_label_length, oprf_key_path="", setup_path='::memory', + compressed=FLAGS.compressed, + bucket_size=FLAGS.bucket_size, + max_items_per_bin=FLAGS.max_items_per_bin, ) report = pir.pir_memory_server(link, config) diff --git a/examples/python/pir/pir_server.py b/examples/python/pir/pir_server.py index 5d0adc47..538c286b 100644 --- a/examples/python/pir/pir_server.py +++ b/examples/python/pir/pir_server.py @@ -25,9 +25,10 @@ flags.DEFINE_integer("rank", 0, "rank: 0/1/2...") flags.DEFINE_string("party_ips", "127.0.0.1:9307,127.0.0.1:9308", "party addresses") -flags.DEFINE_string("in_path", "data.csv", "data input path") + flags.DEFINE_string("oprf_key_path", "oprf_key.bin", "oprf key file path") flags.DEFINE_string("setup_path", "setup_path", "data output path") + flags.DEFINE_bool("enable_tls", False, "whether to enable tls for link") flags.DEFINE_string("link_server_certificate", "", "link server certificate file path") flags.DEFINE_string("link_server_private_key", "", "link server private key file path") diff --git a/examples/python/pir/pir_setup.py b/examples/python/pir/pir_setup.py index ba6b6a16..61b61925 100644 --- a/examples/python/pir/pir_setup.py +++ b/examples/python/pir/pir_setup.py @@ -31,6 +31,11 @@ flags.DEFINE_integer("max_label_length", 256, "max_label_length") flags.DEFINE_string("oprf_key_path", "oprf_key.bin", "oprf key file") flags.DEFINE_string("setup_path", "setup_path", "data output path") + +flags.DEFINE_boolean("compressed", False, "compress seal he plaintext") +flags.DEFINE_integer("bucket_size", 1000000, "bucket size of pir query") +flags.DEFINE_integer("max_items_per_bin", 0, "max items per bin, i.e. Interpolate polynomial max degree") + FLAGS = flags.FLAGS @@ -54,6 +59,9 @@ def main(_): label_max_len=FLAGS.max_label_length, oprf_key_path=FLAGS.oprf_key_path, setup_path=FLAGS.setup_path, + compressed=FLAGS.compressed, + bucket_size=FLAGS.bucket_size, + max_items_per_bin=FLAGS.max_items_per_bin, ) report = pir.pir_setup(config) diff --git a/libspu/compiler/passes/lower_mixed_type_op.cc b/libspu/compiler/passes/lower_mixed_type_op.cc index 6f05067a..389ac82d 100644 --- a/libspu/compiler/passes/lower_mixed_type_op.cc +++ b/libspu/compiler/passes/lower_mixed_type_op.cc @@ -32,6 +32,20 @@ namespace { // %3 = mul/dot(%0, %2) int, fxp -> fxp // Save one truncation template struct FxpIntMulTruncationRemover : public OpRewritePattern { +private: + bool isLegitConvert(mlir::pphlo::ConvertOp op) const { + if (op == nullptr) { + return true; + } + TypeTools tools; + + // Only int->fxp conversion is considered legit + auto to_type = tools.getExpressedType(op.getType()); + auto from_type = tools.getExpressedType(op.getOperand().getType()); + + return from_type.isa() && to_type.isa(); + } + public: explicit FxpIntMulTruncationRemover(MLIRContext *context) : OpRewritePattern(context) {} @@ -51,8 +65,9 @@ struct FxpIntMulTruncationRemover : public OpRewritePattern { auto lhs_convert = lhs.template getDefiningOp(); auto rhs_convert = rhs.template getDefiningOp(); - if ((lhs_convert != nullptr && rhs_convert == nullptr) || - (lhs_convert == nullptr && rhs_convert != nullptr)) { + if (((lhs_convert != nullptr && rhs_convert == nullptr) || + (lhs_convert == nullptr && rhs_convert != nullptr)) && + (isLegitConvert(lhs_convert) && isLegitConvert(rhs_convert))) { llvm::SmallVector operands(2); operands[0] = lhs_convert == nullptr ? lhs : lhs_convert.getOperand(); operands[1] = rhs_convert == nullptr ? rhs : rhs_convert.getOperand(); diff --git a/libspu/compiler/tests/lower_mixed_type_op.mlir b/libspu/compiler/tests/lower_mixed_type_op.mlir index 7da20156..270ad4da 100644 --- a/libspu/compiler/tests/lower_mixed_type_op.mlir +++ b/libspu/compiler/tests/lower_mixed_type_op.mlir @@ -33,3 +33,12 @@ func.func @main(%arg0: tensor<2x2x!pphlo.pub>, %arg1: tensor<2x2x!pphlo.pub %1 = "pphlo.dot"(%arg1, %0) : (tensor<2x2x!pphlo.pub>, tensor<2x2x!pphlo.pub>) -> tensor<2x2x!pphlo.pub> return %1 : tensor<2x2x!pphlo.pub> } + +// ----- + +func.func @main(%arg0: tensor<2x2x!pphlo.pub>, %arg1: tensor<2x2x!pphlo.pub>) -> (tensor<2x2x!pphlo.pub>) { + //CHECK: "pphlo.dot"(%arg1, %0) : (tensor<2x2x!pphlo.pub>, tensor<2x2x!pphlo.pub>) -> tensor<2x2x!pphlo.pub> + %0 = "pphlo.convert"(%arg0) : (tensor<2x2x!pphlo.pub>) -> tensor<2x2x!pphlo.pub> + %1 = "pphlo.dot"(%arg1, %0) : (tensor<2x2x!pphlo.pub>, tensor<2x2x!pphlo.pub>) -> tensor<2x2x!pphlo.pub> + return %1 : tensor<2x2x!pphlo.pub> +} diff --git a/libspu/core/encoding.cc b/libspu/core/encoding.cc index 97cb23ab..79573129 100644 --- a/libspu/core/encoding.cc +++ b/libspu/core/encoding.cc @@ -78,20 +78,22 @@ NdArrayRef encodeToRing(const NdArrayRef& src, FieldType field, size_t fxp_bits, const Float kFlpLower = static_cast(static_cast(kFxpLower) / kScale); + auto _src = NdArrayView(src); + auto _dst = NdArrayView(dst); + pforeach(0, numel, [&](int64_t idx) { - const auto src_value = src.at(idx); - auto& dst_value = dst.at(idx); + const auto src_value = _src[idx]; if (std::isnan(src_value)) { // see numpy.nan_to_num // note(jint) I dont know why nan could be // encoded as zero.. - dst_value = 0; + _dst[idx] = 0; } else if (src_value >= kFlpUpper) { - dst_value = kFxpUpper; + _dst[idx] = kFxpUpper; } else if (src_value <= kFlpLower) { - dst_value = kFxpLower; + _dst[idx] = kFxpLower; } else { - dst_value = static_cast(src_value * kScale); + _dst[idx] = static_cast(src_value * kScale); } }); }); @@ -108,11 +110,11 @@ NdArrayRef encodeToRing(const NdArrayRef& src, FieldType field, size_t fxp_bits, field, pt_type); using T = std::make_signed_t; + auto _src = NdArrayView(src); + auto _dst = NdArrayView(dst); // TODO: encoding integer in range [-2^(k-2),2^(k-2)) pforeach(0, numel, [&](int64_t idx) { - const auto src_value = src.at(idx); - auto& dst_value = dst.at(idx); - dst_value = static_cast(src_value); // NOLINT + _dst[idx] = static_cast(_src[idx]); // NOLINT }); }); }); @@ -143,22 +145,24 @@ NdArrayRef decodeFromRing(const NdArrayRef& src, DataType in_dtype, DISPATCH_ALL_PT_TYPES(pt_type, "pt_type", [&]() { using T = std::make_signed_t; + auto _src = NdArrayView(src); + auto _dst = NdArrayView(dst); + if (in_dtype == DT_I1) { constexpr bool kSanity = std::is_same_v; SPU_ENFORCE(kSanity); - pforeach(0, numel, [&](int64_t idx) { - dst.at(idx) = !((src.at(idx) & 0x1) == 0); - }); + pforeach(0, numel, + [&](int64_t idx) { _dst[idx] = !((_src[idx] & 0x1) == 0); }); } else if (in_dtype == DT_F32 || in_dtype == DT_F64 || in_dtype == DT_F16) { const T kScale = T(1) << fxp_bits; pforeach(0, numel, [&](int64_t idx) { - dst.at(idx) = static_cast( - static_cast(src.at(idx)) / kScale); + _dst[idx] = + static_cast(static_cast(_src[idx]) / kScale); }); } else { pforeach(0, numel, [&](int64_t idx) { - dst.at(idx) = static_cast(src.at(idx)); + _dst[idx] = static_cast(_src[idx]); }); } }); diff --git a/libspu/core/encoding_test.cc b/libspu/core/encoding_test.cc index c838a1f5..8a3628bb 100644 --- a/libspu/core/encoding_test.cc +++ b/libspu/core/encoding_test.cc @@ -102,9 +102,7 @@ TYPED_TEST(FloatEncodingTest, Works) { }; NdArrayRef frm(makePtType(PtTypeToEnum::value), {samples.size()}); - std::copy(samples.begin(), samples.end(), &frm.at(0)); - - // std::cout << frm.at(0) << std::endl; + std::copy(samples.begin(), samples.end(), &frm.at({0})); DataType encoded_dtype; auto encoded = encodeToRing(frm, kField, kFxpBits, &encoded_dtype); @@ -122,7 +120,7 @@ TYPED_TEST(FloatEncodingTest, Works) { } else { EXPECT_EQ(encoded_dtype, DT_F64); } - auto* out_ptr = &decoded.at(0); + auto* out_ptr = &decoded.at({0}); const int64_t kReprBits = SizeOf(kField) * 8 - 2; const int64_t kScale = 1LL << kFxpBits; EXPECT_EQ(out_ptr[0], -static_cast((1LL << kReprBits)) / kScale); @@ -154,9 +152,7 @@ TYPED_TEST(IntEncodingTest, Works) { }; NdArrayRef frm(makePtType(PtTypeToEnum::value), {samples.size()}); - std::copy(samples.begin(), samples.end(), &frm.at(0)); - - // std::cout << frm.at(0) << std::endl; + std::copy(samples.begin(), samples.end(), &frm.at({0})); DataType encoded_dtype; auto encoded = encodeToRing(frm, kField, kFxpBits, &encoded_dtype); @@ -166,7 +162,7 @@ TYPED_TEST(IntEncodingTest, Works) { auto decoded = decodeFromRing(encoded, encoded_dtype, kFxpBits, &out_pt_type); EXPECT_EQ(out_pt_type, frm_pt_type); - IntT* out_ptr = &decoded.at(0); + IntT* out_ptr = &decoded.at({0}); EXPECT_EQ(out_ptr[0], samples[0]); EXPECT_EQ(out_ptr[1], samples[1]); EXPECT_EQ(out_ptr[2], static_cast(-1)); diff --git a/libspu/core/ndarray_ref.cc b/libspu/core/ndarray_ref.cc index 38d3684f..dd1b2103 100644 --- a/libspu/core/ndarray_ref.cc +++ b/libspu/core/ndarray_ref.cc @@ -119,8 +119,8 @@ static bool attempt_nocopy_reshape(const NdArrayRef& old, return true; } -std::pair can_use_fast_indexing(const Shape& shape, - const Strides& strides) { +std::pair can_use_fast_indexing(const Shape& shape, + const Strides& strides) { Shape stripped_shape; Strides stripped_strides; @@ -222,7 +222,6 @@ NdArrayRef NdArrayRef::clone() const { } return res; - return {}; } std::shared_ptr NdArrayRef::getOrCreateCompactBuf() const { diff --git a/libspu/core/ndarray_ref.h b/libspu/core/ndarray_ref.h index 30596f19..d10aef19 100644 --- a/libspu/core/ndarray_ref.h +++ b/libspu/core/ndarray_ref.h @@ -49,7 +49,7 @@ class NdArrayRef { // Indicate this buffer can be indexing in a linear way bool use_fast_indexing_{false}; - int64_t fast_indexing_stride_{0}; + Stride fast_indexing_stride_{0}; public: NdArrayRef() = default; @@ -112,23 +112,13 @@ class NdArrayRef { // Test only bool canUseFastIndexing() const { return use_fast_indexing_; } + const Stride& fastIndexingStride() const { return fast_indexing_stride_; } // View this array ref as another type. // @param force, true if ignore the type check, else only the same elsize type // could be casted. NdArrayRef as(const Type& new_ty, bool force = false) const; - // Get data pointer - template - T* data() { - return reinterpret_cast(buf_->data() + offset_); - } - - template - const T* data() const { - return reinterpret_cast(buf_->data() + offset_); - } - // Get element. template T& at(const Index& pos) { @@ -144,6 +134,17 @@ class NdArrayRef { elsize() * fi); } + // Get data pointer + template + T* data() { + return reinterpret_cast(buf_->data() + offset_); + } + + template + const T* data() const { + return reinterpret_cast(buf_->data() + offset_); + } + template T& at(int64_t pos) { if (use_fast_indexing_) { @@ -381,6 +382,46 @@ NdArrayRef makeConstantArrayRef(const Type& eltype, const Shape& shape); std::ostream& operator<<(std::ostream& out, const NdArrayRef& v); +template +class NdArrayView { + private: + NdArrayRef* arr_; + size_t elsize_; + + public: + // Note: we explicit discard const correctness due to the complexity. + explicit NdArrayView(const NdArrayRef& arr) + : arr_(const_cast(&arr)), elsize_(sizeof(T)) { + if (!arr.canUseFastIndexing()) { + // When elsize does not match, flatten/unflatten computation is dangerous + SPU_ENFORCE(elsize_ == arr_->elsize(), "T size = {}, arr elsize = {}", + elsize_, arr_->elsize()); + } + } + + T& operator[](size_t idx) { + if (arr_->canUseFastIndexing()) { + return *reinterpret_cast(arr_->data() + + elsize_ * idx * arr_->fastIndexingStride()); + } else { + const auto& indices = unflattenIndex(idx, arr_->shape()); + auto fi = calcFlattenOffset(indices, arr_->shape(), arr_->strides()); + return *reinterpret_cast(arr_->data() + elsize_ * fi); + } + } + + T const& operator[](size_t idx) const { + if (arr_->canUseFastIndexing()) { + return *reinterpret_cast(arr_->data() + + elsize_ * idx * arr_->fastIndexingStride()); + } else { + const auto& indices = unflattenIndex(idx, arr_->shape()); + auto fi = calcFlattenOffset(indices, arr_->shape(), arr_->strides()); + return *reinterpret_cast(arr_->data() + elsize_ * fi); + } + } +}; + template size_t maxBitWidth(const NdArrayRef& in) { auto numel = in.numel(); @@ -393,12 +434,14 @@ size_t maxBitWidth(const NdArrayRef& in) { return BitWidth(in.cbegin().getScalarValue()); } + NdArrayView _in(in); + size_t res = preduce( 0, numel, [&](int64_t begin, int64_t end) { size_t partial_max = 0; for (int64_t idx = begin; idx < end; ++idx) { - partial_max = std::max(partial_max, BitWidth(in.at(idx))); + partial_max = std::max(partial_max, BitWidth(_in[idx])); } return partial_max; }, diff --git a/libspu/core/ndarray_ref_test.cc b/libspu/core/ndarray_ref_test.cc index 79c5deeb..af85b323 100644 --- a/libspu/core/ndarray_ref_test.cc +++ b/libspu/core/ndarray_ref_test.cc @@ -59,8 +59,7 @@ TEST(NdArrayRefTest, Iterator) { EXPECT_EQ(a.numel(), 9); // Fill array with 0 1 2 3 4 5 - std::iota(static_cast(a.data()), - static_cast(a.data()) + 9, 0); + std::iota(a.data(), a.data() + 9, 0); verify_array(a, std::vector{0, 1, 2, 3, 4, 5, 6, 7, 8}); } @@ -73,8 +72,7 @@ TEST(NdArrayRefTest, StridedIterator) { EXPECT_EQ(a.numel(), 9); // Fill array with 0 1 2 3 4 5 - std::iota(static_cast(a.data()), - static_cast(a.data()) + 36, 0); + std::iota(a.data(), a.data() + 36, 0); verify_array(a, std::vector{0, 2, 4, 12, 14, 16, 24, 26, 28}); } @@ -87,8 +85,7 @@ TEST(NdArrayRefTest, NdStrides) { EXPECT_EQ(a.numel(), 9); // Fill array with 0 1 2 3 4 5 - std::iota(static_cast(a.data()), - static_cast(a.data()) + 36, 0); + std::iota(a.data(), a.data() + 36, 0); EXPECT_EQ(a.at({0, 0}), 0); EXPECT_EQ(a.at({0, 1}), 2); @@ -119,10 +116,8 @@ TEST(NdArrayRefTest, UpdateSlice) { // 6,7,8 // b : 0,1 // 2,3 - std::iota(static_cast(a.data()), - static_cast(a.data()) + 9, 0); - std::iota(static_cast(b.data()), - static_cast(b.data()) + 4, 0); + std::iota(a.data(), a.data() + 9, 0); + std::iota(b.data(), b.data() + 4, 0); a.update_slice(b, {0, 1}); // a : 0,0,1 diff --git a/libspu/core/pt_buffer_view.cc b/libspu/core/pt_buffer_view.cc index 1e63d71d..b0c15db2 100644 --- a/libspu/core/pt_buffer_view.cc +++ b/libspu/core/pt_buffer_view.cc @@ -30,7 +30,7 @@ NdArrayRef convertToNdArray(PtBufferView bv) { auto out = NdArrayRef(type, bv.shape); if (bv.shape.numel() > 0) { - auto* out_ptr = static_cast(out.data()); + auto* out_ptr = out.data(); size_t elsize = SizeOf(bv.pt_type); diff --git a/libspu/core/shape.h b/libspu/core/shape.h index 01a4555d..2cc01761 100644 --- a/libspu/core/shape.h +++ b/libspu/core/shape.h @@ -84,14 +84,16 @@ class Index : public std::vector { } }; -class Strides : public std::vector { +using Stride = int64_t; + +class Strides : public std::vector { private: - using Base = std::vector; + using Base = std::vector; public: using Base::Base; - /*explicit*/ Strides(llvm::ArrayRef arr) + /*explicit*/ Strides(llvm::ArrayRef arr) : Base(arr.begin(), arr.end()) {} friend std::ostream &operator<<(std::ostream &out, const Strides &s) { diff --git a/libspu/core/value.cc b/libspu/core/value.cc index ae47d6e3..bd1fd1c2 100644 --- a/libspu/core/value.cc +++ b/libspu/core/value.cc @@ -149,7 +149,7 @@ Value Value::fromProto(const ValueProto& value) { for (const auto& [offset, chunk] : ordered_chunks) { SPU_ENFORCE(offset == chunk_end_pos, "offset {} is not match to last chunk's end pos", offset); - memcpy(static_cast(data.data()) + offset, chunk->content().data(), + memcpy(data.data() + offset, chunk->content().data(), chunk->content().size()); chunk_end_pos += chunk->content().size(); } diff --git a/libspu/core/xt_helper.h b/libspu/core/xt_helper.h index 4977d129..9040ae92 100644 --- a/libspu/core/xt_helper.h +++ b/libspu/core/xt_helper.h @@ -35,8 +35,8 @@ auto xt_mutable_adapt(NdArrayRef& aref) { std::vector shape(aref.shape().begin(), aref.shape().end()); std::vector stride(aref.strides().begin(), aref.strides().end()); - return xt::adapt(static_cast(aref.data()), aref.numel(), - xt::no_ownership(), shape, stride); + return xt::adapt(aref.data(), aref.numel(), xt::no_ownership(), shape, + stride); } template @@ -47,8 +47,8 @@ auto xt_adapt(const NdArrayRef& aref) { std::vector shape(aref.shape().begin(), aref.shape().end()); std::vector stride(aref.strides().begin(), aref.strides().end()); - return xt::adapt(static_cast(aref.data()), aref.numel(), - xt::no_ownership(), shape, stride); + return xt::adapt(aref.data(), aref.numel(), xt::no_ownership(), + shape, stride); } // Make a NdArrayRef from an xt expression. diff --git a/libspu/device/pphlo/pphlo_executor_test_runner.h b/libspu/device/pphlo/pphlo_executor_test_runner.h index 3b52cbe5..af1b61ab 100644 --- a/libspu/device/pphlo/pphlo_executor_test_runner.h +++ b/libspu/device/pphlo/pphlo_executor_test_runner.h @@ -45,15 +45,15 @@ class Runner { const auto &out = io_->OutFeed(fmt::format("output{}", idx)); size_t numel = out.numel(); - const auto *in_ptr = static_cast(out.data()); + NdArrayView _out(out); // TODO: handle strides for (size_t i = 0; i < numel; ++i) { if constexpr (std::is_integral_v) { - EXPECT_EQ(in_ptr[i], expected[i]) << "i = " << i << "\n"; + EXPECT_EQ(_out[i], expected[i]) << "i = " << i << "\n"; } else { - EXPECT_TRUE(std::abs(in_ptr[i] - expected[i]) <= 1e-2) - << "i = " << i << " in = " << in_ptr[i] + EXPECT_TRUE(std::abs(_out[i] - expected[i]) <= 1e-2) + << "i = " << i << " in = " << _out[i] << " expected = " << expected[i] << "\n"; } } diff --git a/libspu/kernel/hlo/indexing_test.cc b/libspu/kernel/hlo/indexing_test.cc index 99ade5d1..4529cfd7 100644 --- a/libspu/kernel/hlo/indexing_test.cc +++ b/libspu/kernel/hlo/indexing_test.cc @@ -28,11 +28,12 @@ namespace spu::kernel::hlo { TEST(IndexingTest, Take1) { Value a(NdArrayRef(makePtType(PT_I64), {5}), DT_I64); - auto *data_ptr = static_cast(a.data().data()); + auto *data_ptr = a.data().data(); std::iota(data_ptr, data_ptr + 5, 0); + NdArrayView _a(a.data()); for (int64_t idx = 0; idx < 5; ++idx) { - auto v = a.data().at(idx); + auto v = _a[idx]; EXPECT_EQ(v, idx); } @@ -46,15 +47,16 @@ TEST(IndexingTest, Take1) { TEST(IndexingTest, Take2) { Value a(NdArrayRef(makePtType(PT_I64), {10}), DT_I64); - auto *data_ptr = static_cast(a.data().data()); + auto *data_ptr = a.data().data(); std::iota(data_ptr, data_ptr + 10, 0); // Create a strided data Value b(NdArrayRef(a.data().buf(), a.data().eltype(), {5}, {2}, 0), a.dtype()); // Sanity input... + NdArrayView _b(b.data()); for (int64_t idx = 0; idx < 5; ++idx) { - auto v = b.data().at(idx); + auto v = _b[idx]; EXPECT_EQ(v, 2 * idx); } diff --git a/libspu/mpc/aby3/arithmetic.cc b/libspu/mpc/aby3/arithmetic.cc index c64bcffd..e1557e21 100644 --- a/libspu/mpc/aby3/arithmetic.cc +++ b/libspu/mpc/aby3/arithmetic.cc @@ -41,43 +41,40 @@ std::vector a1b_offline(size_t sender, const NdArrayRef& a, auto numel = a.numel(); return DISPATCH_ALL_FIELDS(field, "_", [&]() -> std::vector { - using AShrT = ring2k_t; - + using ashr_el_t = ring2k_t; NdArrayRef m0(makeType(field), a.shape()); NdArrayRef m1(makeType(field), a.shape()); - ring_set_value(m0, AShrT(0)); - ring_set_value(m1, AShrT(1)); + ring_set_value(m0, ashr_el_t(0)); + ring_set_value(m1, ashr_el_t(1)); - auto _m0 = m0.data(); - auto _m1 = m1.data(); + NdArrayView _m0(m0); + NdArrayView _m1(m1); return DISPATCH_UINT_PT_TYPES( b.eltype().as()->getBacktype(), "_", [&]() -> std::vector { - using BSharT = ScalarT; + using bshr_t = std::array; if (self_rank == sender) { auto c1 = prg_state->genPrssPair(field, a.shape(), false, true).first; auto c2 = prg_state->genPrssPair(field, a.shape(), true).second; - auto _c1 = c1.data(); - auto _c2 = c2.data(); + NdArrayView _a(a); + NdArrayView _b(b); + NdArrayView _c1(c1); + NdArrayView _c2(c2); // (i \xor b1 \xor b2) * a - c1 - c2 pforeach(0, numel, [&](int64_t idx) { - _m0[idx] = - (_m0[idx] ^ (b.at>(idx)[0] & 0x1) ^ - (b.at>(idx)[1] & 0x1)) * - a.at(idx) - - _c1[idx] - _c2[idx]; + _m0[idx] = (_m0[idx] ^ (_b[idx][0] & 0x1) ^ (_b[idx][1] & 0x1)) * + _a[idx] - + _c1[idx] - _c2[idx]; }); pforeach(0, numel, [&](int64_t idx) { - _m1[idx] = - (_m1[idx] ^ (b.at>(idx)[0] & 0x1) ^ - (b.at>(idx)[1] & 0x1)) * - a.at(idx) - - _c1[idx] - _c2[idx]; + _m1[idx] = (_m1[idx] ^ (_b[idx][0] & 0x1) ^ (_b[idx][1] & 0x1)) * + _a[idx] - + _c1[idx] - _c2[idx]; }); return {c1, c2, m0, m1}; @@ -104,9 +101,9 @@ std::vector ring_cast_boolean(const NdArrayRef& x) { std::vector res(numel); DISPATCH_UINT_PT_TYPES(x.eltype().as()->pt_type(), "_", [&]() { - using BShrT = ScalarT; + NdArrayView _x(x); pforeach(0, numel, [&](int64_t idx) { - res[idx] = static_cast(x.at(idx) & 0x1); + res[idx] = static_cast(_x[idx] & 0x1); }); }); @@ -122,13 +119,13 @@ NdArrayRef RandA::proc(KernelEvalContext* ctx, const Shape& shape) const { NdArrayRef out(makeType(field), shape); DISPATCH_ALL_FIELDS(field, "_", [&]() { - using AShrT = ring2k_t; + using el_t = ring2k_t; - std::vector r0(shape.numel()); - std::vector r1(shape.numel()); + std::vector r0(shape.numel()); + std::vector r1(shape.numel()); prg_state->fillPrssPair(absl::MakeSpan(r0), absl::MakeSpan(r1)); - auto _out = out.data>(); + NdArrayView> _out(out); pforeach(0, out.numel(), [&](int64_t idx) { // Comparison only works for [-2^(k-2), 2^(k-2)). @@ -147,23 +144,22 @@ NdArrayRef A2P::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { auto numel = in.numel(); return DISPATCH_ALL_FIELDS(field, "_", [&]() { - using PShrT = ring2k_t; - using AShrT = ring2k_t; + using pshr_el_t = ring2k_t; + using ashr_el_t = ring2k_t; + using ashr_t = std::array; NdArrayRef out(makeType(field), in.shape()); - auto _out = out.data(); + NdArrayView _out(out); + NdArrayView _in(in); - std::vector x2(numel); + std::vector x2(numel); - pforeach(0, numel, [&](int64_t idx) { - x2[idx] = in.at>(idx)[1]; - }); + pforeach(0, numel, [&](int64_t idx) { x2[idx] = _in[idx][1]; }); - auto x3 = comm->rotate(x2, "a2p"); // comm => 1, k + auto x3 = comm->rotate(x2, "a2p"); // comm => 1, k pforeach(0, numel, [&](int64_t idx) { - _out[idx] = in.at>(idx)[0] + - in.at>(idx)[1] + x3[idx]; + _out[idx] = _in[idx][0] + _in[idx][1] + x3[idx]; }); return out; @@ -179,28 +175,30 @@ NdArrayRef P2A::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { auto rank = comm->getRank(); return DISPATCH_ALL_FIELDS(field, "_", [&]() { - using AShrT = ring2k_t; - using PShrT = ring2k_t; + using ashr_el_t = ring2k_t; + using ashr_t = std::array; + using pshr_el_t = ring2k_t; NdArrayRef out(makeType(field), in.shape()); - auto _out = out.data>(); + NdArrayView _out(out); + NdArrayView _in(in); pforeach(0, in.numel(), [&](int64_t idx) { - _out[idx][0] = rank == 0 ? in.at(idx) : 0; - _out[idx][1] = rank == 2 ? in.at(idx) : 0; + _out[idx][0] = rank == 0 ? _in[idx] : 0; + _out[idx][1] = rank == 2 ? _in[idx] : 0; }); // for debug purpose, randomize the inputs to avoid corner cases. #ifdef ENABLE_MASK_DURING_ABY3_P2A - std::vector r0(in.numel()); - std::vector r1(in.numel()); + std::vector r0(in.numel()); + std::vector r1(in.numel()); auto* prg_state = ctx->getState(); prg_state->fillPrssPair(absl::MakeSpan(r0), absl::MakeSpan(r1)); for (int64_t idx = 0; idx < in.numel(); idx++) { r0[idx] = r0[idx] - r1[idx]; } - r1 = comm->rotate(r0, "p2a.zero"); + r1 = comm->rotate(r0, "p2a.zero"); for (int64_t idx = 0; idx < in.numel(); idx++) { _out[idx][0] += r0[idx]; @@ -218,32 +216,31 @@ NdArrayRef A2V::proc(KernelEvalContext* ctx, const NdArrayRef& in, const auto field = in.eltype().as()->field(); return DISPATCH_ALL_FIELDS(field, "_", [&]() { - using VShrT = ring2k_t; - using AShrT = ring2k_t; + using vshr_el_t = ring2k_t; + using ashr_el_t = ring2k_t; + using ashr_t = std::array; - // auto _in = ArrayView>(in); + NdArrayView _in(in); auto out_ty = makeType(field, rank); if (comm->getRank() == rank) { - auto x3 = comm->recv(comm->nextRank(), "a2v"); // comm => 1, k - // + auto x3 = comm->recv(comm->nextRank(), "a2v"); // comm => 1, k + // NdArrayRef out(out_ty, in.shape()); - auto _out = out.data(); + NdArrayView _out(out); pforeach(0, in.numel(), [&](int64_t idx) { - _out[idx] = in.at>(idx)[0] + - in.at>(idx)[1] + x3[idx]; + _out[idx] = _in[idx][0] + _in[idx][1] + x3[idx]; }); return out; } else if (comm->getRank() == (rank + 1) % 3) { - std::vector x2(in.numel()); + std::vector x2(in.numel()); - pforeach(0, in.numel(), [&](int64_t idx) { - x2[idx] = in.at>(idx)[1]; - }); + pforeach(0, in.numel(), [&](int64_t idx) { x2[idx] = _in[idx][1]; }); - comm->sendAsync(comm->prevRank(), x2, "a2v"); // comm => 1, k + comm->sendAsync(comm->prevRank(), x2, + "a2v"); // comm => 1, k return makeConstantArrayRef(out_ty, in.shape()); } else { return makeConstantArrayRef(out_ty, in.shape()); @@ -260,28 +257,32 @@ NdArrayRef V2A::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { size_t owner_rank = in_ty->owner(); return DISPATCH_ALL_FIELDS(field, "_", [&]() { - using AShrT = ring2k_t; + using ashr_el_t = ring2k_t; + using ashr_t = std::array; NdArrayRef out(makeType(field), in.shape()); - auto _out = out.data>(); + NdArrayView _out(out); if (comm->getRank() == owner_rank) { auto splits = ring_rand_additive_splits(in, 2); comm->sendAsync(comm->nextRank(), splits[1], "v2a"); // comm => 1, k comm->sendAsync(comm->prevRank(), splits[0], "v2a"); // comm => 1, k + NdArrayView _s0(splits[0]); + NdArrayView _s1(splits[1]); + pforeach(0, in.numel(), [&](int64_t idx) { - _out[idx][0] = splits[0].at(idx); - _out[idx][1] = splits[1].at(idx); + _out[idx][0] = _s0[idx]; + _out[idx][1] = _s1[idx]; }); } else if (comm->getRank() == (owner_rank + 1) % 3) { - auto x0 = comm->recv(comm->prevRank(), "v2a"); // comm => 1, k + auto x0 = comm->recv(comm->prevRank(), "v2a"); // comm => 1, k pforeach(0, in.numel(), [&](int64_t idx) { _out[idx][0] = x0[idx]; _out[idx][1] = 0; }); } else { - auto x1 = comm->recv(comm->nextRank(), "v2a"); // comm => 1, k + auto x1 = comm->recv(comm->nextRank(), "v2a"); // comm => 1, k pforeach(0, in.numel(), [&](int64_t idx) { _out[idx][0] = 0; _out[idx][1] = x1[idx]; @@ -300,16 +301,18 @@ NdArrayRef NotA::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { auto rank = comm->getRank(); return DISPATCH_ALL_FIELDS(field, "_", [&]() { - using S = std::make_unsigned_t; + using el_t = std::make_unsigned_t; + using shr_t = std::array; NdArrayRef out(makeType(field), in.shape()); - auto _out = out.data>(); + NdArrayView _out(out); + NdArrayView _in(in); // neg(x) = not(x) + 1 // not(x) = neg(x) - 1 pforeach(0, in.numel(), [&](int64_t idx) { - _out[idx][0] = -in.at>(idx)[0]; - _out[idx][1] = -in.at>(idx)[1]; + _out[idx][0] = -_in[idx][0]; + _out[idx][1] = -_in[idx][1]; if (rank == 0) { _out[idx][1] -= 1; } else if (rank == 1) { @@ -336,16 +339,19 @@ NdArrayRef AddAP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, auto rank = comm->getRank(); return DISPATCH_ALL_FIELDS(field, "_", [&]() { - using U = ring2k_t; + using el_t = ring2k_t; + using shr_t = std::array; NdArrayRef out(makeType(field), lhs.shape()); - auto* _out = out.data>(); + NdArrayView _out(out); + NdArrayView _lhs(lhs); + NdArrayView _rhs(rhs); pforeach(0, lhs.numel(), [&](int64_t idx) { - _out[idx][0] = lhs.at>(idx)[0]; - _out[idx][1] = lhs.at>(idx)[1]; - if (rank == 0) _out[idx][1] += rhs.at(idx); - if (rank == 1) _out[idx][0] += rhs.at(idx); + _out[idx][0] = _lhs[idx][0]; + _out[idx][1] = _lhs[idx][1]; + if (rank == 0) _out[idx][1] += _rhs[idx]; + if (rank == 1) _out[idx][0] += _rhs[idx]; }); return out; }); @@ -360,16 +366,16 @@ NdArrayRef AddAA::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const auto field = lhs_ty->field(); return DISPATCH_ALL_FIELDS(field, "_", [&]() { - using U = ring2k_t; + using shr_t = std::array; NdArrayRef out(makeType(field), lhs.shape()); - auto* _out = out.data>(); + NdArrayView _out(out); + NdArrayView _lhs(lhs); + NdArrayView _rhs(rhs); pforeach(0, lhs.numel(), [&](int64_t idx) { - _out[idx][0] = - lhs.at>(idx)[0] + rhs.at>(idx)[0]; - _out[idx][1] = - lhs.at>(idx)[1] + rhs.at>(idx)[1]; + _out[idx][0] = _lhs[idx][0] + _rhs[idx][0]; + _out[idx][1] = _lhs[idx][1] + _rhs[idx][1]; }); return out; }); @@ -387,14 +393,17 @@ NdArrayRef MulAP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const auto field = lhs_ty->field(); return DISPATCH_ALL_FIELDS(field, "_", [&]() { - using U = ring2k_t; + using el_t = ring2k_t; + using shr_t = std::array; NdArrayRef out(makeType(field), lhs.shape()); - auto* _out = out.data>(); + NdArrayView _out(out); + NdArrayView _lhs(lhs); + NdArrayView _rhs(rhs); pforeach(0, lhs.numel(), [&](int64_t idx) { - _out[idx][0] = lhs.at>(idx)[0] * rhs.at(idx); - _out[idx][1] = lhs.at>(idx)[1] * rhs.at(idx); + _out[idx][0] = _lhs[idx][0] * _rhs[idx]; + _out[idx][1] = _lhs[idx][1] * _rhs[idx]; }); return out; }); @@ -407,27 +416,26 @@ NdArrayRef MulAA::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, auto* prg_state = ctx->getState(); return DISPATCH_ALL_FIELDS(field, "aby3.mulAA", [&]() { - using U = ring2k_t; + using el_t = ring2k_t; + using shr_t = std::array; - std::vector r0(lhs.numel()); - std::vector r1(lhs.numel()); + std::vector r0(lhs.numel()); + std::vector r1(lhs.numel()); prg_state->fillPrssPair(absl::MakeSpan(r0), absl::MakeSpan(r1)); + NdArrayView _lhs(lhs); + NdArrayView _rhs(rhs); + // z1 = (x1 * y1) + (x1 * y2) + (x2 * y1) + (r0 - r1); pforeach(0, lhs.numel(), [&](int64_t idx) { - r0[idx] = (lhs.at>(idx)[0] * - rhs.at>(idx)[0]) + - (lhs.at>(idx)[0] * - rhs.at>(idx)[1]) + - (lhs.at>(idx)[1] * - rhs.at>(idx)[0]) + - (r0[idx] - r1[idx]); + r0[idx] = (_lhs[idx][0] * _rhs[idx][0]) + (_lhs[idx][0] * _rhs[idx][1]) + + (_lhs[idx][1] * _rhs[idx][0]) + (r0[idx] - r1[idx]); }); - r1 = comm->rotate(r0, "mulaa"); // comm => 1, k + r1 = comm->rotate(r0, "mulaa"); // comm => 1, k NdArrayRef out(makeType(field), lhs.shape()); - auto _out = out.data>(); + NdArrayView _out(out); pforeach(0, lhs.numel(), [&](int64_t idx) { _out[idx][0] = r0[idx]; @@ -556,15 +564,17 @@ NdArrayRef MulA1B::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, } DISPATCH_ALL_FIELDS(field, "_", [&]() { - using AShrT = ring2k_t; + // using = ring2k_t; + NdArrayView r1_0(r1.first); + NdArrayView r1_1(r1.second); + NdArrayView r2_0(r2.first); + NdArrayView r2_1(r2.second); // r1.first = r1.first + r2.first // r1.second = r1.second + r2.second pforeach(0, r1.first.numel(), [&](int64_t idx) { - r1.first.at(idx) = - r1.first.at(idx) + r2.first.at(idx); - r1.second.at(idx) = - r1.second.at(idx) + r2.second.at(idx); + r1_0[idx] = r1_0[idx] + r2_0[idx]; + r1_1[idx] = r1_1[idx] + r2_1[idx]; }); }); return r1; @@ -641,14 +651,15 @@ NdArrayRef LShiftA::proc(KernelEvalContext* ctx, const NdArrayRef& in, const auto field = in_ty->field(); return DISPATCH_ALL_FIELDS(field, "_", [&]() { - using U = ring2k_t; + using shr_t = std::array; NdArrayRef out(makeType(field), in.shape()); - auto _out = out.data>(); + NdArrayView _out(out); + NdArrayView _in(in); pforeach(0, in.numel(), [&](int64_t idx) { - _out[idx][0] = in.at>(idx)[0] << bits; - _out[idx][1] = in.at>(idx)[1] << bits; + _out[idx][0] = _in[idx][0] << bits; + _out[idx][1] = _in[idx][1] << bits; }); return out; @@ -749,40 +760,41 @@ NdArrayRef TruncAPr::proc(KernelEvalContext* ctx, const NdArrayRef& in, NdArrayRef out(in.eltype(), in.shape()); DISPATCH_ALL_FIELDS(field, "aby3.truncpr", [&]() { - using U = ring2k_t; + using el_t = ring2k_t; + using shr_t = std::array; - auto _out = out.data>(); + NdArrayView _out(out); + NdArrayView _in(in); if (comm->getRank() == P0) { - std::vector r(numel); + std::vector r(numel); prg_state->fillPrssPair(absl::MakeSpan(r), {}, false, true); - std::vector x_plus_r(numel); + std::vector x_plus_r(numel); pforeach(0, numel, [&](int64_t idx) { // convert to 2-outof-2 share. - auto x = - in.at>(idx)[0] + in.at>(idx)[1]; + auto x = _in[idx][0] + _in[idx][1]; // handle negative number. // assume secret x in [-2^(k-2), 2^(k-2)), by // adding 2^(k-2) x' = x + 2^(k-2) in [0, 2^(k-1)), with msb(x') == 0 - x += U(1) << (k - 2); + x += el_t(1) << (k - 2); // mask it with ra x_plus_r[idx] = x + r[idx]; }); // open c = + - auto c = openWith(comm, P1, x_plus_r); + auto c = openWith(comm, P1, x_plus_r); // get correlated randomness from P2 // let rb = r{k-1}, // rc = sum(r{i-m}<<(i-m)) for i in range(m, k-2) - auto cr = comm->recv(P2, "cr0"); + auto cr = comm->recv(P2, "cr0"); auto rb = absl::MakeSpan(cr).subspan(0, numel); auto rc = absl::MakeSpan(cr).subspan(numel, numel); - std::vector y2(numel); // the 2-out-of-2 truncation result + std::vector y2(numel); // the 2-out-of-2 truncation result pforeach(0, numel, [&](int64_t idx) { // c_hat = c/2^m mod 2^(k-m-1) = (c << 1) >> (1+m) auto c_hat = (c[idx] << 1) >> (1 + bits); @@ -798,18 +810,18 @@ NdArrayRef TruncAPr::proc(KernelEvalContext* ctx, const NdArrayRef& in, // re-encode negative numbers. // from https://eprint.iacr.org/2020/338.pdf, section 5.1 // y' = y - 2^(k-2-m) - y2[idx] = y - (U(1) << (k - 2 - bits)); + y2[idx] = y - (el_t(1) << (k - 2 - bits)); }); // - std::vector y1(numel); + std::vector y1(numel); prg_state->fillPrssPair(absl::MakeSpan(y1), {}, false, true); pforeach(0, numel, [&](int64_t idx) { // y2[idx] -= y1[idx]; }); - comm->sendAsync(P1, y2, "2to3"); - auto tmp = comm->recv(P1, "2to3"); + comm->sendAsync(P1, y2, "2to3"); + auto tmp = comm->recv(P1, "2to3"); // rebuild the final result. pforeach(0, numel, [&](int64_t idx) { @@ -817,26 +829,26 @@ NdArrayRef TruncAPr::proc(KernelEvalContext* ctx, const NdArrayRef& in, _out[idx][1] = y2[idx] + tmp[idx]; }); } else if (comm->getRank() == P1) { - std::vector r(numel); + std::vector r(numel); prg_state->fillPrssPair({}, absl::MakeSpan(r), true, false); - std::vector x_plus_r(numel); + std::vector x_plus_r(numel); pforeach(0, numel, [&](int64_t idx) { // let t as 2-out-of-2 share, mask it with ra. - x_plus_r[idx] = in.at>(idx)[1] + r[idx]; + x_plus_r[idx] = _in[idx][1] + r[idx]; }); // open c = + - auto c = openWith(comm, P0, x_plus_r); + auto c = openWith(comm, P0, x_plus_r); // get correlated randomness from P2 // let rb = r{k-1}, // rc = sum(r{i-m}<<(i-m)) for i in range(m, k-2) - auto cr = comm->recv(P2, "cr1"); + auto cr = comm->recv(P2, "cr1"); auto rb = absl::MakeSpan(cr).subspan(0, numel); auto rc = absl::MakeSpan(cr).subspan(numel, numel); - std::vector y2(numel); // the 2-out-of-2 truncation result + std::vector y2(numel); // the 2-out-of-2 truncation result pforeach(0, numel, [&](int64_t idx) { // = ^ c{k-1} = + c{k-1} - 2*c{k-1}* // note: is a randbit (in r^2k) @@ -847,13 +859,13 @@ NdArrayRef TruncAPr::proc(KernelEvalContext* ctx, const NdArrayRef& in, y2[idx] = 0 - rc[idx] + (b << (k - 1 - bits)); }); - std::vector y3(numel); + std::vector y3(numel); prg_state->fillPrssPair({}, absl::MakeSpan(y3), true, false); pforeach(0, numel, [&](int64_t idx) { // y2[idx] -= y3[idx]; }); - comm->sendAsync(P0, y2, "2to3"); - auto tmp = comm->recv(P0, "2to3"); + comm->sendAsync(P0, y2, "2to3"); + auto tmp = comm->recv(P0, "2to3"); // rebuild the final result. pforeach(0, numel, [&](int64_t idx) { @@ -861,12 +873,12 @@ NdArrayRef TruncAPr::proc(KernelEvalContext* ctx, const NdArrayRef& in, _out[idx][1] = y3[idx]; }); } else if (comm->getRank() == P2) { - std::vector r0(numel); - std::vector r1(numel); + std::vector r0(numel); + std::vector r1(numel); prg_state->fillPrssPair(absl::MakeSpan(r0), absl::MakeSpan(r1)); - std::vector cr0(2 * numel); - std::vector cr1(2 * numel); + std::vector cr0(2 * numel); + std::vector cr1(2 * numel); auto rb0 = absl::MakeSpan(cr0).subspan(0, numel); auto rc0 = absl::MakeSpan(cr0).subspan(numel, numel); auto rb1 = absl::MakeSpan(cr1).subspan(0, numel); @@ -888,11 +900,11 @@ NdArrayRef TruncAPr::proc(KernelEvalContext* ctx, const NdArrayRef& in, rc0[idx] += (r << 1) >> (bits + 1); }); - comm->sendAsync(P0, cr0, "cr0"); - comm->sendAsync(P1, cr1, "cr1"); + comm->sendAsync(P0, cr0, "cr0"); + comm->sendAsync(P1, cr1, "cr1"); - std::vector y3(numel); - std::vector y1(numel); + std::vector y3(numel); + std::vector y1(numel); prg_state->fillPrssPair(absl::MakeSpan(y3), absl::MakeSpan(y1)); pforeach(0, numel, [&](int64_t idx) { _out[idx][0] = y3[idx]; diff --git a/libspu/mpc/aby3/boolean.cc b/libspu/mpc/aby3/boolean.cc index 41a9f4f2..3fa5dd11 100644 --- a/libspu/mpc/aby3/boolean.cc +++ b/libspu/mpc/aby3/boolean.cc @@ -44,15 +44,20 @@ NdArrayRef CastTypeB::proc(KernelEvalContext* ctx, const NdArrayRef& in, const Type& to_type) const { NdArrayRef out(to_type, in.shape()); DISPATCH_UINT_PT_TYPES(in.eltype().as()->getBacktype(), "_", [&]() { - using InT = ScalarT; + using in_el_t = ScalarT; + using in_shr_t = std::array; + DISPATCH_UINT_PT_TYPES(to_type.as()->getBacktype(), "_", [&]() { - using OutT = ScalarT; - auto _out = out.data>(); + using out_el_t = ScalarT; + using out_shr_t = std::array; + + NdArrayView _out(out); + NdArrayView _in(in); pforeach(0, in.numel(), [&](int64_t idx) { - const auto& v = in.at>(idx); - _out[idx][0] = static_cast(v[0]); - _out[idx][1] = static_cast(v[1]); + const auto& v = _in[idx]; + _out[idx][0] = static_cast(v[0]); + _out[idx][1] = static_cast(v[1]); }); }); }); @@ -66,21 +71,23 @@ NdArrayRef B2P::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { const auto field = ctx->getState()->getDefaultField(); return DISPATCH_UINT_PT_TYPES(btype, "aby3.b2p", [&]() { - using BShrT = ScalarT; + using bshr_el_t = ScalarT; + using bshr_t = std::array; return DISPATCH_ALL_FIELDS(field, "_", [&]() { - using PShrT = ring2k_t; + using pshr_el_t = ring2k_t; NdArrayRef out(makeType(field), in.shape()); - auto _out = out.data(); + NdArrayView _out(out); + NdArrayView _in(in); - auto x2 = getShareAs(in, 1); - auto x3 = comm->rotate(x2, "b2p"); // comm => 1, k + auto x2 = getShareAs(in, 1); + auto x3 = comm->rotate(x2, "b2p"); // comm => 1, k pforeach(0, in.numel(), [&](int64_t idx) { - const auto& v = in.at>(idx); - _out[idx] = static_cast(v[0] ^ v[1] ^ x3[idx]); + const auto& v = _in[idx]; + _out[idx] = static_cast(v[0] ^ v[1] ^ x3[idx]); }); return out; @@ -96,22 +103,25 @@ NdArrayRef P2B::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { return DISPATCH_ALL_FIELDS(field, "_", [&]() { const size_t nbits = maxBitWidth(in); const PtType btype = calcBShareBacktype(nbits); + NdArrayView _in(in); return DISPATCH_UINT_PT_TYPES(btype, "_", [&]() { - using BShrT = ScalarT; + using bshr_el_t = ScalarT; + using bshr_t = std::array; + NdArrayRef out(makeType(btype, nbits), in.shape()); - auto _out = out.data>(); + NdArrayView _out(out); pforeach(0, in.numel(), [&](int64_t idx) { if (comm->getRank() == 0) { - _out[idx][0] = static_cast(in.at(idx)); + _out[idx][0] = static_cast(_in[idx]); _out[idx][1] = 0U; } else if (comm->getRank() == 1) { _out[idx][0] = 0U; _out[idx][1] = 0U; } else { _out[idx][0] = 0U; - _out[idx][1] = static_cast(in.at(idx)); + _out[idx][1] = static_cast(_in[idx]); } }); return out; @@ -126,32 +136,34 @@ NdArrayRef B2V::proc(KernelEvalContext* ctx, const NdArrayRef& in, const auto field = ctx->getState()->getDefaultField(); return DISPATCH_UINT_PT_TYPES(btype, "aby3.b2v", [&]() { - using BShrT = ScalarT; + using bshr_el_t = ScalarT; + using bshr_t = std::array; + NdArrayView _in(in); + return DISPATCH_ALL_FIELDS(field, "_", [&]() { - using VShrT = ring2k_t; + using vshr_scalar_t = ring2k_t; auto out_ty = makeType(field, rank); if (comm->getRank() == rank) { - auto x3 = comm->recv(comm->nextRank(), "b2v"); // comm => 1, k + auto x3 = + comm->recv(comm->nextRank(), "b2v"); // comm => 1, k NdArrayRef out(out_ty, in.shape()); - - auto _out = out.data(); + NdArrayView _out(out); pforeach(0, in.numel(), [&](int64_t idx) { - const auto& v = in.at>(idx); + const auto& v = _in[idx]; _out[idx] = v[0] ^ v[1] ^ x3[idx]; }); return out; } else if (comm->getRank() == (rank + 1) % 3) { - std::vector x2(in.numel()); + std::vector x2(in.numel()); - pforeach(0, in.numel(), [&](int64_t idx) { - x2[idx] = in.at>(idx)[1]; - }); + pforeach(0, in.numel(), [&](int64_t idx) { x2[idx] = _in[idx][1]; }); - comm->sendAsync(comm->prevRank(), x2, "b2v"); // comm => 1, k + comm->sendAsync(comm->prevRank(), x2, + "b2v"); // comm => 1, k return makeConstantArrayRef(out_ty, in.shape()); } else { @@ -167,23 +179,30 @@ NdArrayRef AndBP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const auto* rhs_ty = rhs.eltype().as(); return DISPATCH_ALL_FIELDS(rhs_ty->field(), "_", [&]() { - using RhsT = ring2k_t; - const size_t rhs_nbits = maxBitWidth(rhs); + using rhs_scalar_t = ring2k_t; + + const size_t rhs_nbits = maxBitWidth(rhs); const size_t out_nbits = std::min(lhs_ty->nbits(), rhs_nbits); const PtType out_btype = calcBShareBacktype(out_nbits); + NdArrayView _rhs(rhs); + return DISPATCH_UINT_PT_TYPES(lhs_ty->getBacktype(), "_", [&]() { - using LhsT = ScalarT; + using lhs_el_t = ScalarT; + using lhs_shr_t = std::array; + + NdArrayView _lhs(lhs); return DISPATCH_UINT_PT_TYPES(out_btype, "_", [&]() { - using OutT = ScalarT; + using out_el_t = ScalarT; + using out_shr_t = std::array; NdArrayRef out(makeType(out_btype, out_nbits), lhs.shape()); - auto _out = out.data>(); + NdArrayView _out(out); pforeach(0, lhs.numel(), [&](int64_t idx) { - const auto& l = lhs.at>(idx); - const auto& r = rhs.at(idx); + const auto& l = _lhs[idx]; + const auto& r = _rhs[idx]; _out[idx][0] = l[0] & r; _out[idx][1] = l[1] & r; }); @@ -207,27 +226,34 @@ NdArrayRef AndBB::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, NdArrayRef out(makeType(out_btype, out_nbits), lhs.shape()); return DISPATCH_UINT_PT_TYPES(rhs_ty->getBacktype(), "_", [&]() { - using RhsT = ScalarT; + using rhs_el_t = ScalarT; + using rhs_shr_t = std::array; + NdArrayView _rhs(rhs); + return DISPATCH_UINT_PT_TYPES(lhs_ty->getBacktype(), "_", [&]() { - using LhsT = ScalarT; + using lhs_el_t = ScalarT; + using lhs_shr_t = std::array; + NdArrayView _lhs(lhs); + return DISPATCH_UINT_PT_TYPES(out_btype, "_", [&]() { - using OutT = ScalarT; + using out_el_t = ScalarT; + using out_shr_t = std::array; - std::vector r0(lhs.numel()); - std::vector r1(lhs.numel()); + std::vector r0(lhs.numel()); + std::vector r1(lhs.numel()); prg_state->fillPrssPair(absl::MakeSpan(r0), absl::MakeSpan(r1)); // z1 = (x1 & y1) ^ (x1 & y2) ^ (x2 & y1) ^ (r0 ^ r1); pforeach(0, lhs.numel(), [&](int64_t idx) { - const auto& l = lhs.at>(idx); - const auto& r = rhs.at>(idx); + const auto& l = _lhs[idx]; + const auto& r = _rhs[idx]; r0[idx] = (l[0] & r[0]) ^ (l[0] & r[1]) ^ (l[1] & r[0]) ^ (r0[idx] ^ r1[idx]); }); - r1 = comm->rotate(r0, "andbb"); // comm => 1, k + r1 = comm->rotate(r0, "andbb"); // comm => 1, k - auto _out = out.data>(); + NdArrayView _out(out); pforeach(0, lhs.numel(), [&](int64_t idx) { _out[idx][0] = r0[idx]; _out[idx][1] = r1[idx]; @@ -244,24 +270,30 @@ NdArrayRef XorBP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const auto* rhs_ty = rhs.eltype().as(); return DISPATCH_ALL_FIELDS(rhs_ty->field(), "_", [&]() { - using RhsT = ring2k_t; + using rhs_scalar_t = ring2k_t; - const size_t rhs_nbits = maxBitWidth(rhs); + const size_t rhs_nbits = maxBitWidth(rhs); const size_t out_nbits = std::max(lhs_ty->nbits(), rhs_nbits); const PtType out_btype = calcBShareBacktype(out_nbits); + NdArrayView _rhs(rhs); + NdArrayRef out(makeType(out_btype, out_nbits), lhs.shape()); return DISPATCH_UINT_PT_TYPES(lhs_ty->getBacktype(), "_", [&]() { - using LhsT = ScalarT; + using lhs_el_t = ScalarT; + using lhs_shr_t = std::array; + + NdArrayView _lhs(lhs); return DISPATCH_UINT_PT_TYPES(out_btype, "_", [&]() { - using OutT = ScalarT; + using out_el_t = ScalarT; + using out_shr_t = std::array; - auto _out = out.data>(); + NdArrayView _out(out); pforeach(0, lhs.numel(), [&](int64_t idx) { - const auto& l = lhs.at>(idx); - const auto& r = rhs.at(idx); + const auto& l = _lhs[idx]; + const auto& r = _rhs[idx]; _out[idx][0] = l[0] ^ r; _out[idx][1] = l[1] ^ r; }); @@ -280,20 +312,27 @@ NdArrayRef XorBB::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const PtType out_btype = calcBShareBacktype(out_nbits); return DISPATCH_UINT_PT_TYPES(rhs_ty->getBacktype(), "_", [&]() { - using RhsT = ScalarT; + using rhs_el_t = ScalarT; + using rhs_shr_t = std::array; + + NdArrayView _rhs(rhs); return DISPATCH_UINT_PT_TYPES(lhs_ty->getBacktype(), "_", [&]() { - using LhsT = ScalarT; + using lhs_el_t = ScalarT; + using lhs_shr_t = std::array; + + NdArrayView _lhs(lhs); return DISPATCH_UINT_PT_TYPES(out_btype, "_", [&]() { - using OutT = ScalarT; + using out_el_t = ScalarT; + using out_shr_t = std::array; NdArrayRef out(makeType(out_btype, out_nbits), lhs.shape()); - auto _out = out.data>(); + NdArrayView _out(out); pforeach(0, lhs.numel(), [&](int64_t idx) { - const auto& l = lhs.at>(idx); - const auto& r = rhs.at>(idx); + const auto& l = _lhs[idx]; + const auto& r = _rhs[idx]; _out[idx][0] = l[0] ^ r[0]; _out[idx][1] = l[1] ^ r[1]; }); @@ -313,18 +352,22 @@ NdArrayRef LShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, const PtType out_btype = calcBShareBacktype(out_nbits); return DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), "_", [&]() { - using InT = ScalarT; + using in_el_t = ScalarT; + using in_shr_t = std::array; + + NdArrayView _in(in); return DISPATCH_UINT_PT_TYPES(out_btype, "_", [&]() { - using OutT = ScalarT; + using out_el_t = ScalarT; + using out_shr_t = std::array; NdArrayRef out(makeType(out_btype, out_nbits), in.shape()); - auto _out = out.data>(); + NdArrayView _out(out); pforeach(0, in.numel(), [&](int64_t idx) { - const auto& v = in.at>(idx); - _out[idx][0] = static_cast(v[0]) << bits; - _out[idx][1] = static_cast(v[1]) << bits; + const auto& v = _in[idx]; + _out[idx][0] = static_cast(v[0]) << bits; + _out[idx][1] = static_cast(v[1]) << bits; }); return out; @@ -342,18 +385,20 @@ NdArrayRef RShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, const PtType out_btype = calcBShareBacktype(out_nbits); return DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), "_", [&]() { - using InT = ScalarT; + using in_shr_t = std::array; + NdArrayView _in(in); return DISPATCH_UINT_PT_TYPES(out_btype, "_", [&]() { - using OutT = ScalarT; + using out_el_t = ScalarT; + using out_shr_t = std::array; NdArrayRef out(makeType(out_btype, out_nbits), in.shape()); - auto _out = out.data>(); + NdArrayView _out(out); pforeach(0, in.numel(), [&](int64_t idx) { - const auto& v = in.at>(idx); - _out[idx][0] = static_cast(v[0] >> bits); - _out[idx][1] = static_cast(v[1] >> bits); + const auto& v = _in[idx]; + _out[idx][0] = static_cast(v[0] >> bits); + _out[idx][1] = static_cast(v[1] >> bits); }); return out; @@ -374,12 +419,15 @@ NdArrayRef ARShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, const size_t out_nbits = in_ty->nbits(); return DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), "_", [&]() { - using T = std::make_signed_t; + using el_t = std::make_signed_t; + using shr_t = std::array; + NdArrayRef out(makeType(out_btype, out_nbits), in.shape()); - auto _out = out.data>(); + NdArrayView _out(out); + NdArrayView _in(in); pforeach(0, in.numel(), [&](int64_t idx) { - const auto& v = in.at>(idx); + const auto& v = _in[idx]; _out[idx][0] = v[0] >> bits; _out[idx][1] = v[1] >> bits; }); @@ -397,30 +445,34 @@ NdArrayRef BitrevB::proc(KernelEvalContext* ctx, const NdArrayRef& in, const PtType out_btype = calcBShareBacktype(out_nbits); return DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), "_", [&]() { - using InT = ScalarT; + using in_el_t = ScalarT; + using in_shr_t = std::array; + + NdArrayView _in(in); return DISPATCH_UINT_PT_TYPES(out_btype, "_", [&]() { - using OutT = ScalarT; + using out_el_t = ScalarT; + using out_shr_t = std::array; NdArrayRef out(makeType(out_btype, out_nbits), in.shape()); - auto _out = out.data>(); + NdArrayView _out(out); - auto bitrev_fn = [&](OutT el) -> OutT { - OutT tmp = 0U; + auto bitrev_fn = [&](out_el_t el) -> out_el_t { + out_el_t tmp = 0U; for (size_t idx = start; idx < end; idx++) { - if (el & ((OutT)1 << idx)) { - tmp |= (OutT)1 << (end - 1 - idx + start); + if (el & ((out_el_t)1 << idx)) { + tmp |= (out_el_t)1 << (end - 1 - idx + start); } } - OutT mask = ((OutT)1U << end) - ((OutT)1U << start); + out_el_t mask = ((out_el_t)1U << end) - ((out_el_t)1U << start); return (el & ~mask) | tmp; }; pforeach(0, in.numel(), [&](int64_t idx) { - const auto& v = in.at>(idx); - _out[idx][0] = bitrev_fn(static_cast(v[0])); - _out[idx][1] = bitrev_fn(static_cast(v[1])); + const auto& v = _in[idx]; + _out[idx][0] = bitrev_fn(static_cast(v[0])); + _out[idx][1] = bitrev_fn(static_cast(v[1])); }); return out; @@ -437,13 +489,15 @@ NdArrayRef BitIntlB::proc(KernelEvalContext* ctx, const NdArrayRef& in, NdArrayRef out(in.eltype(), in.shape()); DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), "_", [&]() { - using T = ScalarT; - auto _out = out.data>(); + using el_t = ScalarT; + using shr_t = std::array; + NdArrayView _out(out); + NdArrayView _in(in); pforeach(0, in.numel(), [&](int64_t idx) { - const auto& v = in.at>(idx); - _out[idx][0] = BitIntl(v[0], stride, nbits); - _out[idx][1] = BitIntl(v[1], stride, nbits); + const auto& v = _in[idx]; + _out[idx][0] = BitIntl(v[0], stride, nbits); + _out[idx][1] = BitIntl(v[1], stride, nbits); }); }); @@ -458,13 +512,15 @@ NdArrayRef BitDeintlB::proc(KernelEvalContext* ctx, const NdArrayRef& in, NdArrayRef out(in.eltype(), in.shape()); DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), "_", [&]() { - using T = ScalarT; - auto _out = out.data>(); + using el_t = ScalarT; + using shr_t = std::array; + NdArrayView _out(out); + NdArrayView _in(in); pforeach(0, in.numel(), [&](int64_t idx) { - const auto& v = in.at>(idx); - _out[idx][0] = BitDeintl(v[0], stride, nbits); - _out[idx][1] = BitDeintl(v[1], stride, nbits); + const auto& v = _in[idx]; + _out[idx][0] = BitDeintl(v[0], stride, nbits); + _out[idx][1] = BitDeintl(v[1], stride, nbits); }); }); diff --git a/libspu/mpc/aby3/conversion.cc b/libspu/mpc/aby3/conversion.cc index 8cf10e23..8e1bb028 100644 --- a/libspu/mpc/aby3/conversion.cc +++ b/libspu/mpc/aby3/conversion.cc @@ -63,28 +63,29 @@ NdArrayRef A2B::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { auto numel = in.numel(); DISPATCH_ALL_FIELDS(field, "_", [&]() { - // const auto _in = ArrayView>(in); - using AShrT = ring2k_t; + using ashr_t = std::array; + NdArrayView _in(in); DISPATCH_UINT_PT_TYPES(out_btype, "_", [&]() { - using BShrT = ScalarT; + using bshr_el_t = ScalarT; + using bshr_t = std::array; - std::vector r0(in.numel()); - std::vector r1(in.numel()); + std::vector r0(in.numel()); + std::vector r1(in.numel()); prg_state->fillPrssPair(absl::MakeSpan(r0), absl::MakeSpan(r1)); pforeach(0, numel, [&](int64_t idx) { r0[idx] ^= r1[idx]; if (comm->getRank() == 0) { - const auto& v = in.at>(idx); + const auto& v = _in[idx]; r0[idx] ^= v[0] + v[1]; } }); - r1 = comm->rotate(r0, "a2b"); // comm => 1, k + r1 = comm->rotate(r0, "a2b"); // comm => 1, k - auto* _m = m.data>(); - auto* _n = n.data>(); + NdArrayView _m(m); + NdArrayView _n(n); pforeach(0, numel, [&](int64_t idx) { _m[idx][0] = r0[idx]; @@ -95,9 +96,9 @@ NdArrayRef A2B::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { _n[idx][1] = 0; } else if (comm->getRank() == 1) { _n[idx][0] = 0; - _n[idx][1] = in.at>(idx)[1]; + _n[idx][1] = _in[idx][1]; } else if (comm->getRank() == 2) { - _n[idx][0] = in.at>(idx)[0]; + _n[idx][0] = _in[idx][0]; _n[idx][1] = 0; } }); @@ -148,8 +149,7 @@ NdArrayRef B2AByPPA::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { if (in_nbits == 0) { // special case, it's known to be zero. DISPATCH_ALL_FIELDS(field, "_", [&]() { - using AShrT = ring2k_t; - auto _out = out.data>(); + NdArrayView> _out(out); pforeach(0, numel, [&](int64_t idx) { _out[idx][0] = 0; _out[idx][1] = 0; @@ -162,28 +162,30 @@ NdArrayRef B2AByPPA::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { auto* prg_state = ctx->getState(); DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), "_", [&]() { - using BShrT = ScalarT; + using bshr_t = std::array; + NdArrayView _in(in); + DISPATCH_ALL_FIELDS(field, "_", [&]() { - using AShrT = ring2k_t; + using ashr_el_t = ring2k_t; + using ashr_t = std::array; // first expand b share to a share length. const auto expanded_ty = makeType( calcBShareBacktype(SizeOf(field) * 8), SizeOf(field) * 8); NdArrayRef x(expanded_ty, in.shape()); - - auto _x = x.data>(); + NdArrayView _x(x); pforeach(0, numel, [&](int64_t idx) { - const auto& v = in.at>(idx); + const auto& v = _in[idx]; _x[idx][0] = v[0]; _x[idx][1] = v[1]; }); // P1 & P2 local samples ra, note P0's ra is not used. - std::vector ra0(numel); - std::vector ra1(numel); - std::vector rb0(numel); - std::vector rb1(numel); + std::vector ra0(numel); + std::vector ra1(numel); + std::vector rb0(numel); + std::vector rb1(numel); prg_state->fillPrssPair(absl::MakeSpan(ra0), absl::MakeSpan(ra1)); prg_state->fillPrssPair(absl::MakeSpan(rb0), absl::MakeSpan(rb1)); @@ -196,11 +198,11 @@ NdArrayRef B2AByPPA::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { rb0[idx] = zb; } }); - rb1 = comm->rotate(rb0, "b2a.rand"); // comm => 1, k + rb1 = comm->rotate(rb0, "b2a.rand"); // comm => 1, k // compute [x+r]B NdArrayRef r(expanded_ty, in.shape()); - auto _r = r.data>(); + NdArrayView _r(r); pforeach(0, numel, [&](int64_t idx) { _r[idx][0] = rb0[idx]; _r[idx][1] = rb1[idx]; @@ -208,34 +210,33 @@ NdArrayRef B2AByPPA::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { // comm => log(k) + 1, 2k(logk) + k auto x_plus_r = wrap_add_bb(ctx->sctx(), x, r); - // auto _x_plus_r = ArrayView>(x_plus_r); + NdArrayView _x_plus_r(x_plus_r); // reveal - std::vector x_plus_r_2(numel); + std::vector x_plus_r_2(numel); if (comm->getRank() == 0) { - x_plus_r_2 = comm->recv(2, "reveal.x_plus_r.to.P0"); + x_plus_r_2 = comm->recv(2, "reveal.x_plus_r.to.P0"); } else if (comm->getRank() == 2) { - std::vector x_plus_r_0(numel); - pforeach(0, numel, [&](int64_t idx) { - x_plus_r_0[idx] = x_plus_r.at>(idx)[0]; - }); - comm->sendAsync(0, x_plus_r_0, "reveal.x_plus_r.to.P0"); + std::vector x_plus_r_0(numel); + pforeach(0, numel, + [&](int64_t idx) { x_plus_r_0[idx] = _x_plus_r[idx][0]; }); + comm->sendAsync(0, x_plus_r_0, "reveal.x_plus_r.to.P0"); } // P0 hold x+r, P1 & P2 hold -r, reuse ra0 and ra1 as output auto self_rank = comm->getRank(); pforeach(0, numel, [&](int64_t idx) { if (self_rank == 0) { - const auto& x_r_v = x_plus_r.at>(idx); + const auto& x_r_v = _x_plus_r[idx]; ra0[idx] = x_r_v[0] ^ x_r_v[1] ^ x_plus_r_2[idx]; } else { ra0[idx] = -ra0[idx]; } }); - ra1 = comm->rotate(ra0, "b2a.rotate"); + ra1 = comm->rotate(ra0, "b2a.rotate"); - auto _out = out.data>(); + NdArrayView _out(out); pforeach(0, numel, [&](int64_t idx) { _out[idx][0] = ra0[idx]; _out[idx][1] = ra1[idx]; @@ -251,8 +252,10 @@ static std::vector bitDecompose(const NdArrayRef& in, size_t nbits) { // decompose each bit of an array of element. std::vector dep(numel * nbits); + NdArrayView _in(in); + pforeach(0, numel, [&](int64_t idx) { - const auto& v = in.at(idx); + const auto& v = _in[idx]; for (size_t bit = 0; bit < nbits; bit++) { size_t flat_idx = idx * nbits + bit; dep[flat_idx] = static_cast((v >> bit) & 0x1); @@ -315,8 +318,7 @@ NdArrayRef B2AByOT::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { if (in_nbits == 0) { // special case, it's known to be zero. DISPATCH_ALL_FIELDS(field, "_", [&]() { - using AShrT = ring2k_t; - auto _out = out.data>(); + NdArrayView> _out(out); pforeach(0, numel, [&](int64_t idx) { _out[idx][0] = 0; _out[idx][1] = 0; @@ -337,25 +339,28 @@ NdArrayRef B2AByOT::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { size_t P2 = (pivot + 2) % 3; DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), "_", [&]() { - using BShrT = ScalarT; + using bshr_el_t = ScalarT; + using bshr_t = std::array; + NdArrayView _in(in); DISPATCH_ALL_FIELDS(field, "_", [&]() { - using AShrT = ring2k_t; + using ashr_el_t = ring2k_t; + using ashr_t = std::array; - auto _out = out.data>(); + NdArrayView _out(out); const size_t total_nbits = numel * in_nbits; - std::vector r0(total_nbits); - std::vector r1(total_nbits); + std::vector r0(total_nbits); + std::vector r1(total_nbits); prg_state->fillPrssPair(absl::MakeSpan(r0), absl::MakeSpan(r1)); if (comm->getRank() == P0) { // the helper - auto b2 = bitDecompose(getShare(in, 1), in_nbits); + auto b2 = bitDecompose(getShare(in, 1), in_nbits); // gen masks with helper. - std::vector m0(total_nbits); - std::vector m1(total_nbits); + std::vector m0(total_nbits); + std::vector m1(total_nbits); prg_state->fillPrssPair(absl::MakeSpan(m0), {}, false, true); prg_state->fillPrssPair(absl::MakeSpan(m1), {}, false, true); @@ -365,10 +370,10 @@ NdArrayRef B2AByOT::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { [&](int64_t idx) { m0[idx] = !b2[idx] ? m0[idx] : m1[idx]; }); // send selected masked to receiver. - comm->sendAsync(P1, m0, "mc"); + comm->sendAsync(P1, m0, "mc"); - auto c1 = bitCompose(r0, in_nbits); - auto c2 = comm->recv(P1, "c2"); + auto c1 = bitCompose(r0, in_nbits); + auto c2 = comm->recv(P1, "c2"); pforeach(0, numel, [&](int64_t idx) { _out[idx][0] = c1[idx]; @@ -379,20 +384,20 @@ NdArrayRef B2AByOT::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { prg_state->fillPrssPair(absl::MakeSpan(r0), {}, false, false); prg_state->fillPrssPair(absl::MakeSpan(r0), {}, false, false); - auto b2 = bitDecompose(getShare(in, 0), in_nbits); + auto b2 = bitDecompose(getShare(in, 0), in_nbits); // ot.recv - auto mc = comm->recv(P0, "mc"); - auto m0 = comm->recv(P2, "m0"); - auto m1 = comm->recv(P2, "m1"); + auto mc = comm->recv(P0, "mc"); + auto m0 = comm->recv(P2, "m0"); + auto m1 = comm->recv(P2, "m1"); // rebuild c2 = (b1^b2^b3)-c1-c3 pforeach(0, total_nbits, [&](int64_t idx) { mc[idx] = !b2[idx] ? m0[idx] ^ mc[idx] : m1[idx] ^ mc[idx]; }); - auto c2 = bitCompose(mc, in_nbits); - comm->sendAsync(P0, c2, "c2"); - auto c3 = bitCompose(r1, in_nbits); + auto c2 = bitCompose(mc, in_nbits); + comm->sendAsync(P0, c2, "c2"); + auto c3 = bitCompose(r1, in_nbits); pforeach(0, numel, [&](int64_t idx) { _out[idx][0] = c2[idx]; @@ -400,26 +405,26 @@ NdArrayRef B2AByOT::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { }); } else if (comm->getRank() == P2) { // the sender. - auto c3 = bitCompose(r0, in_nbits); - auto c1 = bitCompose(r1, in_nbits); + auto c3 = bitCompose(r0, in_nbits); + auto c1 = bitCompose(r1, in_nbits); // c3 = r0, c1 = r1 // let mi := (i^b1^b3)−c1−c3 for i in {0, 1} // reuse r's memory for m pforeach(0, numel, [&](int64_t idx) { - const auto x = in.at>(idx); + const auto x = _in[idx]; auto xx = x[0] ^ x[1]; for (size_t bit = 0; bit < in_nbits; bit++) { size_t flat_idx = idx * in_nbits + bit; - AShrT t = r0[flat_idx] + r1[flat_idx]; + ashr_el_t t = r0[flat_idx] + r1[flat_idx]; r0[flat_idx] = ((xx >> bit) & 0x1) - t; r1[flat_idx] = ((~xx >> bit) & 0x1) - t; } }); // gen masks with helper. - std::vector m0(total_nbits); - std::vector m1(total_nbits); + std::vector m0(total_nbits); + std::vector m1(total_nbits); prg_state->fillPrssPair({}, absl::MakeSpan(m0), true, false); prg_state->fillPrssPair({}, absl::MakeSpan(m1), true, false); pforeach(0, total_nbits, [&](int64_t idx) { @@ -427,8 +432,8 @@ NdArrayRef B2AByOT::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { m1[idx] ^= r1[idx]; }); - comm->sendAsync(P1, m0, "m0"); - comm->sendAsync(P1, m1, "m1"); + comm->sendAsync(P1, m0, "m0"); + comm->sendAsync(P1, m1, "m1"); pforeach(0, numel, [&](int64_t idx) { _out[idx][0] = c3[idx]; @@ -476,18 +481,23 @@ NdArrayRef B2AByOT::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { NdArrayRef hi(out_type, in.shape()); DISPATCH_UINT_PT_TYPES(in_ty->getBacktype(), "_", [&]() { - using InT = ScalarT; + using in_el_t = ScalarT; + using in_shr_t = std::array; + NdArrayView _in(in); + DISPATCH_UINT_PT_TYPES(out_backtype, "_", [&]() { - using OutT = ScalarT; - auto _lo = lo.data>(); - auto _hi = hi.data>(); + using out_el_t = ScalarT; + using out_shr_t = std::array; + + NdArrayView _lo(lo); + NdArrayView _hi(hi); - if constexpr (sizeof(InT) <= 8) { + if constexpr (sizeof(out_el_t) <= 8) { pforeach(0, in.numel(), [&](int64_t idx) { constexpr uint64_t S = 0x5555555555555555; // 01010101 - const InT M = (InT(1) << (in_nbits / 2)) - 1; + const out_el_t M = (out_el_t(1) << (in_nbits / 2)) - 1; - const auto& r = in.at>(idx); + const auto& r = _in[idx]; _lo[idx][0] = pext_u64(r[0], S) & M; _hi[idx][0] = pext_u64(r[0], ~S) & M; @@ -496,7 +506,7 @@ NdArrayRef B2AByOT::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { }); } else { pforeach(0, in.numel(), [&](int64_t idx) { - auto r = in.at>(idx); + auto r = _in[idx]; // algorithm: // 0101010101010101 // swap ^^ ^^ ^^ ^^ @@ -506,8 +516,8 @@ NdArrayRef B2AByOT::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { // swap ^^^^^^^^ // 0000000011111111 for (int k = 0; k + 1 < Log2Ceil(in_nbits); k++) { - InT keep = static_cast(kKeepMasks[k]); - InT move = static_cast(kSwapMasks[k]); + auto keep = static_cast(kKeepMasks[k]); + auto move = static_cast(kSwapMasks[k]); int shift = 1 << k; r[0] = (r[0] & keep) ^ ((r[0] >> shift) & move) ^ @@ -515,11 +525,11 @@ NdArrayRef B2AByOT::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { r[1] = (r[1] & keep) ^ ((r[1] >> shift) & move) ^ ((r[1] & move) << shift); } - InT mask = (InT(1) << (in_nbits / 2)) - 1; - _lo[idx][0] = static_cast(r[0]) & mask; - _hi[idx][0] = static_cast(r[0] >> (in_nbits / 2)) & mask; - _lo[idx][1] = static_cast(r[1]) & mask; - _hi[idx][1] = static_cast(r[1] >> (in_nbits / 2)) & mask; + in_el_t mask = (in_el_t(1) << (in_nbits / 2)) - 1; + _lo[idx][0] = static_cast(r[0]) & mask; + _hi[idx][0] = static_cast(r[0] >> (in_nbits / 2)) & mask; + _lo[idx][1] = static_cast(r[1]) & mask; + _hi[idx][1] = static_cast(r[1] >> (in_nbits / 2)) & mask; }); } }); @@ -551,27 +561,29 @@ NdArrayRef MsbA2B::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { NdArrayRef m(bshr_type, in.shape()); NdArrayRef n(bshr_type, in.shape()); DISPATCH_ALL_FIELDS(field, "aby3.msb.split", [&]() { - using U = ring2k_t; + using el_t = ring2k_t; + using shr_t = std::array; - auto _m = m.data>(); - auto _n = n.data>(); + NdArrayView _in(in); + NdArrayView _m(m); + NdArrayView _n(n); - std::vector r0(numel); - std::vector r1(numel); + std::vector r0(numel); + std::vector r1(numel); prg_state->fillPrssPair(absl::MakeSpan(r0), absl::MakeSpan(r1)); pforeach(0, numel, [&](int64_t idx) { r0[idx] = r0[idx] ^ r1[idx]; if (comm->getRank() == 0) { - const auto& v = in.at>(idx); + const auto& v = _in[idx]; r0[idx] ^= (v[0] + v[1]); } }); - r1 = comm->rotate(r0, "m"); + r1 = comm->rotate(r0, "m"); pforeach(0, numel, [&](int64_t idx) { - const auto& v = in.at>(idx); + const auto& v = _in[idx]; _m[idx][0] = r0[idx]; _m[idx][1] = r1[idx]; _n[idx][0] = comm->getRank() == 2 ? v[0] : 0; diff --git a/libspu/mpc/aby3/io.cc b/libspu/mpc/aby3/io.cc index a0e67150..87c96309 100644 --- a/libspu/mpc/aby3/io.cc +++ b/libspu/mpc/aby3/io.cc @@ -111,21 +111,24 @@ std::vector Aby3Io::makeBitSecret(const NdArrayRef& in) const { NdArrayRef(out_type, in.shape())}; return DISPATCH_UINT_PT_TYPES(in_pt_type, "_", [&]() { - using InT = ScalarT; - using BShrT = uint8_t; + using in_el_t = ScalarT; + using bshr_el_t = uint8_t; + using bshr_t = std::array; - std::vector r0(numel); - std::vector r1(numel); + NdArrayView _in(in); + + std::vector r0(numel); + std::vector r1(numel); yacl::crypto::PrgAesCtr(yacl::crypto::RandSeed(), absl::MakeSpan(r0)); yacl::crypto::PrgAesCtr(yacl::crypto::RandSeed(), absl::MakeSpan(r1)); - auto _s0 = shares[0].data>(); - auto _s1 = shares[1].data>(); - auto _s2 = shares[2].data>(); + NdArrayView _s0(shares[0]); + NdArrayView _s1(shares[1]); + NdArrayView _s2(shares[2]); for (int64_t idx = 0; idx < in.numel(); idx++) { - const BShrT r2 = static_cast(in.at(idx)) - r0[idx] - r1[idx]; + const bshr_el_t r2 = static_cast(_in[idx]) - r0[idx] - r1[idx]; _s0[idx][0] = r0[idx] & 0x1; _s0[idx][1] = r1[idx] & 0x1; @@ -155,13 +158,16 @@ NdArrayRef Aby3Io::fromShares(const std::vector& shares) const { NdArrayRef out(makeType(field_), shares[0].shape()); DISPATCH_ALL_FIELDS(field_, "_", [&]() { - auto _out = out.data(); + using el_t = ring2k_t; + using shr_t = std::array; + NdArrayView _out(out); for (size_t si = 0; si < shares.size(); si++) { + NdArrayView _s(shares[si]); for (auto idx = 0; idx < shares[0].numel(); ++idx) { if (si == 0) { _out[idx] = 0; } - _out[idx] += shares[si].at>(idx)[0]; + _out[idx] += _s[idx][0]; } } }); @@ -170,17 +176,17 @@ NdArrayRef Aby3Io::fromShares(const std::vector& shares) const { NdArrayRef out(makeType(field_), shares[0].shape()); DISPATCH_ALL_FIELDS(field_, "_", [&]() { - using OutT = ring2k_t; - auto _out = out.data(); + NdArrayView _out(out); DISPATCH_UINT_PT_TYPES(eltype.as()->getBacktype(), "_", [&] { - using BShrT = ScalarT; + using shr_t = std::array; for (size_t si = 0; si < shares.size(); si++) { + NdArrayView _s(shares[si]); for (auto idx = 0; idx < shares[0].numel(); ++idx) { if (si == 0) { _out[idx] = 0; } - _out[idx] ^= shares[si].at>(idx)[0]; + _out[idx] ^= _s[idx][0]; } } }); diff --git a/libspu/mpc/aby3/value.h b/libspu/mpc/aby3/value.h index ade58f9b..1af40b62 100644 --- a/libspu/mpc/aby3/value.h +++ b/libspu/mpc/aby3/value.h @@ -60,8 +60,9 @@ std::vector getShareAs(const NdArrayRef& in, size_t share_idx) { std::vector res(numel); DISPATCH_UINT_PT_TYPES(share.eltype().as()->pt_type(), "_", [&]() { + NdArrayView _share(share); for (auto idx = 0; idx < numel; ++idx) { - res[idx] = share.at(idx); + res[idx] = _share[idx]; } }); diff --git a/libspu/mpc/cheetah/arith/cheetah_conv2d_test.cc b/libspu/mpc/cheetah/arith/cheetah_conv2d_test.cc index 10ab8d30..b05235a6 100644 --- a/libspu/mpc/cheetah/arith/cheetah_conv2d_test.cc +++ b/libspu/mpc/cheetah/arith/cheetah_conv2d_test.cc @@ -85,10 +85,10 @@ TEST_P(CheetahConv2dTest, Basic) { const int64_t kMaxDiff = 1; DISPATCH_ALL_FIELDS(field, "_", [&]() { - auto c = ArrayView(computed); - + ArrayView c(computed); + NdArrayView exp(expected); for (auto idx = 0; idx < expected.numel(); idx++) { - EXPECT_NEAR(c[idx], expected.at(idx), kMaxDiff); + EXPECT_NEAR(c[idx], exp[idx], kMaxDiff); } }); } diff --git a/libspu/mpc/cheetah/arithmetic.cc b/libspu/mpc/cheetah/arithmetic.cc index 3154b05d..e8c8776b 100644 --- a/libspu/mpc/cheetah/arithmetic.cc +++ b/libspu/mpc/cheetah/arithmetic.cc @@ -109,18 +109,15 @@ NdArrayRef MsbA2B::proc(KernelEvalContext* ctx, const NdArrayRef& x) const { using u2k = std::make_unsigned::type; const u2k mask = (static_cast(1) << shft) - 1; NdArrayRef adjusted = ring_zeros(field, {n}); - // auto xinp = ArrayView(x); - // auto xadj = ArrayView(adjusted); + auto xinp = NdArrayView(x); + auto xadj = NdArrayView(adjusted); if (rank == 0) { // x0 - pforeach(0, n, - [&](int64_t i) { adjusted.at(i) = x.at(i) & mask; }); + pforeach(0, n, [&](int64_t i) { xadj[i] = xinp[i] & mask; }); } else { // 2^{k - 1} - 1 - x1 - pforeach(0, n, [&](int64_t i) { - adjusted.at(i) = (mask - x.at(i)) & mask; - }); + pforeach(0, n, [&](int64_t i) { xadj[i] = (mask - xinp[i]) & mask; }); } NdArrayRef carry_bit(x.eltype(), x.shape()); @@ -145,9 +142,8 @@ NdArrayRef MsbA2B::proc(KernelEvalContext* ctx, const NdArrayRef& x) const { }); // [msb(x)]_B <- [1{x0 + x1 > 2^{k- 1} - 1]_B ^ msb(x0) - pforeach(0, n, [&](int64_t i) { - carry_bit.at(i) ^= (x.at(i) >> shft); - }); + NdArrayView _carry_bit(carry_bit); + pforeach(0, n, [&](int64_t i) { _carry_bit[i] ^= (xinp[i] >> shft); }); return carry_bit.as(makeType(field, 1)); }); diff --git a/libspu/mpc/common/prg_state.cc b/libspu/mpc/common/prg_state.cc index 18c1dce4..652487aa 100644 --- a/libspu/mpc/common/prg_state.cc +++ b/libspu/mpc/common/prg_state.cc @@ -88,21 +88,19 @@ std::pair PrgState::genPrssPair(FieldType field, if (!ignore_first) { new_counter = yacl::crypto::FillPRand( kAesType, self_seed_, 0, prss_counter_, - absl::MakeSpan(static_cast(r_self.data()), - r_self.buf()->size())); + absl::MakeSpan(r_self.data(), r_self.buf()->size())); } if (!ignore_second) { new_counter = yacl::crypto::FillPRand( kAesType, next_seed_, 0, prss_counter_, - absl::MakeSpan(static_cast(r_next.data()), - r_next.buf()->size())); + absl::MakeSpan(r_next.data(), r_next.buf()->size())); } if (new_counter == prss_counter_) { // both part ignored, dummy run to update counter... new_counter = yacl::crypto::DummyUpdateRandomCount( - prss_counter_, absl::MakeSpan(static_cast(r_next.data()), - r_next.buf()->size())); + prss_counter_, + absl::MakeSpan(r_next.data(), r_next.buf()->size())); } prss_counter_ = new_counter; @@ -113,7 +111,7 @@ NdArrayRef PrgState::genPriv(FieldType field, const Shape& shape) { NdArrayRef res(makeType(field), shape); priv_counter_ = yacl::crypto::FillPRand( kAesType, priv_seed_, 0, priv_counter_, - absl::MakeSpan(static_cast(res.data()), res.buf()->size())); + absl::MakeSpan(res.data(), res.buf()->size())); return res; } @@ -122,7 +120,7 @@ NdArrayRef PrgState::genPubl(FieldType field, const Shape& shape) { NdArrayRef res(makeType(field), shape); pub_counter_ = yacl::crypto::FillPRand( kAesType, pub_seed_, 0, pub_counter_, - absl::MakeSpan(static_cast(res.data()), res.buf()->size())); + absl::MakeSpan(res.data(), res.buf()->size())); return res; } diff --git a/libspu/mpc/common/pv2k.cc b/libspu/mpc/common/pv2k.cc index bd8edbac..3ef0824e 100644 --- a/libspu/mpc/common/pv2k.cc +++ b/libspu/mpc/common/pv2k.cc @@ -71,13 +71,13 @@ class V2P : public UnaryKernel { DISPATCH_ALL_FIELDS(field, "v2p", [&]() { std::vector priv(numel); + NdArrayView _in(in); - pforeach(0, in.numel(), - [&](int64_t idx) { priv[idx] = in.at(idx); }); + pforeach(0, in.numel(), [&](int64_t idx) { priv[idx] = _in[idx]; }); std::vector publ = comm->bcast(priv, owner, "v2p"); - auto* _out = out.data(); + NdArrayView _out(out); pforeach(0, in.numel(), [&](int64_t idx) { _out[idx] = publ[idx]; }); }); return out; diff --git a/libspu/mpc/io_interface.h b/libspu/mpc/io_interface.h index c87072df..085dabdf 100644 --- a/libspu/mpc/io_interface.h +++ b/libspu/mpc/io_interface.h @@ -34,10 +34,10 @@ class IoInterface { // @return a list of random values, each of which could be send to one mpc // engine. // - // TODO: currently, the `private type` is transprent to compiler. + // TODO: currently, the `private type` is transparent to compiler. // - // This function deos NOT handle encoding stuffs, it's the upper layer(hal)'s - // resposibility to encode to it to ring. + // This function does NOT handle encoding stuffs, it's the upper layer(hal)'s + // responsibility to encode to it to ring. virtual std::vector toShares(const NdArrayRef& raw, Visibility vis, int owner_rank = -1) const = 0; diff --git a/libspu/mpc/ref2k/ref2k.cc b/libspu/mpc/ref2k/ref2k.cc index 84e7140d..ad117ddb 100644 --- a/libspu/mpc/ref2k/ref2k.cc +++ b/libspu/mpc/ref2k/ref2k.cc @@ -143,15 +143,17 @@ class Ref2kV2S : public UnaryKernel { int64_t numel = in.numel(); DISPATCH_ALL_FIELDS(field, "v2s", [&]() { - std::vector _in(numel); + std::vector _send(numel); + NdArrayView _in(in); for (int64_t idx = 0; idx < numel; ++idx) { - _in[idx] = in.at(idx); + _send[idx] = _in[idx]; } - std::vector _out = comm->bcast(_in, owner, "v2s"); + std::vector _recv = comm->bcast(_send, owner, "v2s"); + NdArrayView _out(out); for (int64_t idx = 0; idx < numel; ++idx) { - out.at(idx) = _out[idx]; + _out[idx] = _recv[idx]; } }); return out; diff --git a/libspu/mpc/semi2k/arithmetic.cc b/libspu/mpc/semi2k/arithmetic.cc index 71c5b02c..2a329ee4 100644 --- a/libspu/mpc/semi2k/arithmetic.cc +++ b/libspu/mpc/semi2k/arithmetic.cc @@ -75,19 +75,21 @@ NdArrayRef A2V::proc(KernelEvalContext* ctx, const NdArrayRef& in, return DISPATCH_ALL_FIELDS(field, "_", [&]() { std::vector share(numel); - pforeach(0, numel, [&](int64_t idx) { share[idx] = in.at(idx); }); + NdArrayView _in(in); + pforeach(0, numel, [&](int64_t idx) { share[idx] = _in[idx]; }); std::vector> shares = comm->gather(share, rank, "a2v"); // comm => 1, k if (comm->getRank() == rank) { SPU_ENFORCE(shares.size() == comm->getWorldSize()); NdArrayRef out(out_ty, in.shape()); + NdArrayView _out(out); pforeach(0, numel, [&](int64_t idx) { ring2k_t s = 0; for (auto& share : shares) { s += share[idx]; } - out.at(idx) = s; + _out[idx] = s; }); return out; } else { @@ -187,17 +189,19 @@ NdArrayRef MulAA::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, SPU_ENFORCE(a.isCompact() && b.isCompact() && c.isCompact(), "beaver must be compact"); - auto _a = a.data(); - auto _b = b.data(); - auto _c = c.data(); + NdArrayView _a(a); + NdArrayView _b(b); + NdArrayView _c(c); + NdArrayView _lhs(lhs); + NdArrayView _rhs(rhs); std::vector eu(numel * 2); absl::Span e(eu.data(), numel); absl::Span u(eu.data() + numel, numel); pforeach(0, numel, [&](int64_t idx) { - e[idx] = lhs.at(idx) - _a[idx]; // e = x - a; - u[idx] = rhs.at(idx) - _b[idx]; // u = y - b; + e[idx] = _lhs[idx] - _a[idx]; // e = x - a; + u[idx] = _rhs[idx] - _b[idx]; // u = y - b; }); // open x-a & y-b @@ -213,12 +217,13 @@ NdArrayRef MulAA::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, e = absl::Span(eu.data(), numel); u = absl::Span(eu.data() + numel, numel); + NdArrayView _res(res); // Zi = Ci + (X - A) * Bi + (Y - B) * Ai + <(X - A) * (Y - B)> pforeach(0, a.numel(), [&](int64_t idx) { - res.at(idx) = _c[idx] + e[idx] * _b[idx] + u[idx] * _a[idx]; + _res[idx] = _c[idx] + e[idx] * _b[idx] + u[idx] * _a[idx]; if (comm->getRank() == 0) { // z += (X-A) * (Y-B); - res.at(idx) += e[idx] * u[idx]; + _res[idx] += e[idx] * u[idx]; } }); }); @@ -327,11 +332,18 @@ NdArrayRef TruncAPr::proc(KernelEvalContext* ctx, const NdArrayRef& in, DISPATCH_ALL_FIELDS(field, "semi2k.truncpr", [&]() { using U = ring2k_t; + NdArrayView _in(in); + NdArrayView _r(r); + NdArrayView _rb(rb); + NdArrayView _rc(rc); + NdArrayView _out(out); + std::vector c; { std::vector x_plus_r(numel); + pforeach(0, numel, [&](int64_t idx) { - auto x = in.at(idx); + auto x = _in[idx]; // handle negative number. // assume secret x in [-2^(k-2), 2^(k-2)), by // adding 2^(k-2) x' = x + 2^(k-2) in [0, 2^(k-1)), with msb(x') == 0 @@ -339,7 +351,7 @@ NdArrayRef TruncAPr::proc(KernelEvalContext* ctx, const NdArrayRef& in, x += U(1) << (k - 2); } // mask x with r - x_plus_r[idx] = x + r.at(idx); + x_plus_r[idx] = x + _r[idx]; }); // open + = c c = comm->allReduce(x_plus_r, kBindName); @@ -351,21 +363,21 @@ NdArrayRef TruncAPr::proc(KernelEvalContext* ctx, const NdArrayRef& in, U y; if (comm->getRank() == 0) { // = ^ c{k-1} = + c{k-1} - 2*c{k-1}* - auto b = rb.at(idx) + ck_1 - 2 * ck_1 * rb.at(idx); + auto b = _rb[idx] + ck_1 - 2 * ck_1 * _rb[idx]; // c_hat = c/2^m mod 2^(k-m-1) = (c << 1) >> (1+m) auto c_hat = (c[idx] << 1) >> (1 + bits); // y = c_hat - + * 2^(k-m-1) - y = c_hat - rc.at(idx) + (b << (k - 1 - bits)); + y = c_hat - _rc[idx] + (b << (k - 1 - bits)); // re-encode negative numbers. // from https://eprint.iacr.org/2020/338.pdf, section 5.1 // y' = y - 2^(k-2-m) y -= (U(1) << (k - 2 - bits)); } else { - auto b = rb.at(idx) + 0 - 2 * ck_1 * rb.at(idx); - y = 0 - rc.at(idx) + (b << (k - 1 - bits)); + auto b = _rb[idx] + 0 - 2 * ck_1 * _rb[idx]; + y = 0 - _rc[idx] + (b << (k - 1 - bits)); } - out.at(idx) = y; + _out[idx] = y; }); }); diff --git a/libspu/mpc/semi2k/beaver/beaver_test.cc b/libspu/mpc/semi2k/beaver/beaver_test.cc index a0dfeb7e..562d9b34 100644 --- a/libspu/mpc/semi2k/beaver/beaver_test.cc +++ b/libspu/mpc/semi2k/beaver/beaver_test.cc @@ -116,9 +116,9 @@ TEST_P(BeaverTest, Mul_large) { } DISPATCH_ALL_FIELDS(kField, "_", [&]() { - auto _a = sum_a.data(); - auto _b = sum_b.data(); - auto _c = sum_c.data(); + NdArrayView _a(sum_a); + NdArrayView _b(sum_b); + NdArrayView _c(sum_c); for (auto idx = 0; idx < sum_a.numel(); idx++) { auto t = _a[idx] * _b[idx]; auto err = t > _c[idx] ? t - _c[idx] : _c[idx] - t; @@ -159,9 +159,9 @@ TEST_P(BeaverTest, Mul) { } DISPATCH_ALL_FIELDS(kField, "_", [&]() { - auto _a = sum_a.data(); - auto _b = sum_b.data(); - auto _c = sum_c.data(); + NdArrayView _a(sum_a); + NdArrayView _b(sum_b); + NdArrayView _c(sum_c); for (auto idx = 0; idx < sum_a.numel(); idx++) { auto t = _a[idx] * _b[idx]; auto err = t > _c[idx] ? t - _c[idx] : _c[idx] - t; @@ -241,8 +241,8 @@ TEST_P(BeaverTest, Dot) { auto res = ring_mmul(sum_a, sum_b); DISPATCH_ALL_FIELDS(kField, "_", [&]() { - auto _r = res.data(); - auto _c = sum_c.data(); + NdArrayView _r(res); + NdArrayView _c(sum_c); for (auto idx = 0; idx < res.numel(); idx++) { auto err = _r[idx] > _c[idx] ? _r[idx] - _c[idx] : _c[idx] - _r[idx]; EXPECT_LE(err, kMaxDiff); @@ -287,8 +287,8 @@ TEST_P(BeaverTest, Dot_large) { auto res = ring_mmul(sum_a, sum_b); DISPATCH_ALL_FIELDS(kField, "_", [&]() { - auto _r = res.data(); - auto _c = sum_c.data(); + NdArrayView _r(res); + NdArrayView _c(sum_c); for (auto idx = 0; idx < res.numel(); idx++) { auto err = _r[idx] > _c[idx] ? _r[idx] - _c[idx] : _c[idx] - _r[idx]; EXPECT_LE(err, kMaxDiff); diff --git a/libspu/mpc/semi2k/boolean.cc b/libspu/mpc/semi2k/boolean.cc index 8eb850f4..9d7cb427 100644 --- a/libspu/mpc/semi2k/boolean.cc +++ b/libspu/mpc/semi2k/boolean.cc @@ -118,11 +118,12 @@ NdArrayRef AndBP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, NdArrayRef out(makeType(field, out_nbits), lhs.shape()); DISPATCH_ALL_FIELDS(field, "_", [&]() { - using T = ring2k_t; + NdArrayView _lhs(lhs); + NdArrayView _rhs(rhs); + NdArrayView _out(out); - pforeach(0, lhs.numel(), [&](int64_t idx) { - out.at(idx) = lhs.at(idx) & rhs.at(idx); - }); + pforeach(0, lhs.numel(), + [&](int64_t idx) { _out[idx] = _lhs[idx] & _rhs[idx]; }); }); return out; } @@ -143,6 +144,9 @@ NdArrayRef AndBB::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, NdArrayRef out(makeType(field, out_nbits), lhs.shape()); DISPATCH_ALL_FIELDS(field, "_", [&]() { using T = ring2k_t; + NdArrayView _lhs(lhs); + NdArrayView _rhs(rhs); + DISPATCH_UINT_PT_TYPES(backtype, "_", [&]() { using V = ScalarT; @@ -154,21 +158,21 @@ NdArrayRef AndBB::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, auto [a, b, c] = beaver->And(field, {numField}); SPU_ENFORCE(a.buf()->size() >= static_cast(numBytes)); - const auto* _a = reinterpret_cast(a.data()); - const auto* _b = reinterpret_cast(b.data()); - const auto* _c = reinterpret_cast(c.data()); + NdArrayView _a(a); + NdArrayView _b(b); + NdArrayView _c(c); // first half mask x^a, second half mask y^b. std::vector mask(numel * 2, 0); pforeach(0, numel, [&](int64_t idx) { - mask[idx] = lhs.at(idx) ^ _a[idx]; - mask[numel + idx] = rhs.at(idx) ^ _b[idx]; + mask[idx] = _lhs[idx] ^ _a[idx]; + mask[numel + idx] = _rhs[idx] ^ _b[idx]; }); mask = comm->allReduce(mask, "open(x^a,y^b)"); // Zi = Ci ^ ((X ^ A) & Bi) ^ ((Y ^ B) & Ai) ^ <(X ^ A) & (Y ^ B)> - auto* _z = reinterpret_cast(out.data()); + NdArrayView _z(out); pforeach(0, numel, [&](int64_t idx) { _z[idx] = _c[idx]; _z[idx] ^= mask[idx] & _b[idx]; @@ -263,12 +267,11 @@ NdArrayRef BitIntlB::proc(KernelEvalContext* ctx, const NdArrayRef& in, auto numel = in.numel(); DISPATCH_ALL_FIELDS(field, "_", [&]() { - using T = ring2k_t; - - auto _out = reinterpret_cast(out.data()); + NdArrayView _in(in); + NdArrayView _out(out); pforeach(0, numel, [&](int64_t idx) { - _out[idx] = BitIntl(in.at(idx), stride, nbits); + _out[idx] = BitIntl(_in[idx], stride, nbits); }); }); @@ -285,12 +288,11 @@ NdArrayRef BitDeintlB::proc(KernelEvalContext* ctx, const NdArrayRef& in, auto numel = in.numel(); DISPATCH_ALL_FIELDS(field, "_", [&]() { - using T = ring2k_t; - - auto _out = reinterpret_cast(out.data()); + NdArrayView _in(in); + NdArrayView _out(out); pforeach(0, numel, [&](int64_t idx) { - _out[idx] = BitDeintl(in.at(idx), stride, nbits); + _out[idx] = BitDeintl(_in[idx], stride, nbits); }); }); diff --git a/libspu/mpc/semi2k/conversion.cc b/libspu/mpc/semi2k/conversion.cc index 7897b67d..1c38bc20 100644 --- a/libspu/mpc/semi2k/conversion.cc +++ b/libspu/mpc/semi2k/conversion.cc @@ -102,6 +102,9 @@ NdArrayRef B2A_Randbit::proc(KernelEvalContext* ctx, DISPATCH_ALL_FIELDS(field, kBindName, [&]() { using U = ring2k_t; + NdArrayView _randbits(randbits); + NdArrayView _x(x); + // algorithm begins. // Ref: III.D @ https://eprint.iacr.org/2019/599.pdf (SPDZ-2K primitives) std::vector x_xor_r(numel); @@ -110,24 +113,24 @@ NdArrayRef B2A_Randbit::proc(KernelEvalContext* ctx, // use _r[i*nbits, (i+1)*nbits) to construct rb[i] U mask = 0; for (int64_t bit = 0; bit < nbits; ++bit) { - mask += (randbits.at(idx * nbits + bit) & 0x1) << bit; + mask += (_randbits[idx * nbits + bit] & 0x1) << bit; } - x_xor_r[idx] = x.at(idx) ^ mask; + x_xor_r[idx] = _x[idx] ^ mask; }); // open c = x ^ r x_xor_r = comm->allReduce(x_xor_r, "open(x^r)"); + NdArrayView _res(res); pforeach(0, numel, [&](int64_t idx) { - res.at(idx) = 0; + _res[idx] = 0; for (int64_t bit = 0; bit < nbits; bit++) { auto c_i = (x_xor_r[idx] >> bit) & 0x1; if (comm->getRank() == 0) { - res.at(idx) += - (c_i + (1 - c_i * 2) * randbits.at(idx * nbits + bit)) << bit; + _res[idx] += (c_i + (1 - c_i * 2) * _randbits[idx * nbits + bit]) + << bit; } else { - res.at(idx) += ((1 - c_i * 2) * randbits.at(idx * nbits + bit)) - << bit; + _res[idx] += ((1 - c_i * 2) * _randbits[idx * nbits + bit]) << bit; } } }); diff --git a/libspu/mpc/spdz2k/abprotocol_spdz2k_test.cc b/libspu/mpc/spdz2k/abprotocol_spdz2k_test.cc index 465115b6..b49e34c2 100644 --- a/libspu/mpc/spdz2k/abprotocol_spdz2k_test.cc +++ b/libspu/mpc/spdz2k/abprotocol_spdz2k_test.cc @@ -246,9 +246,9 @@ TEST_P(ConversionTest, BitLT) { using U = std::make_unsigned::type; size_t numel = kShape.numel(); - auto p0_data = p0.data().data(); - auto p1_data = p1.data().data(); - auto re_data = re.data().data(); + NdArrayView p0_data(p0.data()); + NdArrayView p1_data(p1.data()); + NdArrayView re_data(re.data()); for (size_t i = 0; i < numel; ++i) { if ((p0_data[i] < p1_data[i])) { SPU_ENFORCE((re_data[i] == 1), "i {}, p0 {}, p1 {}", i, p0_data[i], @@ -294,9 +294,9 @@ TEST_P(ConversionTest, BitLE) { DISPATCH_ALL_FIELDS(field, "_", [&]() { using U = std::make_unsigned::type; size_t numel = kShape.numel(); - auto p0_data = p0.data().data(); - auto p1_data = p1.data().data(); - auto re_data = re.data().data(); + NdArrayView p0_data(p0.data()); + NdArrayView p1_data(p1.data()); + NdArrayView re_data(re.data()); for (size_t i = 0; i < numel; ++i) { if ((p0_data[i] <= p1_data[i])) { diff --git a/libspu/mpc/spdz2k/arithmetic.cc b/libspu/mpc/spdz2k/arithmetic.cc index b0949534..d3a54d44 100644 --- a/libspu/mpc/spdz2k/arithmetic.cc +++ b/libspu/mpc/spdz2k/arithmetic.cc @@ -45,13 +45,14 @@ namespace { NdArrayRef CastRing(const NdArrayRef& in, FieldType out_field) { const auto* in_ty = in.eltype().as(); const auto in_field = in_ty->field(); + auto out = ring_zeros(out_field, in.shape()); + return DISPATCH_ALL_FIELDS(in_field, "_", [&]() { - using InT = ring2k_t; - auto out = ring_zeros(out_field, in.shape()); + NdArrayView _in(in); return DISPATCH_ALL_FIELDS(out_field, "_", [&]() { - using OutT = ring2k_t; + NdArrayView _out(out); pforeach(0, in.numel(), [&](int64_t idx) { - out.at(idx) = static_cast(in.at(idx)); + _out[idx] = static_cast(_in[idx]); }); return out; @@ -324,7 +325,7 @@ bool SingleCheck(KernelEvalContext* ctx, const NdArrayRef& in) { // 4. Check the consistency of y auto z = ring_sub(y_mac, ring_mul(plain_y, key)); - std::string z_str(reinterpret_cast(z.data()), z.numel() * z.elsize()); + std::string z_str(z.data(), z.numel() * z.elsize()); std::vector z_strs; SPU_ENFORCE(commit_and_open(comm->lctx(), z_str, &z_strs)); SPU_ENFORCE(z_strs.size() == comm->getWorldSize()); @@ -404,8 +405,9 @@ bool BatchCheck(KernelEvalContext* ctx, const std::vector& ins) { auto pub_r = beaver->genPublCoin(field, size); std::vector rv; uint128_t mask = (static_cast(1) << s) - 1; + NdArrayView _pub_r(pub_r); for (size_t i = 0; i < size; ++i) { - rv.emplace_back(pub_r.at(i) & mask); + rv.emplace_back(_pub_r[i] & mask); } auto plain_y = ring_zeros(field, ins[0].shape()); @@ -422,7 +424,7 @@ bool BatchCheck(KernelEvalContext* ctx, const std::vector& ins) { auto plain_y_mac_share = ring_mul(plain_y, key); auto z = ring_sub(m, plain_y_mac_share); - std::string z_str(reinterpret_cast(z.data()), z.numel() * z.elsize()); + std::string z_str(z.data(), z.numel() * z.elsize()); std::vector z_strs; YACL_ENFORCE(commit_and_open(lctx, z_str, &z_strs)); YACL_ENFORCE(z_strs.size() == comm->getWorldSize()); diff --git a/libspu/mpc/spdz2k/beaver/beaver_test.cc b/libspu/mpc/spdz2k/beaver/beaver_test.cc index 4fc257ce..0395ee21 100644 --- a/libspu/mpc/spdz2k/beaver/beaver_test.cc +++ b/libspu/mpc/spdz2k/beaver/beaver_test.cc @@ -127,10 +127,13 @@ TEST_P(BeaverTest, AuthAnd) { << sum_c << sum_key << sum_c_mac; DISPATCH_ALL_FIELDS(kField, "_", [&]() { + NdArrayView _valid_a(valid_a); + NdArrayView _valid_b(valid_b); + NdArrayView _valid_c(valid_c); + for (auto idx = 0; idx < sum_a.numel(); idx++) { - auto t = valid_a.at(idx) * valid_b.at(idx); - auto err = t > valid_c.at(idx) ? t - valid_c.at(idx) - : valid_c.at(idx) - t; + auto t = _valid_a[idx] * _valid_b[idx]; + auto err = t > _valid_c[idx] ? t - _valid_c[idx] : _valid_c[idx] - t; EXPECT_LE(err, kMaxDiff); } }); @@ -384,10 +387,13 @@ TEST_P(BeaverTest, AuthDot) { auto res = ring_mmul(sum_a, sum_b); DISPATCH_ALL_FIELDS(kField, "_", [&]() { + NdArrayView _sum_a(sum_a); + NdArrayView _sum_c(sum_c); + NdArrayView _res(res); + for (auto idx = 0; idx < res.numel(); idx++) { - auto err = res.at(idx) > sum_c.at(idx) - ? res.at(idx) - sum_c.at(idx) - : sum_c.at(idx) - res.at(idx); + auto err = _res[idx] > _sum_c[idx] ? _res[idx] - _sum_c[idx] + : _sum_c[idx] - _res[idx]; EXPECT_LE(err, kMaxDiff); } }); diff --git a/libspu/mpc/spdz2k/beaver/beaver_tfp.cc b/libspu/mpc/spdz2k/beaver/beaver_tfp.cc index 3ce26ed1..aab474e8 100644 --- a/libspu/mpc/spdz2k/beaver/beaver_tfp.cc +++ b/libspu/mpc/spdz2k/beaver/beaver_tfp.cc @@ -49,15 +49,17 @@ uint128_t BeaverTfpUnsafe::InitSpdzKey(FieldType field, size_t s) { auto a = prgCreateArray(field, {size}, seed_, &counter_, &desc); return DISPATCH_ALL_FIELDS(field, "_", [&]() { + NdArrayView _a(a); if (comm_->getRank() == 0) { auto t = tp_.adjustSpdzKey(desc); + NdArrayView _t(t); global_key_ = yacl::crypto::SecureRandSeed(); global_key_ &= (static_cast(1) << s) - 1; - a.at(0) += global_key_ - t.at(0); + _a[0] += global_key_ - _t[0]; } - spdz_key_ = a.at(0); + spdz_key_ = _a[0]; return spdz_key_; }); @@ -229,9 +231,8 @@ NdArrayRef BeaverTfpUnsafe::genPublCoin(FieldType field, int64_t numel) { } auto kAesType = yacl::crypto::SymmetricCrypto::CryptoType::AES128_CTR; - yacl::crypto::FillPRand( - kAesType, public_seed, 0, 0, - absl::MakeSpan(static_cast(res.data()), res.buf()->size())); + yacl::crypto::FillPRand(kAesType, public_seed, 0, 0, + absl::MakeSpan(res.data(), res.buf()->size())); return res; } @@ -291,7 +292,7 @@ bool BeaverTfpUnsafe::BatchMacCheck(const NdArrayRef& open_value, // 3. compute z, commit and open z auto z = ring_sub(check_mac, ring_mul(check_value, key)); - std::string z_str(reinterpret_cast(z.data()), z.numel() * z.elsize()); + std::string z_str(z.data(), z.numel() * z.elsize()); std::vector z_strs; SPU_ENFORCE(commit_and_open(lctx, z_str, &z_strs)); SPU_ENFORCE(z_strs.size() == comm->getWorldSize()); @@ -310,7 +311,7 @@ bool BeaverTfpUnsafe::BatchMacCheck(const NdArrayRef& open_value, const auto& _z_str = z_strs[i]; auto mem = std::make_shared(_z_str.data(), _z_str.size()); NdArrayRef a(mem, plain_z.eltype(), - {(int64_t)(_z_str.size() / SizeOf(field))}, {1}, 0); + {static_cast(_z_str.size() / SizeOf(field))}, {1}, 0); ring_add_(plain_z, a); } diff --git a/libspu/mpc/spdz2k/beaver/beaver_tinyot.cc b/libspu/mpc/spdz2k/beaver/beaver_tinyot.cc index 512fe694..fdadac65 100644 --- a/libspu/mpc/spdz2k/beaver/beaver_tinyot.cc +++ b/libspu/mpc/spdz2k/beaver/beaver_tinyot.cc @@ -72,9 +72,11 @@ NdArrayRef ring_sqrt2k(const NdArrayRef& x, size_t bits = 0) { auto ret = ring_zeros(field, x.shape()); DISPATCH_ALL_FIELDS(field, "_", [&]() { using U = std::make_unsigned::type; + NdArrayView _ret(ret); + NdArrayView _x(x); yacl::parallel_for(0, numel, 4096, [&](int64_t beg, int64_t end) { for (int64_t idx = beg; idx < end; ++idx) { - ret.at(idx) = Sqrt2k(x.at(idx), bits); + _ret[idx] = Sqrt2k(_x[idx], bits); } }); }); @@ -103,9 +105,11 @@ NdArrayRef ring_inv2k(const NdArrayRef& x, size_t bits = 0) { auto ret = ring_zeros(field, x.shape()); DISPATCH_ALL_FIELDS(field, "_", [&]() { using U = std::make_unsigned::type; + NdArrayView _ret(ret); + NdArrayView _x(x); yacl::parallel_for(0, numel, 4096, [&](int64_t beg, int64_t end) { for (int64_t idx = beg; idx < end; ++idx) { - ret.at(idx) = Invert2k(x.at(idx), bits); + _ret[idx] = Invert2k(_x[idx], bits); } }); }); @@ -117,9 +121,10 @@ std::vector ring_cast_vector_boolean(const NdArrayRef& x) { std::vector res(x.numel()); DISPATCH_ALL_FIELDS(field, "RingOps", [&]() { + NdArrayView _x(x); yacl::parallel_for(0, x.numel(), 4096, [&](size_t start, size_t end) { for (size_t i = start; i < end; i++) { - res[i] = static_cast(x.at(i) & 0x1); + res[i] = static_cast(_x[i] & 0x1); } }); }); @@ -206,10 +211,14 @@ NdArrayRef BeaverTinyOt::AuthArrayRef(const NdArrayRef& x, FieldType field, int64_t new_numel = t + 1; NdArrayRef x_hat(x.eltype(), {new_numel}); auto x_mask = ring_rand(field, {1}); + + NdArrayView _x_hat(x_hat); + NdArrayView _x(x); + NdArrayView _x_mask(x_mask); for (int i = 0; i < t; ++i) { - x_hat.at(i) = x.at(i); + _x_hat[i] = _x[i]; } - x_hat.at(t) = x_mask.at(0); + _x_hat[t] = _x_mask[0]; // 3. every pair calls vole && 4. receives vole output size_t WorldSize = comm_->getWorldSize(); @@ -241,14 +250,16 @@ NdArrayRef BeaverTinyOt::AuthArrayRef(const NdArrayRef& x, FieldType field, } auto m = ring_add(ring_mul(x_hat, spdz_key_), a_b); + NdArrayView _m(m); // Consistency check // 6. get l public random values auto pub_r = prg_state_->genPubl(field, {new_numel}); + NdArrayView _pub_r(pub_r); std::vector rv; size_t numel = x.numel(); for (size_t i = 0; i < numel; ++i) { - rv.emplace_back(pub_r.at(i)); + rv.emplace_back(_pub_r[i]); } rv.emplace_back(1); @@ -257,8 +268,8 @@ NdArrayRef BeaverTinyOt::AuthArrayRef(const NdArrayRef& x, FieldType field, T m_angle = 0; for (int64_t i = 0; i < new_numel; ++i) { // x_hat, not x - x_angle += rv[i] * x_hat.at(i); - m_angle += rv[i] * m.at(i); + x_angle += rv[i] * _x_hat[i]; + m_angle += rv[i] * _m[i]; } auto x_angle_sum = @@ -340,10 +351,11 @@ BeaverTinyOt::Triple_Pair BeaverTinyOt::AuthAnd(FieldType field, DISPATCH_ALL_FIELDS(field, "_", [&]() { using U = std::make_unsigned::type; auto _size = auth_abcr.choices.size(); + NdArrayView _spdz_choices(spdz_choices); // copy authbit choices yacl::parallel_for(0, _size, 4096, [&](int64_t beg, int64_t end) { for (int64_t idx = beg; idx < end; ++idx) { - spdz_choices.at(idx) = auth_abcr.choices[idx]; + _spdz_choices[idx] = auth_abcr.choices[idx]; } }); }); @@ -384,10 +396,14 @@ BeaverTinyOt::Triple_Pair BeaverTinyOt::AuthAnd(FieldType field, DISPATCH_ALL_FIELDS(field, "_", [&]() { using U = std::make_unsigned::type; + NdArrayView _check_spdz_bit(check_spdz_bit); + NdArrayView _check_spdz_mac(check_spdz_mac); + NdArrayView _spdz_choices(spdz_choices); + NdArrayView _spdz_mac(spdz_mac); for (int64_t i = 0; i < sigma; ++i) { - check_spdz_bit.at(i) = spdz_choices.at(3 * tinyot_num + i); - check_spdz_mac.at(i) = spdz_mac.at(3 * tinyot_num + i); + _check_spdz_bit[i] = _spdz_choices[3 * tinyot_num + i]; + _check_spdz_mac[i] = _spdz_mac[3 * tinyot_num + i]; check_tiny_bit.mac[i] = auth_abcr.mac[tinyot_num * 3 + i]; } for (int64_t j = 0; j < tinyot_num * 3; ++j) { @@ -397,8 +413,8 @@ BeaverTinyOt::Triple_Pair BeaverTinyOt::AuthAnd(FieldType field, for (size_t i = 0; i < sigma; ++i) { if (ceof & 1) { check_tiny_bit.mac[i] ^= auth_abcr.mac[j]; - check_spdz_bit.at(i) += spdz_choices.at(j); - check_spdz_mac.at(i) += spdz_mac.at(j); + _check_spdz_bit[i] += _spdz_choices[j]; + _check_spdz_mac[i] += _spdz_mac[j]; } ceof >>= 1; } @@ -534,12 +550,12 @@ BeaverTinyOt::Pair_Pair BeaverTinyOt::AuthTrunc(FieldType field, DISPATCH_ALL_FIELDS(field, "_", [&]() { using PShrT = ring2k_t; - auto _val = b_val.data(); - auto _mac = b_mac.data(); - auto _r_val = r_val.data(); - auto _r_mac = r_mac.data(); - auto _tr_val = tr_val.data(); - auto _tr_mac = tr_mac.data(); + NdArrayView _val(b_val); + NdArrayView _mac(b_mac); + NdArrayView _r_val(r_val); + NdArrayView _r_mac(r_mac); + NdArrayView _tr_val(tr_val); + NdArrayView _tr_mac(tr_mac); pforeach(0, size, [&](int64_t idx) { _r_val[idx] = 0; _r_mac[idx] = 0; @@ -659,9 +675,8 @@ NdArrayRef BeaverTinyOt::genPublCoin(FieldType field, int64_t num) { } const auto kAesType = yacl::crypto::SymmetricCrypto::CryptoType::AES128_CTR; - yacl::crypto::FillPRand( - kAesType, public_seed, 0, 0, - absl::MakeSpan(static_cast(res.data()), res.buf()->size())); + yacl::crypto::FillPRand(kAesType, public_seed, 0, 0, + absl::MakeSpan(res.data(), res.buf()->size())); return res; } @@ -693,7 +708,7 @@ bool BeaverTinyOt::BatchMacCheck(const NdArrayRef& open_value, // 4. local_mac = check_mac - check_value * key auto local_mac = ring_sub(check_mac, ring_mul(check_value, key)); // commit and reduce all macs - std::string mac_str(reinterpret_cast(local_mac.data()), + std::string mac_str(local_mac.data(), local_mac.numel() * local_mac.elsize()); std::vector all_mac_strs; SPU_ENFORCE(commit_and_open(comm_->lctx(), mac_str, &all_mac_strs)); @@ -754,8 +769,8 @@ void BeaverTinyOt::rotSend(FieldType field, NdArrayRef* q0, NdArrayRef* q1) { SPDLOG_DEBUG("rotSend start with numel {}", q0->numel()); SPU_ENFORCE(q0->numel() == q1->numel()); size_t numel = q0->numel(); - T* data0 = reinterpret_cast(q0->data()); - T* data1 = reinterpret_cast(q1->data()); + auto* data0 = q0->data(); + auto* data1 = q1->data(); SPU_ENFORCE(spdz2k_ot_primitives_ != nullptr); SPU_ENFORCE(spdz2k_ot_primitives_->GetSenderCOT() != nullptr); @@ -777,15 +792,17 @@ void BeaverTinyOt::rotRecv(FieldType field, const NdArrayRef& a, SPDLOG_DEBUG("rotRecv start with numel {}", a.numel()); size_t numel = a.numel(); std::vector b_v(numel); + + NdArrayView _a(a); for (size_t i = 0; i < numel; ++i) { - b_v[i] = a.at(i); + b_v[i] = _a[i]; } SPU_ENFORCE(spdz2k_ot_primitives_ != nullptr); SPU_ENFORCE(spdz2k_ot_primitives_->GetSenderCOT() != nullptr); SPU_ENFORCE(spdz2k_ot_primitives_->GetReceiverCOT() != nullptr); - T* data = reinterpret_cast(s->data()); + auto* data = s->data(); spdz2k_ot_primitives_->GetReceiverCOT()->RecvRMCC( b_v, absl::MakeSpan(data, numel)); spdz2k_ot_primitives_->GetReceiverCOT()->Flush(); @@ -806,9 +823,9 @@ NdArrayRef BeaverTinyOt::voleSend(FieldType field, const NdArrayRef& x) { SPU_ENFORCE(spdz2k_ot_primitives_->GetSenderCOT() != nullptr); NdArrayRef res(x.eltype(), x.shape()); - T* data = reinterpret_cast(res.data()); + auto* data = res.data(); spdz2k_ot_primitives_->GetSenderCOT()->SendVole( - absl::MakeConstSpan(reinterpret_cast(x.data()), x.numel()), + absl::MakeConstSpan(x.data(), x.numel()), absl::MakeSpan(data, x.numel())); return res; @@ -823,10 +840,9 @@ NdArrayRef BeaverTinyOt::voleRecv(FieldType field, const NdArrayRef& alpha) { SPU_ENFORCE(spdz2k_ot_primitives_->GetReceiverCOT() != nullptr); NdArrayRef res(makeType(field), alpha.shape()); - T* data = reinterpret_cast(res.data()); + auto* data = res.data(); spdz2k_ot_primitives_->GetReceiverCOT()->RecvVole( - absl::MakeConstSpan(reinterpret_cast(alpha.data()), - alpha.numel()), + absl::MakeConstSpan(alpha.data(), alpha.numel()), absl::MakeSpan(data, alpha.numel())); return res; @@ -911,8 +927,11 @@ BeaverTinyOt::Triple_Pair BeaverTinyOt::AuthMul(FieldType field, auto b = ring_rand(field, {_size}); auto b_arr = ring_zeros(field, {expand_tao}); + + NdArrayView _b(b); + NdArrayView _b_arr(b_arr); for (int64_t i = 0; i < expand_tao; ++i) { - b_arr.at(i) = b.at(i / tao); + _b_arr[i] = _b[i / tao]; } // Every ordered pair does following @@ -973,12 +992,21 @@ BeaverTinyOt::Triple_Pair BeaverTinyOt::AuthMul(FieldType field, NdArrayRef crc = ring_zeros(field, {_size}); NdArrayRef crc_hat = ring_zeros(field, {_size}); + NdArrayView _cra(cra); + NdArrayView _cra_hat(cra_hat); + NdArrayView _crc(crc); + NdArrayView _crc_hat(cra_hat); + NdArrayView _ra(ra); + NdArrayView _ra_hat(ra_hat); + NdArrayView _rc(rc); + NdArrayView _rc_hat(rc_hat); + for (int64_t i = 0; i < expand_tao; ++i) { - cra.at(i / tao) += ra.at(i); - cra_hat.at(i / tao) += ra_hat.at(i); + _cra[i / tao] += _ra[i]; + _cra_hat[i / tao] += _ra_hat[i]; - crc.at(i / tao) += rc.at(i); - crc_hat.at(i / tao) += rc_hat.at(i); + _crc[i / tao] += _rc[i]; + _crc_hat[i / tao] += _rc_hat[i]; } // Authenticate diff --git a/libspu/mpc/spdz2k/boolean.cc b/libspu/mpc/spdz2k/boolean.cc index 3ff51f18..5036e592 100644 --- a/libspu/mpc/spdz2k/boolean.cc +++ b/libspu/mpc/spdz2k/boolean.cc @@ -54,11 +54,13 @@ NdArrayRef P2Value(FieldType out_field, const NdArrayRef& in, size_t k, auto out = ring_zeros(out_field, out_shape); return DISPATCH_ALL_FIELDS(out_field, "_", [&]() { using BShrT = ring2k_t; + NdArrayView _in(in); + NdArrayView _out(out); + pforeach(0, in.numel(), [&](int64_t idx) { pforeach(0, valid_nbits, [&](int64_t jdx) { size_t offset = idx * valid_nbits + jdx; - out.at(offset) = - static_cast((in.at(idx) >> jdx) & 1); + _out[offset] = static_cast((_in[idx] >> jdx) & 1); }); }); @@ -74,7 +76,10 @@ std::pair RShiftBImpl(const NdArrayRef& in, const auto old_nbits = in.eltype().as()->nbits(); int64_t new_nbits = old_nbits - bits; - if (bits == 0) return {getValueShare(in), getMacShare(in)}; + if (bits == 0) { + return {getValueShare(in), getMacShare(in)}; + } + if (new_nbits <= 0) { return {ring_zeros(field, in.shape()), ring_zeros(field, in.shape())}; } @@ -284,13 +289,14 @@ NdArrayRef B2P::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { NdArrayRef out(makeType(out_field), in.shape()); + NdArrayView _out(out); + NdArrayView _value(value); pforeach(0, in.numel(), [&](int64_t idx) { PShrT t = 0; for (size_t jdx = 0; jdx < nbits; ++jdx) { - t |= static_cast((value.at(idx * nbits + jdx) & 1) - << jdx); + t |= static_cast((_value[idx * nbits + jdx] & 1) << jdx); } - out.at(idx) = t; + _out[idx] = t; }); return out; @@ -375,12 +381,15 @@ NdArrayRef BitrevB::proc(KernelEvalContext* ctx, const NdArrayRef& in, auto ret_mac = x_mac.clone(); DISPATCH_ALL_FIELDS(field, "_", [&]() { + NdArrayView _ret(ret); + NdArrayView _ret_mac(ret_mac); + NdArrayView _x(x); + NdArrayView _x_mac(x_mac); + for (size_t i = 0; i < static_cast(numel); ++i) { for (size_t j = start; j < end; ++j) { - ret.at(i * nbits + j) = - x.at(i * nbits + end + start - j - 1); - ret_mac.at(i * nbits + j) = - x_mac.at(i * nbits + end + start - j - 1); + _ret[i * nbits + j] = _x[i * nbits + end + start - j - 1]; + _ret_mac[i * nbits + j] = _x_mac[i * nbits + end + start - j - 1]; } } }); @@ -576,12 +585,14 @@ NdArrayRef BitIntlB::proc(KernelEvalContext* ctx, const NdArrayRef& in, using T = ring2k_t; if (in.eltype().isa()) { + NdArrayView _out(out); + NdArrayView _in(in); pforeach(0, in.numel(), [&](int64_t idx) { - out.at(idx) = BitIntl(in.at(idx), stride, k); + _out[idx] = BitIntl(_in[idx], stride, k); }); } else { - auto _in = in.data>(); - auto _out = out.data>(); + NdArrayView> _in(in); + NdArrayView> _out(out); size_t num_per_group = 1 << stride; size_t group_num = k / num_per_group + (k % num_per_group != 0); size_t half_group_num = (group_num + 1) / 2; @@ -621,12 +632,15 @@ NdArrayRef BitDeintlB::proc(KernelEvalContext* ctx, const NdArrayRef& in, using T = ring2k_t; if (in.eltype().isa()) { + NdArrayView _out(out); + NdArrayView _in(in); pforeach(0, in.numel(), [&](int64_t idx) { - out.at(idx) = BitDeintl(in.at(idx), stride, k); + _out[idx] = BitDeintl(_in[idx], stride, k); }); } else { - auto _in = in.data>(); - auto _out = out.data>(); + NdArrayView> _in(in); + NdArrayView> _out(out); + size_t num_per_group = 1 << stride; size_t group_num = k / num_per_group + (k % num_per_group != 0); size_t half_group_num = (group_num + 1) / 2; diff --git a/libspu/mpc/spdz2k/conversion.cc b/libspu/mpc/spdz2k/conversion.cc index 7a082961..6c67c129 100644 --- a/libspu/mpc/spdz2k/conversion.cc +++ b/libspu/mpc/spdz2k/conversion.cc @@ -196,12 +196,12 @@ NdArrayRef Bit2A::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { // 5. [x] = c + [r] - 2 * c * [r] return DISPATCH_ALL_FIELDS(field, "_", [&]() { - auto _c = c.data(); - auto _r = r.data(); - auto _r_mac = r_mac.data(); + NdArrayView _c(c); + NdArrayView _r(r); + NdArrayView _r_mac(r_mac); NdArrayRef out(makeType(field, true), _in.shape()); - auto _out = out.data>(); + NdArrayView> _out(out); pforeach(0, out.numel(), [&](int64_t idx) { _out[idx][0] = _r[idx] - 2 * _c[idx] * _r[idx]; @@ -280,7 +280,8 @@ NdArrayRef B2A::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { out_shape.back() /= nbits; // 1. get rand bit [r] - NdArrayRef r, r_mac; + NdArrayRef r; + NdArrayRef r_mac; std::tie(r, r_mac) = beaver->AuthRandBit(field, _in.shape(), k, s); auto ar = makeAShare(r, r_mac, field); @@ -292,7 +293,8 @@ NdArrayRef B2A::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { // Notice we only reserve the least significant bit const auto [bc_val, bc_mac] = BShareSwitch2Nbits(bc, 1); - NdArrayRef c, zero_mac; + NdArrayRef c; + NdArrayRef zero_mac; std::tie(c, zero_mac) = beaver->BatchOpen(bc_val, bc_mac, 1, s); SPU_ENFORCE(beaver->BatchMacCheck(c, zero_mac, 1, s)); ring_bitmask_(c, 0, 1); @@ -301,12 +303,13 @@ NdArrayRef B2A::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { return DISPATCH_ALL_FIELDS(field, "_", [&]() { NdArrayRef out(makeType(field, true), out_shape); NdArrayRef expand_out(makeType(field, true), _in.shape()); - auto _c = c.data(); - auto _r = r.data(); - auto _r_mac = r_mac.data(); - auto _out = out.data>(); - auto _expand_out = expand_out.data>(); + NdArrayView _c(c); + NdArrayView _r(r); + NdArrayView _r_mac(r_mac); + + NdArrayView> _out(out); + NdArrayView> _expand_out(expand_out); pforeach(0, _in.numel(), [&](int64_t idx) { _expand_out[idx][0] = (_r[idx] - 2 * _c[idx] * _r[idx]); diff --git a/libspu/mpc/spdz2k/io.cc b/libspu/mpc/spdz2k/io.cc index 295180c6..4a096587 100644 --- a/libspu/mpc/spdz2k/io.cc +++ b/libspu/mpc/spdz2k/io.cc @@ -61,10 +61,11 @@ std::vector Spdz2kIo::toShares(const NdArrayRef& raw, NdArrayRef x(makeType(runtime_field), raw.shape()); DISPATCH_ALL_FIELDS(field, "_", [&]() { - using raw_t = ring2k_t; + NdArrayView _raw(raw); DISPATCH_ALL_FIELDS(runtime_field, "_", [&]() { + NdArrayView _x(x); pforeach(0, raw.numel(), [&](int64_t idx) { - x.at(idx) = static_cast(raw.at(idx)); + _x[idx] = static_cast(_raw[idx]); }); }); }); @@ -113,10 +114,11 @@ NdArrayRef Spdz2kIo::fromShares(const std::vector& shares) const { NdArrayRef x(makeType(field_), res.shape()); DISPATCH_ALL_FIELDS(field, "_", [&]() { - using res_t = ring2k_t; + NdArrayView _res(res); DISPATCH_ALL_FIELDS(field_, "_", [&]() { + NdArrayView _x(x); pforeach(0, x.numel(), [&](int64_t idx) { - x.at(idx) = static_cast(res.at(idx)); + _x[idx] = static_cast(_res[idx]); }); }); }); diff --git a/libspu/mpc/spdz2k/value.cc b/libspu/mpc/spdz2k/value.cc index edda38e3..facdf0a2 100644 --- a/libspu/mpc/spdz2k/value.cc +++ b/libspu/mpc/spdz2k/value.cc @@ -54,20 +54,22 @@ NdArrayRef makeBShare(const NdArrayRef& s1, const NdArrayRef& s2, new_shape.back() /= nbits; NdArrayRef res(ty, new_shape); - size_t res_numel = res.numel(); + int64_t res_numel = res.numel(); DISPATCH_ALL_FIELDS(field, "_", [&]() { - auto _res = res.data>(); + NdArrayView> _res(res); pforeach(0, res_numel * k, [&](int64_t i) { _res[i][0] = 0; _res[i][1] = 0; }); + NdArrayView _s1(s1); + NdArrayView _s2(s2); pforeach(0, res_numel, [&](int64_t i) { pforeach(0, nbits, [&](int64_t j) { - _res[i * k + j][0] = s1.at(i * nbits + j); - _res[i * k + j][1] = s2.at(i * nbits + j); + _res[i * k + j][0] = _s1[i * nbits + j]; + _res[i * k + j][1] = _s2[i * nbits + j]; }); }); }); @@ -108,8 +110,8 @@ NdArrayRef getShare(const NdArrayRef& in, int64_t share_idx) { DISPATCH_ALL_FIELDS(field, "_", [&]() { size_t numel = in.numel(); - auto _ret = ret.data(); - auto _in = in.data>(); + NdArrayView _ret(ret); + NdArrayView> _in(in); pforeach(0, numel, [&](int64_t i) { pforeach(0, nbits, [&](int64_t j) { diff --git a/libspu/mpc/utils/ring_ops.cc b/libspu/mpc/utils/ring_ops.cc index 56f8912f..d480bda6 100644 --- a/libspu/mpc/utils/ring_ops.cc +++ b/libspu/mpc/utils/ring_ops.cc @@ -43,16 +43,17 @@ constexpr char kModule[] = "RingOps"; SPU_ENFORCE((lhs).shape() == (rhs).shape(), \ "numel mismatch, lhs={}, rhs={}", lhs, rhs); -#define DEF_UNARY_RING_OP(NAME, OP) \ - void NAME##_impl(NdArrayRef& ret, const NdArrayRef& x) { \ - ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, x); \ - const auto field = x.eltype().as()->field(); \ - const int64_t numel = ret.numel(); \ - return DISPATCH_ALL_FIELDS(field, kModule, [&]() { \ - using T = std::make_signed_t; \ - pforeach(0, numel, \ - [&](int64_t idx) { ret.at(idx) = OP x.at(idx); }); \ - }); \ +#define DEF_UNARY_RING_OP(NAME, OP) \ + void NAME##_impl(NdArrayRef& ret, const NdArrayRef& x) { \ + ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, x); \ + const auto field = x.eltype().as()->field(); \ + const int64_t numel = ret.numel(); \ + return DISPATCH_ALL_FIELDS(field, kModule, [&]() { \ + using T = std::make_signed_t; \ + NdArrayView _x(x); \ + NdArrayView _ret(ret); \ + pforeach(0, numel, [&](int64_t idx) { _ret[idx] = OP _x[idx]; }); \ + }); \ } DEF_UNARY_RING_OP(ring_not, ~); @@ -60,18 +61,20 @@ DEF_UNARY_RING_OP(ring_neg, -); #undef DEF_UNARY_RING_OP -#define DEF_BINARY_RING_OP(NAME, OP) \ - void NAME##_impl(NdArrayRef& ret, const NdArrayRef& x, \ - const NdArrayRef& y) { \ - ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, x); \ - ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, y); \ - const auto field = x.eltype().as()->field(); \ - const int64_t numel = ret.numel(); \ - return DISPATCH_ALL_FIELDS(field, kModule, [&]() { \ - pforeach(0, numel, [&](int64_t idx) { \ - ret.at(idx) = x.at(idx) OP y.at(idx); \ - }); \ - }); \ +#define DEF_BINARY_RING_OP(NAME, OP) \ + void NAME##_impl(NdArrayRef& ret, const NdArrayRef& x, \ + const NdArrayRef& y) { \ + ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, x); \ + ENFORCE_EQ_ELSIZE_AND_SHAPE(ret, y); \ + const auto field = x.eltype().as()->field(); \ + const int64_t numel = ret.numel(); \ + return DISPATCH_ALL_FIELDS(field, kModule, [&]() { \ + NdArrayView _x(x); \ + NdArrayView _y(y); \ + NdArrayView _ret(ret); \ + pforeach(0, numel, \ + [&](int64_t idx) { _ret[idx] = _x[idx] OP _y[idx]; }); \ + }); \ } DEF_BINARY_RING_OP(ring_add, +) @@ -92,9 +95,9 @@ void ring_arshift_impl(NdArrayRef& ret, const NdArrayRef& x, size_t bits) { // According to K&R 2nd edition the results are implementation-dependent for // right shifts of signed values, but "usually" its arithmetic right shift. using S = std::make_signed::type; - - pforeach(0, numel, - [&](int64_t idx) { ret.at(idx) = x.at(idx) >> bits; }); + NdArrayView _ret(ret); + NdArrayView _x(x); + pforeach(0, numel, [&](int64_t idx) { _ret[idx] = _x[idx] >> bits; }); }); } @@ -104,9 +107,9 @@ void ring_rshift_impl(NdArrayRef& ret, const NdArrayRef& x, size_t bits) { const auto field = x.eltype().as()->field(); return DISPATCH_ALL_FIELDS(field, kModule, [&]() { using U = ring2k_t; - - pforeach(0, numel, - [&](int64_t idx) { ret.at(idx) = x.at(idx) >> bits; }); + NdArrayView _ret(ret); + NdArrayView _x(x); + pforeach(0, numel, [&](int64_t idx) { _ret[idx] = _x[idx] >> bits; }); }); } @@ -115,9 +118,9 @@ void ring_lshift_impl(NdArrayRef& ret, const NdArrayRef& x, size_t bits) { const auto numel = ret.numel(); const auto field = x.eltype().as()->field(); return DISPATCH_ALL_FIELDS(field, kModule, [&]() { - pforeach(0, numel, [&](int64_t idx) { - ret.at(idx) = x.at(idx) << bits; - }); + NdArrayView _ret(ret); + NdArrayView _x(x); + pforeach(0, numel, [&](int64_t idx) { _ret[idx] = _x[idx] << bits; }); }); } @@ -144,8 +147,9 @@ void ring_bitrev_impl(NdArrayRef& ret, const NdArrayRef& x, size_t start, return (in & ~mask) | tmp; }; - pforeach(0, numel, - [&](int64_t idx) { ret.at(idx) = bitrev_fn(x.at(idx)); }); + NdArrayView _ret(ret); + NdArrayView _x(x); + pforeach(0, numel, [&](int64_t idx) { _ret[idx] = bitrev_fn(_x[idx]); }); }); } @@ -168,8 +172,9 @@ void ring_bitmask_impl(NdArrayRef& ret, const NdArrayRef& x, size_t low, auto mark_fn = [&](U el) { return el & mask; }; - pforeach(0, numel, - [&](int64_t idx) { ret.at(idx) = mark_fn(x.at(idx)); }); + NdArrayView _ret(ret); + NdArrayView _x(x); + pforeach(0, numel, [&](int64_t idx) { _ret[idx] = mark_fn(_x[idx]); }); }); } @@ -185,8 +190,9 @@ void ring_print(const NdArrayRef& x, std::string_view name) { std::string out; out += fmt::format("{} = {{", name); + NdArrayView _x(x); for (int64_t idx = 0; idx < x.numel(); idx++) { - const auto& current_v = x.at(idx); + const auto& current_v = _x[idx]; if (idx != 0) { out += fmt::format(", {0:X}", current_v); } else { @@ -212,7 +218,7 @@ NdArrayRef ring_rand(FieldType field, const Shape& shape, uint128_t prg_seed, NdArrayRef res(makeType(field), shape); *prg_counter = yacl::crypto::FillPRand( kCryptoType, prg_seed, kAesInitialVector, *prg_counter, - absl::MakeSpan(static_cast(res.data()), res.buf()->size())); + absl::MakeSpan(res.data(), res.buf()->size())); return res; } @@ -246,8 +252,9 @@ void ring_assign(NdArrayRef& x, const NdArrayRef& y) { const auto field = x.eltype().as()->field(); return DISPATCH_ALL_FIELDS(field, kModule, [&]() { - pforeach(0, numel, - [&](int64_t idx) { x.at(idx) = y.at(idx); }); + NdArrayView _y(y); + NdArrayView _x(x); + pforeach(0, numel, [&](int64_t idx) { _x[idx] = _y[idx]; }); }); } @@ -256,8 +263,8 @@ NdArrayRef ring_zeros(FieldType field, const Shape& shape) { auto numel = ret.numel(); return DISPATCH_ALL_FIELDS(field, kModule, [&]() { - pforeach(0, numel, - [&](int64_t idx) { ret.at(idx) = ring2k_t(0); }); + NdArrayView _ret(ret); + pforeach(0, numel, [&](int64_t idx) { _ret[idx] = ring2k_t(0); }); return ret; }); } @@ -267,8 +274,8 @@ NdArrayRef ring_ones(FieldType field, const Shape& shape) { auto numel = ret.numel(); return DISPATCH_ALL_FIELDS(field, kModule, [&]() { - pforeach(0, numel, - [&](int64_t idx) { ret.at(idx) = ring2k_t(1); }); + NdArrayView _ret(ret); + pforeach(0, numel, [&](int64_t idx) { _ret[idx] = ring2k_t(1); }); return ret; }); } @@ -282,8 +289,9 @@ NdArrayRef ring_randbit(FieldType field, const Shape& shape) { auto numel = ret.numel(); return DISPATCH_ALL_FIELDS(field, kModule, [&]() { + NdArrayView _ret(ret); for (auto idx = 0; idx < numel; ++idx) { - ret.at(idx) = distrib(gen) & 0x1; + _ret[idx] = distrib(gen) & 0x1; } return ret; }); @@ -335,8 +343,9 @@ void ring_mul_impl(NdArrayRef& ret, const NdArrayRef& x, uint128_t y) { const auto field = x.eltype().as()->field(); DISPATCH_ALL_FIELDS(field, kModule, [&]() { using U = std::make_unsigned::type; - - pforeach(0, numel, [&](int64_t idx) { ret.at(idx) = x.at(idx) * y; }); + NdArrayView _x(x); + NdArrayView _ret(ret); + pforeach(0, numel, [&](int64_t idx) { _ret[idx] = _x[idx] * y; }); }); } @@ -369,9 +378,9 @@ void ring_mmul_impl(NdArrayRef& z, const NdArrayRef& lhs, const auto LDC = ret_stride_scale * z.strides()[0]; const auto IDC = ret_stride_scale * z.strides()[1]; - linalg::matmul(M, N, K, static_cast(lhs.data()), LDA, IDA, - static_cast(rhs.data()), LDB, IDB, - static_cast(z.data()), LDC, IDC); + linalg::matmul(M, N, K, lhs.data(), LDA, IDA, + rhs.data(), LDB, IDB, z.data(), + LDC, IDC); }); } @@ -492,10 +501,13 @@ bool ring_all_equal(const NdArrayRef& x, const NdArrayRef& y, size_t abs_err) { return DISPATCH_ALL_FIELDS(field, "_", [&]() { using T = std::make_signed_t; + NdArrayView _x(x); + NdArrayView _y(y); + bool passed = true; for (int64_t idx = 0; idx < numel; ++idx) { - auto x_el = x.at(idx); - auto y_el = y.at(idx); + auto x_el = _x[idx]; + auto y_el = _y[idx]; if (std::abs(x_el - y_el) > static_cast(abs_err)) { fmt::print("error: {0} {1} abs_err: {2}\n", x_el, y_el, abs_err); return false; @@ -513,8 +525,9 @@ std::vector ring_cast_boolean(const NdArrayRef& x) { std::vector res(numel); DISPATCH_ALL_FIELDS(field, kModule, [&]() { + NdArrayView _x(x); pforeach(0, numel, [&](int64_t idx) { - res[idx] = static_cast(x.at(idx) & 0x1); + res[idx] = static_cast(_x[idx] & 0x1); }); }); @@ -532,10 +545,12 @@ NdArrayRef ring_select(const std::vector& c, const NdArrayRef& x, const int64_t numel = c.size(); DISPATCH_ALL_FIELDS(field, kModule, [&]() { - pforeach(0, numel, [&](int64_t idx) { - z.at(idx) = - (c[idx] ? y.at(idx) : x.at(idx)); - }); + NdArrayView _x(x); + NdArrayView _y(y); + NdArrayView _z(z); + + pforeach(0, numel, + [&](int64_t idx) { _z[idx] = (c[idx] ? _y[idx] : _x[idx]); }); }); return z; diff --git a/libspu/mpc/utils/ring_ops.h b/libspu/mpc/utils/ring_ops.h index ff81e707..c5b86ed1 100644 --- a/libspu/mpc/utils/ring_ops.h +++ b/libspu/mpc/utils/ring_ops.h @@ -103,7 +103,8 @@ std::vector ring_rand_boolean_splits(const NdArrayRef& arr, template void ring_set_value(NdArrayRef& in, const T& value) { - pforeach(0, in.numel(), [&](int64_t idx) { in.at(idx) = value; }); + NdArrayView _in(in); + pforeach(0, in.numel(), [&](int64_t idx) { _in[idx] = value; }); } } // namespace spu::mpc diff --git a/libspu/mpc/utils/tiling_util.h b/libspu/mpc/utils/tiling_util.h index bc3545fd..69abea3f 100644 --- a/libspu/mpc/utils/tiling_util.h +++ b/libspu/mpc/utils/tiling_util.h @@ -130,7 +130,7 @@ Value tiled(Fn&& fn, SPUContext* ctx, const Value& x, Args&&... args) { int64_t offset = 0; for (int64_t slice_idx = 0; slice_idx < num_slice_dim * num_slice; slice_idx++) { - std::memcpy(static_cast(out.data()) + offset, + std::memcpy(out.data() + offset, out_slices[slice_idx].data().data(), out_slices[slice_idx].numel() * out.elsize()); offset += out_slices[slice_idx].numel() * out.elsize(); @@ -247,7 +247,7 @@ Value tiled(Fn&& fn, SPUContext* ctx, const Value& x, const Value& y, int64_t offset = 0; for (int64_t slice_idx = 0; slice_idx < num_slice_dim * num_slice; slice_idx++) { - std::memcpy(static_cast(out.data()) + offset, + std::memcpy(out.data() + offset, out_slices[slice_idx].data().data(), out_slices[slice_idx].numel() * out.elsize()); diff --git a/libspu/pir/BUILD.bazel b/libspu/pir/BUILD.bazel index aef59949..771a33a5 100644 --- a/libspu/pir/BUILD.bazel +++ b/libspu/pir/BUILD.bazel @@ -106,3 +106,12 @@ spu_cc_library( "@yacl//yacl/crypto/base:symmetric_crypto", ], ) + +spu_cc_test( + name = "pir_test", + srcs = ["pir_test.cc"], + deps = [ + ":pir", + "//libspu/psi/io:io", + ], +) diff --git a/libspu/pir/pir.cc b/libspu/pir/pir.cc index bec82fc1..65ab0574 100644 --- a/libspu/pir/pir.cc +++ b/libspu/pir/pir.cc @@ -14,12 +14,14 @@ #include "libspu/pir/pir.h" +#include #include #include #include #include #include +#include "yacl/crypto/base/hash/hash_utils.h" #include "yacl/crypto/utils/rand.h" #include "yacl/io/kv/leveldb_kvstore.h" #include "yacl/io/rw/csv_writer.h" @@ -28,6 +30,8 @@ #include "libspu/psi/core/labeled_psi/receiver.h" #include "libspu/psi/core/labeled_psi/sender.h" #include "libspu/psi/core/labeled_psi/sender_db.h" +#include "libspu/psi/core/labeled_psi/sender_kvdb.h" +#include "libspu/psi/core/labeled_psi/sender_memdb.h" #include "libspu/psi/cryptor/ecc_cryptor.h" #include "libspu/psi/io/io.h" #include "libspu/psi/utils/batch_provider.h" @@ -76,17 +80,23 @@ size_t CsvFileDataCount(const std::string &file_path, return data_count; } -constexpr char kMetaInfoStoreName[] = "meta_info"; +constexpr char kMetaInfoStoreName[] = "pir_meta_info"; constexpr char kServerDataCount[] = "server_data_count"; constexpr char kCountPerQuery[] = "count_per_query"; constexpr char kLabelByteCount[] = "label_byte_count"; constexpr char kLabelColumns[] = "label_columns"; constexpr char kPsiParams[] = "psi_params"; +constexpr char kBucketCount[] = "bucket_count"; +constexpr char kBucketSize[] = "bucket_size"; +constexpr char kCompressed[] = "compressed"; + +constexpr size_t kNonceByteCount = 16; void WriteMetaInfo(const std::string &setup_path, size_t server_data_count, size_t count_per_query, size_t label_byte_count, const std::vector &label_cloumns, - const apsi::PSIParams &psi_params) { + const apsi::PSIParams &psi_params, size_t bucket_count, + size_t bucket_size = 1000000, bool compressed = false) { std::string meta_store_name = fmt::format("{}/{}", setup_path, kMetaInfoStoreName); @@ -96,6 +106,8 @@ void WriteMetaInfo(const std::string &setup_path, size_t server_data_count, meta_info_store->Put(kServerDataCount, fmt::format("{}", server_data_count)); meta_info_store->Put(kCountPerQuery, fmt::format("{}", count_per_query)); meta_info_store->Put(kLabelByteCount, fmt::format("{}", label_byte_count)); + meta_info_store->Put(kBucketSize, fmt::format("{}", bucket_size)); + meta_info_store->Put(kCompressed, fmt::format("{}", compressed ? 1 : 0)); spu::psi::proto::StrItemsProto proto; for (const auto &label_cloumn : label_cloumns) { @@ -108,6 +120,8 @@ void WriteMetaInfo(const std::string &setup_path, size_t server_data_count, yacl::Buffer params_buffer = spu::psi::PsiParamsToBuffer(psi_params); meta_info_store->Put(kPsiParams, params_buffer); + + meta_info_store->Put(kBucketCount, fmt::format("{}", bucket_count)); } size_t GetSizeFromStore( @@ -124,7 +138,9 @@ size_t GetSizeFromStore( apsi::PSIParams ReadMetaInfo(const std::string &setup_path, size_t *server_data_count, size_t *count_per_query, size_t *label_byte_count, - std::vector *label_cloumns) { + std::vector *label_cloumns, + size_t *bucket_count, size_t *bucket_size, + bool *compressed) { std::string meta_store_name = fmt::format("{}/{}", setup_path, kMetaInfoStoreName); std::shared_ptr meta_info_store = @@ -147,6 +163,18 @@ apsi::PSIParams ReadMetaInfo(const std::string &setup_path, meta_info_store->Get(kPsiParams, ¶ms_buffer); apsi::PSIParams psi_params = spu::psi::ParsePsiParamsProto(params_buffer); + *bucket_count = GetSizeFromStore(meta_info_store, kBucketCount); + *bucket_size = GetSizeFromStore(meta_info_store, kBucketSize); + size_t compress_flag = GetSizeFromStore(meta_info_store, kCompressed); + + if (compress_flag == 1) { + *compressed = true; + } else if (compress_flag == 0) { + *compressed = false; + } else { + SPU_THROW("error compress value:{}", compress_flag); + } + return psi_params; } @@ -164,16 +192,14 @@ PirResultReport LabeledPirSetup(const PirSetupConfig &config) { size_t server_data_count = CsvFileDataCount(config.input_path(), key_columns); size_t count_per_query = config.num_per_query(); - apsi::PSIParams psi_params = - spu::psi::GetPsiParams(count_per_query, server_data_count); + size_t bucket_size = config.bucket_size(); - // spu::pir::examples::utils::WritePsiParams(ParamsOutPathOpt.getValue(), - // psi_params); + size_t bucket_count = (server_data_count + bucket_size - 1) / bucket_size; std::vector oprf_key = ReadEcSecretKeyFile(config.oprf_key_path()); size_t label_byte_count = config.label_max_len(); - size_t nonce_byte_count = 16; + size_t nonce_byte_count = kNonceByteCount; std::string kv_store_path = config.setup_path(); // delete store path @@ -187,19 +213,60 @@ PirResultReport LabeledPirSetup(const PirSetupConfig &config) { } std::filesystem::create_directory(kv_store_path); - WriteMetaInfo(kv_store_path, server_data_count, count_per_query, - label_byte_count, label_columns, psi_params); + apsi::PSIParams psi_params = spu::psi::GetPsiParams( + count_per_query, bucket_size, config.max_items_per_bin()); + + SPDLOG_INFO("table_params hash_func_count:{}", + psi_params.table_params().hash_func_count); + SPDLOG_INFO("table_params max_items_per_bin:{}", + psi_params.table_params().max_items_per_bin); + + SPDLOG_INFO("seal_params poly_modulus_degree:{}", + psi_params.seal_params().poly_modulus_degree()); + SPDLOG_INFO("query_params query_powers size:{}", + psi_params.query_params().query_powers.size()); - std::shared_ptr sender_db = - std::make_shared(psi_params, oprf_key, kv_store_path, - label_byte_count, nonce_byte_count, - false); + WriteMetaInfo(kv_store_path, server_data_count, count_per_query, + label_byte_count, label_columns, psi_params, bucket_count, + config.bucket_size(), config.compressed()); std::shared_ptr batch_provider = std::make_shared(config.input_path(), key_columns, label_columns); + for (size_t i = 0; i < bucket_count; i++) { + std::string bucket_setup_path = + fmt::format("{}/bucket_{}", kv_store_path, i); + + SPDLOG_INFO("bucket:{} bucket_setup_path:{}", i, bucket_setup_path); + + std::vector batch_ids; + std::vector batch_labels; + std::tie(batch_ids, batch_labels) = + batch_provider->ReadNextBatchWithLabel(bucket_size); + + if (batch_ids.empty()) { + SPDLOG_INFO("Finish read data"); + break; + } + + std::filesystem::create_directory(bucket_setup_path); - sender_db->SetData(batch_provider); + apsi::PSIParams bucket_psi_params = spu::psi::GetPsiParams( + count_per_query, std::min(bucket_size, batch_ids.size()), + config.max_items_per_bin()); + + std::shared_ptr sender_db = + std::make_shared( + bucket_psi_params, oprf_key, bucket_setup_path, label_byte_count, + nonce_byte_count, config.compressed()); + + std::shared_ptr bucket_batch_provider = + std::make_shared(batch_ids, + batch_labels); + + sender_db->SetData(bucket_batch_provider); + SPDLOG_INFO("finish bucket:{}", i); + } PirResultReport report; report.set_data_count(server_data_count); @@ -209,10 +276,11 @@ PirResultReport LabeledPirSetup(const PirSetupConfig &config) { PirResultReport LabeledPirServer( const std::shared_ptr &link_ctx, - const std::shared_ptr &sender_db, + const std::shared_ptr &sender_db, const std::vector &oprf_key, const apsi::PSIParams &psi_params, - const std::vector &label_columns, size_t server_data_count, - size_t count_per_query, size_t label_byte_count) { + const std::vector &label_columns, size_t bucket_count, + size_t server_data_count, size_t count_per_query, size_t label_byte_count, + uint32_t bucket_size) { // send count_per_query link_ctx->SendAsync(link_ctx->NextRank(), spu::psi::utils::SerializeSize(count_per_query), @@ -229,6 +297,11 @@ PirResultReport LabeledPirServer( link_ctx->SendAsync(link_ctx->NextRank(), params_buffer, fmt::format("send psi params")); + // bucket_count + link_ctx->SendAsync(link_ctx->NextRank(), + spu::psi::utils::SerializeSize(bucket_count), + fmt::format("bucket_count:{}", bucket_count)); + // const auto total_query_start = std::chrono::system_clock::now(); size_t query_count = 0; @@ -236,8 +309,6 @@ PirResultReport LabeledPirServer( spu::psi::LabelPsiSender sender(sender_db); - SPDLOG_INFO("LabelPsiSender"); - while (true) { // recv current batch_size size_t batch_data_size = spu::psi::utils::DeserializeSize( @@ -287,27 +358,47 @@ PirResultReport LabeledPirServer( size_t label_byte_count; std::vector label_columns; + size_t bucket_count; + size_t bucket_size; + bool compressed; apsi::PSIParams psi_params = ReadMetaInfo(config.setup_path(), &server_data_count, &count_per_query, - &label_byte_count, &label_columns); + &label_byte_count, &label_columns, &bucket_count, + &bucket_size, &compressed); SPU_ENFORCE(label_columns.size() > 0); std::vector oprf_key = ReadEcSecretKeyFile(config.oprf_key_path()); + // server and client sync + auto run_f = std::async([&] { return 0; }); + spu::psi::SyncWait(link_ctx, &run_f); + SPDLOG_INFO("table_params hash_func_count:{}", psi_params.table_params().hash_func_count); - size_t nonce_byte_count = 16; + size_t nonce_byte_count = kNonceByteCount; + + std::vector> sender_db(bucket_count); + std::vector> sender(bucket_count); - bool compressed = false; - std::shared_ptr sender_db = - std::make_shared( - psi_params, oprf_key, config.setup_path(), label_byte_count, - nonce_byte_count, compressed); + std::unique_ptr oprf_server = + spu::psi::CreateEcdhOprfServer(oprf_key, spu::psi::OprfType::Basic, + spu::psi::CurveType::CURVE_FOURQ); - SPDLOG_INFO("db GetItemCount:{}", sender_db->GetItemCount()); + for (size_t bucket_idx = 0; bucket_idx < bucket_count; ++bucket_idx) { + std::string bucket_setup_path = + fmt::format("{}/bucket_{}", config.setup_path(), bucket_idx); + sender_db[bucket_idx] = std::make_shared( + psi_params, oprf_key, bucket_setup_path, label_byte_count, + nonce_byte_count, compressed); + + sender[bucket_idx] = + std::make_shared(sender_db[bucket_idx]); + } + + SPDLOG_INFO("db GetItemCount:{}", sender_db[0]->GetItemCount()); // send count_per_query link_ctx->SendAsync(link_ctx->NextRank(), @@ -325,6 +416,11 @@ PirResultReport LabeledPirServer( link_ctx->SendAsync(link_ctx->NextRank(), params_buffer, fmt::format("send psi params")); + // send bucket_count + link_ctx->SendAsync(link_ctx->NextRank(), + spu::psi::utils::SerializeSize(bucket_count), + fmt::format("bucket_count:{}", bucket_count)); + // const auto total_query_start = std::chrono::system_clock::now(); size_t query_count = 0; @@ -346,22 +442,22 @@ PirResultReport LabeledPirServer( spu::psi::CreateEcdhOprfServer(oprf_key, spu::psi::OprfType::Basic, spu::psi::CurveType::CURVE_FOURQ); - spu::psi::LabelPsiSender sender(sender_db); - // const auto oprf_start = std::chrono::system_clock::now(); - sender.RunOPRF(std::move(oprf_server), link_ctx); + sender[0]->RunOPRF(std::move(oprf_server), link_ctx); - // const auto oprf_end = std::chrono::system_clock::now(); - // const DurationMillis oprf_duration = oprf_end - oprf_start; - // SPDLOG_INFO("*** server oprf duration:{}", oprf_duration.count()); + for (size_t bucket_idx = 0; bucket_idx < bucket_count; ++bucket_idx) { + // const auto oprf_end = std::chrono::system_clock::now(); + // const DurationMillis oprf_duration = oprf_end - oprf_start; + // SPDLOG_INFO("*** server oprf duration:{}", oprf_duration.count()); - // const auto query_start = std::chrono::system_clock::now(); + // const auto query_start = std::chrono::system_clock::now(); - sender.RunQuery(link_ctx); + sender[bucket_idx]->RunQuery(link_ctx); - // const auto query_end = std::chrono::system_clock::now(); - // const DurationMillis query_duration = query_end - query_start; - // SPDLOG_INFO("*** server query duration:{}", query_duration.count()); + // const auto query_end = std::chrono::system_clock::now(); + // const DurationMillis query_duration = query_end - query_start; + // SPDLOG_INFO("*** server query duration:{}", query_duration.count()); + } query_count++; } @@ -388,31 +484,43 @@ PirResultReport LabeledPirMemoryServer( size_t count_per_query = config.num_per_query(); SPDLOG_INFO("server_data_count:{}", server_data_count); - apsi::PSIParams psi_params = - spu::psi::GetPsiParams(count_per_query, server_data_count); + SPU_ENFORCE(server_data_count <= config.bucket_size(), + "data_count:{} bucket_size:{}", config.bucket_size()); + + apsi::PSIParams psi_params = spu::psi::GetPsiParams( + count_per_query, server_data_count, config.max_items_per_bin()); std::vector oprf_key = yacl::crypto::RandBytes(spu::psi::kEccKeySize); size_t label_byte_count = config.label_max_len(); - size_t nonce_byte_count = 16; + size_t nonce_byte_count = kNonceByteCount; - std::shared_ptr sender_db = - std::make_shared(psi_params, oprf_key, "::memory", - label_byte_count, nonce_byte_count, - false); + std::shared_ptr sender_db = + std::make_shared( + psi_params, oprf_key, label_byte_count, nonce_byte_count, + config.compressed()); - std::shared_ptr batch_provider = - std::make_shared(config.input_path(), - key_columns, label_columns); + // server and client sync + auto run_f = std::async([&] { + std::shared_ptr batch_provider = + std::make_shared( + config.input_path(), key_columns, label_columns); - sender_db->SetData(batch_provider); + sender_db->SetData(batch_provider); + + return 0; + }); + spu::psi::SyncWait(link_ctx, &run_f); SPDLOG_INFO("sender_db->GetItemCount:{}", sender_db->GetItemCount()); - PirResultReport report = LabeledPirServer( - link_ctx, sender_db, oprf_key, psi_params, label_columns, - sender_db->GetItemCount(), count_per_query, label_byte_count); + size_t bucket_count = 1; + + PirResultReport report = + LabeledPirServer(link_ctx, sender_db, oprf_key, psi_params, label_columns, + bucket_count, sender_db->GetItemCount(), count_per_query, + label_byte_count, config.bucket_size()); return report; } @@ -424,6 +532,10 @@ PirResultReport LabeledPirClient( key_columns.insert(key_columns.end(), config.key_columns().begin(), config.key_columns().end()); + // server and client sync + auto run_f = std::async([&] { return 0; }); + spu::psi::SyncWait(link_ctx, &run_f); + // recv count_per_query size_t count_per_query = spu::psi::utils::DeserializeSize( link_ctx->Recv(link_ctx->NextRank(), fmt::format("count_per_query"))); @@ -468,6 +580,20 @@ PirResultReport LabeledPirClient( apsi::PSIParams psi_params = spu::psi::ParsePsiParamsProto(params_buffer); + SPDLOG_INFO("table_params hash_func_count:{}", + psi_params.table_params().hash_func_count); + SPDLOG_INFO("table_params max_items_per_bin:{}", + psi_params.table_params().max_items_per_bin); + + SPDLOG_INFO("seal_params poly_modulus_degree:{}", + psi_params.seal_params().poly_modulus_degree()); + SPDLOG_INFO("query_params query_powers size:{}", + psi_params.query_params().query_powers.size()); + + size_t bucket_count = spu::psi::utils::DeserializeSize( + link_ctx->Recv(link_ctx->NextRank(), fmt::format("bucket_count"))); + SPDLOG_INFO("bucket_count:{}", bucket_count); + spu::psi::LabelPsiReceiver receiver(psi_params, true); // const auto total_query_start = std::chrono::system_clock::now(); @@ -475,9 +601,18 @@ PirResultReport LabeledPirClient( size_t query_count = 0; size_t data_count = 0; + size_t table_items = psi_params.table_params().table_size / 2; + if (psi_params.table_params().hash_func_count == 1) { + table_items = 1; + } + + size_t query_result_count = 0; + size_t batch_read_count = std::max(table_items, count_per_query); + SPDLOG_INFO("batch_read_count:{}", batch_read_count); + while (true) { auto query_batch_items = - query_batch_provider->ReadNextBatch(count_per_query); + query_batch_provider->ReadNextBatch(batch_read_count); // send count_batch_size link_ctx->SendAsync( @@ -498,54 +633,68 @@ PirResultReport LabeledPirClient( // const DurationMillis oprf_duration = oprf_end - oprf_start; // SPDLOG_INFO("*** server oprf duration:{}", oprf_duration.count()); - // const auto query_start = std::chrono::system_clock::now(); - std::pair, std::vector> query_result = - receiver.RequestQuery(items_oprf.first, items_oprf.second, link_ctx); + for (size_t bucket_idx = 0; bucket_idx < bucket_count; bucket_idx++) { + SPDLOG_INFO("bucket_idx: {}", bucket_idx); - // const auto query_end = std::chrono::system_clock::now(); - // const DurationMillis query_duration = query_end - query_start; - // SPDLOG_INFO("*** server query duration:{}", query_duration.count()); + // const auto query_start = std::chrono::system_clock::now(); + std::pair, std::vector> query_result = + receiver.RequestQuery(items_oprf.first, items_oprf.second, link_ctx); - SPDLOG_INFO("query_result size:{}", query_result.first.size()); + // const auto query_end = std::chrono::system_clock::now(); + // const DurationMillis query_duration = query_end - query_start; + // SPDLOG_INFO("*** server query duration:{}", query_duration.count()); - yacl::io::ColumnVectorBatch batch; + SPDLOG_INFO("query_result size:{}", query_result.first.size()); + query_result_count += query_result.first.size(); - std::vector> query_id_results(key_columns.size()); - std::vector> query_label_results( - label_columns_name.size()); + yacl::io::ColumnVectorBatch batch; - for (size_t i = 0; i < query_result.first.size(); ++i) { - std::vector result_ids = - absl::StrSplit(query_batch_items[query_result.first[i]], ","); + std::vector> query_id_results( + key_columns.size()); + std::vector> query_label_results( + label_columns_name.size()); - SPU_ENFORCE(result_ids.size() == key_columns.size()); + for (size_t i = 0; i < query_result.first.size(); ++i) { + std::vector result_ids = + absl::StrSplit(query_batch_items[query_result.first[i]], ","); - std::vector result_labels = - absl::StrSplit(query_result.second[i], ","); - SPU_ENFORCE(result_labels.size() == label_columns_name.size()); + SPU_ENFORCE(result_ids.size() == key_columns.size()); - for (size_t j = 0; j < result_ids.size(); ++j) { - query_id_results[j].push_back(result_ids[j]); + std::vector result_labels = + absl::StrSplit(query_result.second[i], ","); + + if (result_labels.size() != label_columns_name.size()) { + SPDLOG_INFO("query_result: {}", query_result.second[i]); + } + + SPU_ENFORCE(result_labels.size() == label_columns_name.size(), + fmt::format("{}!={}", result_labels.size(), + label_columns_name.size())); + + for (size_t j = 0; j < result_ids.size(); ++j) { + query_id_results[j].push_back(result_ids[j]); + } + for (size_t j = 0; j < result_labels.size(); ++j) { + query_label_results[j].push_back(result_labels[j]); + } + } + + for (size_t i = 0; i < key_columns.size(); ++i) { + batch.AppendCol(query_id_results[i]); } - for (size_t j = 0; j < result_labels.size(); ++j) { - query_label_results[j].push_back(result_labels[j]); + for (size_t i = 0; i < label_columns_name.size(); ++i) { + batch.AppendCol(query_label_results[i]); } - } - for (size_t i = 0; i < key_columns.size(); ++i) { - batch.AppendCol(query_id_results[i]); + writer.Add(batch); } - for (size_t i = 0; i < label_columns_name.size(); ++i) { - batch.AppendCol(query_label_results[i]); - } - - writer.Add(batch); query_count++; } writer.Close(); - SPDLOG_INFO("query_count:{}, data_count:{}", query_count, data_count); + SPDLOG_INFO("query_count:{}, data_count:{}, result_count:{}", query_count, + data_count, query_result_count); PirResultReport report; report.set_data_count(data_count); diff --git a/libspu/pir/pir.proto b/libspu/pir/pir.proto index 6b78a0c8..191bec44 100644 --- a/libspu/pir/pir.proto +++ b/libspu/pir/pir.proto @@ -69,6 +69,15 @@ message PirSetupConfig { // The path of setup output leveldb path. string setup_path = 9; + + // compressed Seal ciphertext + bool compressed = 10; + + // split data bucket to do pir query + uint32 bucket_size = 11; + + // max items per bin, i.e. Interpolate polynomial max degree + uint32 max_items_per_bin = 12; } // The Server's Cfg Params of query phase. diff --git a/libspu/pir/pir_test.cc b/libspu/pir/pir_test.cc new file mode 100644 index 00000000..fb15f754 --- /dev/null +++ b/libspu/pir/pir_test.cc @@ -0,0 +1,310 @@ +// Copyright 2023 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "libspu/pir/pir.h" + +#include +#include +#include +#include +#include + +#include "gtest/gtest.h" +#include "spdlog/spdlog.h" +#include "yacl/crypto/tools/prg.h" +#include "yacl/crypto/utils/rand.h" +#include "yacl/link/test_util.h" + +#include "libspu/psi/core/labeled_psi/psi_params.h" +#include "libspu/psi/io/io.h" + +namespace { + +constexpr size_t kPsiStartPos = 100; + +struct TestParams { + size_t nr; + size_t ns; + size_t label_bytes = 16; + bool use_filedb = true; + bool compressed = false; + size_t bucket_size = 1000000; +}; + +void WriteCsvFile(const std::string &file_name, const std::string &id_name, + + const std::vector &items) { + auto out = + spu::psi::io::BuildOutputStream(spu::psi::io::FileIoOptions(file_name)); + out->Write(fmt::format("{}\n", id_name)); + for (size_t i = 0; i < items.size(); ++i) { + out->Write(fmt::format("{}\n", items[i])); + } + out->Close(); +} + +void WriteCsvFile(const std::string &file_name, const std::string &id_name, + const std::string &label_name, + const std::vector &items, + const std::vector &labels) { + auto out = + spu::psi::io::BuildOutputStream(spu::psi::io::FileIoOptions(file_name)); + out->Write(fmt::format("{},{}\n", id_name, label_name)); + for (size_t i = 0; i < items.size(); ++i) { + out->Write(fmt::format("{},{}\n", items[i], labels[i])); + } + out->Close(); +} + +void WriteSecretKey(const std::string &ecdh_secret_key_path) { + std::ofstream wf(ecdh_secret_key_path, std::ios::out | std::ios::binary); + + std::vector oprf_key = yacl::crypto::RandBytes(32); + + wf.write((const char *)oprf_key.data(), oprf_key.size()); + wf.close(); +} + +std::vector GenerateData(size_t seed, size_t item_count) { + yacl::crypto::Prg prg(seed); + + std::vector items; + + for (size_t i = 0; i < item_count; ++i) { + std::string item(16, '\0'); + prg.Fill(absl::MakeSpan(item.data(), item.length())); + items.emplace_back(absl::BytesToHexString(item)); + } + + return items; +} + +std::pair, std::vector> +GenerateSenderData(size_t seed, size_t item_count, size_t label_byte_count, + const absl::Span &receiver_items, + std::vector *intersection_idx, + std::vector *intersection_label) { + std::vector sender_items; + std::vector sender_labels; + + yacl::crypto::Prg prg(seed); + + for (size_t i = 0; i < item_count; ++i) { + std::string item(16, '\0'); + std::string label((label_byte_count - 1) / 2, '\0'); + + prg.Fill(absl::MakeSpan(item.data(), item.length())); + prg.Fill(absl::MakeSpan(label.data(), label.length())); + sender_items.emplace_back(absl::BytesToHexString(item)); + sender_labels.emplace_back(absl::BytesToHexString(label)); + } + + for (size_t i = 0; i < receiver_items.size(); i += 3) { + if ((kPsiStartPos + i * 5) >= sender_items.size()) { + break; + } + sender_items[kPsiStartPos + i * 5] = receiver_items[i]; + (*intersection_idx).emplace_back(i); + + (*intersection_label).emplace_back(sender_labels[kPsiStartPos + i * 5]); + } + + return std::make_pair(sender_items, sender_labels); +} + +} // namespace + +namespace spu::psi { + +class PirTest : public testing::TestWithParam {}; + +TEST_P(PirTest, Works) { + auto params = GetParam(); + auto ctxs = yacl::link::test::SetupWorld(2); + apsi::PSIParams psi_params = spu::psi::GetPsiParams(params.nr, params.ns); + + std::string tmp_store_path = + fmt::format("data_{}_{}", yacl::crypto::RandU64(), params.ns); + + std::filesystem::create_directory(tmp_store_path); + + // register remove of temp dir. + ON_SCOPE_EXIT([&] { + if (!tmp_store_path.empty()) { + std::error_code ec; + std::filesystem::remove_all(tmp_store_path, ec); + if (ec.value() != 0) { + SPDLOG_WARN("can not remove tmp dir: {}, msg: {}", tmp_store_path, + ec.message()); + } + } + }); + + std::string client_csv_path = fmt::format("{}/client.csv", tmp_store_path); + std::string server_csv_path = fmt::format("{}/server.csv", tmp_store_path); + std::string id_cloumn_name = "id"; + std::string label_cloumn_name = "label"; + + std::vector intersection_idx; + std::vector intersection_label; + + // generate test csv data + { + std::vector receiver_items = + GenerateData(yacl::crypto::RandU64(), params.nr); + + std::vector sender_keys; + std::vector sender_labels; + + std::tie(sender_keys, sender_labels) = GenerateSenderData( + yacl::crypto::RandU64(), params.ns, params.label_bytes - 4, + absl::MakeSpan(receiver_items), &intersection_idx, &intersection_label); + + WriteCsvFile(client_csv_path, id_cloumn_name, receiver_items); + WriteCsvFile(server_csv_path, id_cloumn_name, label_cloumn_name, + sender_keys, sender_labels); + } + + // generate 32 bytes oprf_key + std::string oprf_key_path = fmt::format("{}/oprf_key.bin", tmp_store_path); + WriteSecretKey(oprf_key_path); + + std::string setup_path = fmt::format("{}/setup_path", tmp_store_path); + std::string pir_result_path = + fmt::format("{}/pir_result.csv", tmp_store_path); + + std::vector ids = {id_cloumn_name}; + std::vector labels = {label_cloumn_name}; + + if (params.use_filedb) { + spu::pir::PirSetupConfig config; + + config.set_pir_protocol(spu::pir::PirProtocol::KEYWORD_PIR_LABELED_PSI); + config.set_store_type(spu::pir::KvStoreType::LEVELDB_KV_STORE); + config.set_input_path(server_csv_path); + + config.mutable_key_columns()->Add(ids.begin(), ids.end()); + config.mutable_label_columns()->Add(labels.begin(), labels.end()); + + config.set_num_per_query(params.nr); + config.set_label_max_len(params.label_bytes); + config.set_oprf_key_path(oprf_key_path); + config.set_setup_path(setup_path); + config.set_compressed(params.compressed); + config.set_bucket_size(params.bucket_size); + + spu::pir::PirResultReport setup_report = spu::pir::PirSetup(config); + + EXPECT_EQ(setup_report.data_count(), params.ns); + + std::future f_server = std::async([&] { + spu::pir::PirServerConfig config; + + config.set_pir_protocol(spu::pir::PirProtocol::KEYWORD_PIR_LABELED_PSI); + config.set_store_type(spu::pir::KvStoreType::LEVELDB_KV_STORE); + + config.set_oprf_key_path(oprf_key_path); + config.set_setup_path(setup_path); + + spu::pir::PirResultReport report = spu::pir::PirServer(ctxs[0], config); + return report; + }); + + std::future f_client = std::async([&] { + spu::pir::PirClientConfig config; + + config.set_pir_protocol(spu::pir::PirProtocol::KEYWORD_PIR_LABELED_PSI); + + config.set_input_path(client_csv_path); + config.mutable_key_columns()->Add(ids.begin(), ids.end()); + config.set_output_path(pir_result_path); + + spu::pir::PirResultReport report = spu::pir::PirClient(ctxs[1], config); + return report; + }); + + spu::pir::PirResultReport server_report = f_server.get(); + spu::pir::PirResultReport client_report = f_client.get(); + + EXPECT_EQ(server_report.data_count(), params.nr); + EXPECT_EQ(client_report.data_count(), params.nr); + + } else { + std::future f_server = std::async([&] { + spu::pir::PirSetupConfig config; + + config.set_pir_protocol(spu::pir::PirProtocol::KEYWORD_PIR_LABELED_PSI); + config.set_store_type(spu::pir::KvStoreType::LEVELDB_KV_STORE); + config.set_input_path(server_csv_path); + + config.mutable_key_columns()->Add(ids.begin(), ids.end()); + config.mutable_label_columns()->Add(labels.begin(), labels.end()); + + config.set_num_per_query(params.nr); + config.set_label_max_len(params.label_bytes); + config.set_oprf_key_path(""); + config.set_setup_path("::memory"); + config.set_compressed(params.compressed); + config.set_bucket_size(params.bucket_size); + + spu::pir::PirResultReport report = + spu::pir::PirMemoryServer(ctxs[0], config); + return report; + }); + + std::future f_client = std::async([&] { + spu::pir::PirClientConfig config; + + config.set_pir_protocol(spu::pir::PirProtocol::KEYWORD_PIR_LABELED_PSI); + + config.set_input_path(client_csv_path); + config.mutable_key_columns()->Add(ids.begin(), ids.end()); + config.set_output_path(pir_result_path); + + spu::pir::PirResultReport report = spu::pir::PirClient(ctxs[1], config); + return report; + }); + + spu::pir::PirResultReport server_report = f_server.get(); + spu::pir::PirResultReport client_report = f_client.get(); + + EXPECT_EQ(server_report.data_count(), params.nr); + EXPECT_EQ(client_report.data_count(), params.nr); + } + + std::shared_ptr pir_result_provider = + std::make_shared(pir_result_path, ids, + labels); + + // read pir_result from csv + std::vector pir_result_ids; + std::vector pir_result_labels; + std::tie(pir_result_ids, pir_result_labels) = + pir_result_provider->ReadNextBatchWithLabel(intersection_idx.size()); + + // check pir result correct + EXPECT_EQ(pir_result_ids.size(), intersection_idx.size()); + + EXPECT_EQ(pir_result_labels, intersection_label); +} + +INSTANTIATE_TEST_SUITE_P(Works_Instances, PirTest, + testing::Values( // + TestParams{10, 10000, 32}, // 10-10K-32 + TestParams{2048, 10000, 32}, // 2048-10K-32 + TestParams{100, 100000, 32, false}) // 100-100K-32 + +); + +} // namespace spu::psi diff --git a/libspu/psi/core/labeled_psi/BUILD.bazel b/libspu/psi/core/labeled_psi/BUILD.bazel index cdfac37b..36365146 100644 --- a/libspu/psi/core/labeled_psi/BUILD.bazel +++ b/libspu/psi/core/labeled_psi/BUILD.bazel @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("//bazel:spu.bzl", "spu_cc_library", "spu_cc_test") +load("//bazel:spu.bzl", "spu_cc_binary", "spu_cc_library", "spu_cc_test") load("@rules_proto//proto:defs.bzl", "proto_library") load("@rules_cc//cc:defs.bzl", "cc_proto_library") @@ -26,6 +26,8 @@ spu_cc_library( "receiver.cc", "sender.cc", "sender_db.cc", + "sender_kvdb.cc", + "sender_memdb.cc", ], hdrs = [ "package.h", @@ -33,6 +35,8 @@ spu_cc_library( "receiver.h", "sender.h", "sender_db.h", + "sender_kvdb.h", + "sender_memdb.h", "serialize.h", ], deps = [ @@ -96,3 +100,17 @@ spu_cc_test( "@yacl//yacl/crypto/tools:prg", ], ) + +spu_cc_binary( + name = "apsi_bench", + srcs = [ + "apsi_bench.cc", + ], + deps = [ + ":labeled_psi", + "@com_github_google_benchmark//:benchmark_main", + "@com_google_googletest//:gtest", + "@com_google_absl//absl/strings", + "@yacl//yacl/crypto/tools:prg", + ], +) diff --git a/libspu/psi/core/labeled_psi/apsi_bench.cc b/libspu/psi/core/labeled_psi/apsi_bench.cc index dde8cc5a..0454bc23 100644 --- a/libspu/psi/core/labeled_psi/apsi_bench.cc +++ b/libspu/psi/core/labeled_psi/apsi_bench.cc @@ -13,6 +13,7 @@ // limitations under the License. #include +#include #include #include #include @@ -20,33 +21,66 @@ #include "absl/strings/escaping.h" #include "absl/strings/str_split.h" #include "benchmark/benchmark.h" +#include "gtest/gtest.h" #include "spdlog/spdlog.h" -#include "spu/psi/core/labeled_psi/psi_params.h" -#include "spu/psi/core/labeled_psi/receiver.h" -#include "spu/psi/core/labeled_psi/sender.h" -#include "spu/psi/cryptor/ecdh_oprf/ecdh_oprf_selector.h" +#include "yacl/crypto/tools/prg.h" #include "yacl/link/test_util.h" +#include "libspu/psi/core/ecdh_oprf/ecdh_oprf_selector.h" +#include "libspu/psi/core/labeled_psi/psi_params.h" +#include "libspu/psi/core/labeled_psi/receiver.h" +#include "libspu/psi/core/labeled_psi/sender.h" +#include "libspu/psi/core/labeled_psi/sender_kvdb.h" + namespace spu::psi { namespace { -std::vector CreateRangeItems(size_t start_pos, size_t size) { - std::vector ret(size); +constexpr char kLinkAddrAB[] = "127.0.0.1:9532,127.0.0.1:9533"; +constexpr uint32_t kLinkRecvTimeout = 30 * 60 * 1000; +constexpr uint32_t kLinkWindowSize = 16; - auto gen_items_proc = [&](size_t begin, size_t end) -> void { - for (size_t i = begin; i < end; ++i) { - ret[i] = std::to_string(start_pos + i); - } - }; +using duration_millis = std::chrono::duration; + +constexpr size_t kPsiStartPos = 100; + +std::vector GenerateData(size_t seed, size_t item_count) { + yacl::crypto::Prg prg(seed); + + std::vector items; + + for (size_t i = 0; i < item_count; ++i) { + std::string item(16, '\0'); + prg.Fill(absl::MakeSpan(item.data(), item.length())); + items.emplace_back(absl::BytesToHexString(item)); + } - std::future f_gen = std::async(gen_items_proc, size / 2, size); + return items; +} + +std::vector GenerateSenderData( + size_t seed, size_t item_count, + const absl::Span& receiver_items, + std::vector* intersection_idx) { + std::vector sender_items; + + yacl::crypto::Prg prg(seed); - gen_items_proc(0, size / 2); + for (size_t i = 0; i < item_count; ++i) { + std::string item(16, '\0'); + prg.Fill(absl::MakeSpan(item.data(), item.size())); + sender_items.emplace_back(absl::BytesToHexString(item)); + } - f_gen.get(); + for (size_t i = 0; i < receiver_items.size(); i += 3) { + if ((kPsiStartPos + i * 5) >= sender_items.size()) { + break; + } + sender_items[kPsiStartPos + i * 5] = receiver_items[i]; + (*intersection_idx).emplace_back(i); + } - return ret; + return sender_items; } std::shared_ptr CreateContext( @@ -88,16 +122,13 @@ std::vector> CreateLinks( return links; } -constexpr char kLinkAddrAB[] = "127.0.0.1:9532,127.0.0.1:9533"; -constexpr uint32_t kLinkRecvTimeout = 30 * 60 * 1000; -constexpr uint32_t kLinkWindowSize = 16; - } // namespace static void BM_LabeledPsi(benchmark::State& state) { for (auto _ : state) { state.PauseTiming(); - size_t n = state.range(0); + size_t nr = state.range(0); + size_t ns = state.range(1); auto ctxs = CreateLinks(kLinkAddrAB); @@ -109,14 +140,14 @@ static void BM_LabeledPsi(benchmark::State& state) { state.ResumeTiming(); - apsi::PSIParams psi_params = spu::psi::GetPsiParams(params.nr, params.ns); + apsi::PSIParams psi_params = spu::psi::GetPsiParams(nr, ns); // step 1: PsiParams Request and Response - std::future f_sender_params = std::async( - [&] { return LabelPsiSender::RunPsiParams(params.ns, ctxs[0]); }); + std::future f_sender_params = + std::async([&] { return LabelPsiSender::RunPsiParams(ns, ctxs[0]); }); std::future f_receiver_params = std::async( - [&] { return LabelPsiReceiver::RequestPsiParams(params.nr, ctxs[1]); }); + [&] { return LabelPsiReceiver::RequestPsiParams(nr, ctxs[1]); }); f_sender_params.get(); apsi::PSIParams psi_params2 = f_receiver_params.get(); @@ -124,8 +155,7 @@ static void BM_LabeledPsi(benchmark::State& state) { EXPECT_EQ(psi_params.table_params().table_size, psi_params2.table_params().table_size); - size_t item_count = params.ns; - size_t label_byte_count = params.label_bytes; + size_t item_count = ns; size_t nonce_byte_count = 16; std::random_device rd; @@ -135,12 +165,11 @@ static void BM_LabeledPsi(benchmark::State& state) { prg.Fill(absl::MakeSpan(oprf_key)); bool compressed = false; - std::shared_ptr sender_db = - std::make_shared(psi_params, oprf_key, - label_byte_count, nonce_byte_count, - compressed); + std::shared_ptr sender_db = + std::make_shared(psi_params, oprf_key, "::memory", + 0, nonce_byte_count, compressed); - std::vector receiver_items = GenerateData(rd(), params.nr); + std::vector receiver_items = GenerateData(rd(), nr); std::vector intersection_idx; std::vector intersection_label; @@ -149,36 +178,30 @@ static void BM_LabeledPsi(benchmark::State& state) { const auto setdb_start = std::chrono::system_clock::now(); - if (params.label_bytes == 0) { - std::vector sender_items = GenerateSenderData( - rd(), item_count, absl::MakeSpan(receiver_items), &intersection_idx); + std::vector sender_items = GenerateSenderData( + rd(), item_count, absl::MakeSpan(receiver_items), &intersection_idx); - sender_db->SetData(sender_items); - } else { - std::vector> sender_items = - GenerateSenderData(rd(), item_count, label_byte_count, - absl::MakeSpan(receiver_items), &intersection_idx, - &intersection_label); + std::shared_ptr batch_provider = + std::make_shared(sender_items); - sender_db->SetData(sender_items); - } + sender_db->SetData(batch_provider); const auto setdb_end = std::chrono::system_clock::now(); const duration_millis setdb_duration = setdb_end - setdb_start; SPDLOG_INFO("*** step2 set db duration:{}", setdb_duration.count()); - EXPECT_EQ(params.ns, sender_db->GetItemCount()); + EXPECT_EQ(ns, sender_db->GetItemCount()); SPDLOG_INFO("after set db, bin_bundle_count:{}, packing_rate:{}", sender_db->GetBinBundleCount(), sender_db->GetPackingRate()); - std::unique_ptr oprf_server = - spu::CreateEcdhOprfServer(oprf_key, spu::OprfType::Basic, - spu::CurveType::CurveFourQ); + std::unique_ptr oprf_server = + spu::psi::CreateEcdhOprfServer(oprf_key, spu::psi::OprfType::Basic, + spu::psi::CurveType::CURVE_FOURQ); LabelPsiSender sender(sender_db); - LabelPsiReceiver receiver(psi_params, params.label_bytes > 0); + LabelPsiReceiver receiver(psi_params, false); // step 3: oprf request and response @@ -224,10 +247,11 @@ static void BM_LabeledPsi(benchmark::State& state) { const duration_millis query_duration = query_end - query_start; SPDLOG_INFO("*** step4 query duration:{}", query_duration.count()); + EXPECT_EQ(query_result.first.size(), intersection_idx.size()); SPDLOG_INFO("index vec size:{} intersection_idx size:{}", query_result.first.size(), intersection_idx.size()); - SPDLOG_INFO("intersection:{}", intersection.size()); + SPDLOG_INFO("intersection:{}", intersection_idx.size()); auto stats0 = ctxs[0]->GetStats(); auto stats1 = ctxs[1]->GetStats(); SPDLOG_INFO("sender ctx0 sent_bytes:{} recv_bytes:{}", stats0->sent_bytes, @@ -238,15 +262,18 @@ static void BM_LabeledPsi(benchmark::State& state) { } // [256k, 512k, 1m, 2m, 4m, 8m, 16m] -BENCHMARK(BM_PcgPsi) +BENCHMARK(BM_LabeledPsi) ->Unit(benchmark::kMillisecond) - ->Arg(256 << 10) - ->Arg(512 << 10) - ->Arg(1 << 20) - ->Arg(2 << 20) - ->Arg(4 << 20) - ->Arg(8 << 20) - ->Arg(16 << 20); + ->Args({100, 256 << 10}) + ->Args({100, 512 << 10}) + ->Args({100, 1 << 20}) + ->Args({100, 2 << 20}) + ->Args({100, 4 << 20}) + ->Args({100, 8 << 20}) + ->Args({100, 16 << 20}) + ->Args({100, 1 << 25}) + ->Args({100, 1 << 26}) + ->Args({100, 1 << 27}); BENCHMARK_MAIN(); diff --git a/libspu/psi/core/labeled_psi/apsi_label_test.cc b/libspu/psi/core/labeled_psi/apsi_label_test.cc index 88c4e4c8..84159d4d 100644 --- a/libspu/psi/core/labeled_psi/apsi_label_test.cc +++ b/libspu/psi/core/labeled_psi/apsi_label_test.cc @@ -31,6 +31,8 @@ #include "libspu/psi/core/labeled_psi/psi_params.h" #include "libspu/psi/core/labeled_psi/receiver.h" #include "libspu/psi/core/labeled_psi/sender.h" +#include "libspu/psi/core/labeled_psi/sender_kvdb.h" +#include "libspu/psi/core/labeled_psi/sender_memdb.h" namespace spu::psi { @@ -43,6 +45,7 @@ struct TestParams { size_t nr; size_t ns; size_t label_bytes = 16; + bool use_kvdb = true; }; std::vector GenerateData(size_t seed, size_t item_count) { @@ -53,46 +56,43 @@ std::vector GenerateData(size_t seed, size_t item_count) { for (size_t i = 0; i < item_count; ++i) { std::string item(16, '\0'); prg.Fill(absl::MakeSpan(item.data(), item.length())); - items.emplace_back(item); + items.emplace_back(absl::BytesToHexString(item)); } return items; } -std::vector> GenerateSenderData( - size_t seed, size_t item_count, size_t label_byte_count, - const absl::Span &receiver_items, - std::vector *intersection_idx, - std::vector *intersection_label) { - std::vector> sender_items; +std::pair, std::vector> +GenerateSenderData(size_t seed, size_t item_count, size_t label_byte_count, + const absl::Span &receiver_items, + std::vector *intersection_idx, + std::vector *intersection_label) { + std::vector sender_items; + std::vector sender_labels; yacl::crypto::Prg prg(seed); for (size_t i = 0; i < item_count; ++i) { - apsi::Item item; - apsi::Label label; - label.resize(label_byte_count); - prg.Fill(absl::MakeSpan(item.value())); - prg.Fill(absl::MakeSpan(label)); - sender_items.emplace_back(item, label); + std::string item(16, '\0'); + std::string label((label_byte_count - 1) / 2, '\0'); + + prg.Fill(absl::MakeSpan(item.data(), item.length())); + prg.Fill(absl::MakeSpan(label.data(), label.length())); + sender_items.emplace_back(absl::BytesToHexString(item)); + sender_labels.emplace_back(absl::BytesToHexString(label)); } for (size_t i = 0; i < receiver_items.size(); i += 3) { - apsi::Item item; - std::memcpy(item.value().data(), receiver_items[i].data(), - receiver_items[i].length()); - - sender_items[kPsiStartPos + i * 5].first = item; + if ((kPsiStartPos + i * 5) >= sender_items.size()) { + break; + } + sender_items[kPsiStartPos + i * 5] = receiver_items[i]; (*intersection_idx).emplace_back(i); - std::string label_string(sender_items[kPsiStartPos + i * 5].second.size(), - '\0'); - std::memcpy(label_string.data(), - sender_items[kPsiStartPos + i * 5].second.data(), - sender_items[kPsiStartPos + i * 5].second.size()); - (*intersection_label).emplace_back(label_string); + + (*intersection_label).emplace_back(sender_labels[kPsiStartPos + i * 5]); } - return sender_items; + return std::make_pair(sender_items, sender_labels); } } // namespace @@ -127,7 +127,7 @@ TEST_P(LabelPsiTest, Works) { std::array oprf_key; prg.Fill(absl::MakeSpan(oprf_key)); - std::string kv_store_path = fmt::format("data_{}", params.ns); + std::string kv_store_path = fmt::format("data_{}/", params.ns); std::filesystem::create_directory(kv_store_path); // register remove of temp dir. ON_SCOPE_EXIT([&] { @@ -142,10 +142,15 @@ TEST_P(LabelPsiTest, Works) { }); bool compressed = false; - std::shared_ptr sender_db = - std::make_shared(psi_params, oprf_key, kv_store_path, - label_byte_count, nonce_byte_count, - compressed); + std::shared_ptr sender_db; + if (params.use_kvdb) { + sender_db = std::make_shared( + psi_params, oprf_key, kv_store_path, label_byte_count, nonce_byte_count, + compressed); + } else { + sender_db = std::make_shared( + psi_params, oprf_key, label_byte_count, nonce_byte_count, compressed); + } std::vector receiver_items = GenerateData(rd(), params.nr); @@ -156,29 +161,15 @@ TEST_P(LabelPsiTest, Works) { const auto setdb_start = std::chrono::system_clock::now(); - std::vector> sender_items = - GenerateSenderData(rd(), item_count, label_byte_count - 6, - absl::MakeSpan(receiver_items), &intersection_idx, - &intersection_label); - - // sender_db->SetData(sender_items); - - std::vector sender_data(sender_items.size()); - std::vector sender_label(sender_items.size()); - for (size_t i = 0; i < sender_items.size(); ++i) { - sender_data[i].reserve(sender_items[i].first.value().size()); - sender_data[i].append(absl::string_view( - reinterpret_cast(sender_items[i].first.value().data()), - sender_items[i].first.value().size())); - - sender_label[i].reserve(sender_items[i].second.size()); - sender_label[i].append(absl::string_view( - reinterpret_cast(sender_items[i].second.data()), - sender_items[i].second.size())); - } + std::vector sender_keys; + std::vector sender_labels; + + std::tie(sender_keys, sender_labels) = GenerateSenderData( + rd(), item_count, label_byte_count - 4, absl::MakeSpan(receiver_items), + &intersection_idx, &intersection_label); std::shared_ptr batch_provider = - std::make_shared(sender_data, sender_label); + std::make_shared(sender_keys, sender_labels); sender_db->SetData(batch_provider); @@ -193,12 +184,6 @@ TEST_P(LabelPsiTest, Works) { const apsi::PSIParams apsi_params = sender_db->GetParams(); SPDLOG_INFO("params.bundle_idx_count={}", apsi_params.bundle_idx_count()); - for (size_t i = 0; i < apsi_params.bundle_idx_count(); ++i) { - SPDLOG_INFO("i={},count={}", i, sender_db->GetBinBundleCount(i)); - } - - std::unique_ptr oprf_server = - CreateEcdhOprfServer(oprf_key, OprfType::Basic, CurveType::CURVE_FOURQ); LabelPsiSender sender(sender_db); @@ -208,8 +193,12 @@ TEST_P(LabelPsiTest, Works) { const auto oprf_start = std::chrono::system_clock::now(); - std::future f_sender_oprf = std::async( - [&] { return sender.RunOPRF(std::move(oprf_server), ctxs[0]); }); + std::future f_sender_oprf = std::async([&] { + std::unique_ptr oprf_server = + CreateEcdhOprfServer(oprf_key, OprfType::Basic, CurveType::CURVE_FOURQ); + + return sender.RunOPRF(std::move(oprf_server), ctxs[0]); + }); std::future< std::pair, std::vector>> @@ -269,8 +258,9 @@ INSTANTIATE_TEST_SUITE_P(Works_Instances, LabelPsiTest, TestParams{10000, 100000, 32}, // 10000-100K-32 TestParams{10000, 1000000, 32}) // 10000-1M-32 #else - TestParams{2048, 100000, 32}, // 2048-100K-32 - TestParams{4096, 100000, 32}) // 4096-100K-32 + TestParams{10, 10000, 32}, // 10-10K-32 + TestParams{2048, 10000, 32}, // 2048-10K-32 + TestParams{100, 100000, 32, false}) // 100-100K-32 #endif ); diff --git a/libspu/psi/core/labeled_psi/apsi_test.cc b/libspu/psi/core/labeled_psi/apsi_test.cc index 72ab73ec..894aa6ae 100644 --- a/libspu/psi/core/labeled_psi/apsi_test.cc +++ b/libspu/psi/core/labeled_psi/apsi_test.cc @@ -31,6 +31,8 @@ #include "libspu/psi/core/labeled_psi/psi_params.h" #include "libspu/psi/core/labeled_psi/receiver.h" #include "libspu/psi/core/labeled_psi/sender.h" +#include "libspu/psi/core/labeled_psi/sender_kvdb.h" +#include "libspu/psi/core/labeled_psi/sender_memdb.h" namespace spu::psi { @@ -42,6 +44,7 @@ constexpr size_t kPsiStartPos = 100; struct TestParams { size_t nr; size_t ns; + bool use_kvdb = true; }; std::vector GenerateData(size_t seed, size_t item_count) { @@ -52,32 +55,31 @@ std::vector GenerateData(size_t seed, size_t item_count) { for (size_t i = 0; i < item_count; ++i) { std::string item(16, '\0'); prg.Fill(absl::MakeSpan(item.data(), item.length())); - items.emplace_back(item); + items.emplace_back(absl::BytesToHexString(item)); } return items; } -std::vector GenerateSenderData( +std::vector GenerateSenderData( size_t seed, size_t item_count, const absl::Span &receiver_items, std::vector *intersection_idx) { - std::vector sender_items; + std::vector sender_items; yacl::crypto::Prg prg(seed); for (size_t i = 0; i < item_count; ++i) { - apsi::Item::value_type value{}; - prg.Fill(absl::MakeSpan(value)); - sender_items.emplace_back(value); + std::string item(16, '\0'); + prg.Fill(absl::MakeSpan(item.data(), item.size())); + sender_items.emplace_back(absl::BytesToHexString(item)); } for (size_t i = 0; i < receiver_items.size(); i += 3) { - apsi::Item::value_type value{}; - std::memcpy(value.data(), receiver_items[i].data(), - receiver_items[i].length()); - apsi::Item item(value); - sender_items[kPsiStartPos + i * 5] = item; + if ((kPsiStartPos + i * 5) >= sender_items.size()) { + break; + } + sender_items[kPsiStartPos + i * 5] = receiver_items[i]; (*intersection_idx).emplace_back(i); } @@ -130,9 +132,14 @@ TEST_P(LabelPsiTest, Works) { }); bool compressed = false; - std::shared_ptr sender_db = - std::make_shared(psi_params, oprf_key, kv_store_path, - 0, nonce_byte_count, compressed); + std::shared_ptr sender_db; + if (params.use_kvdb) { + sender_db = std::make_shared( + psi_params, oprf_key, kv_store_path, 0, nonce_byte_count, compressed); + } else { + sender_db = std::make_shared( + psi_params, oprf_key, 0, nonce_byte_count, compressed); + } std::vector receiver_items = GenerateData(rd(), params.nr); @@ -143,22 +150,11 @@ TEST_P(LabelPsiTest, Works) { const auto setdb_start = std::chrono::system_clock::now(); - std::vector sender_items = GenerateSenderData( + std::vector sender_items = GenerateSenderData( rd(), item_count, absl::MakeSpan(receiver_items), &intersection_idx); - // sender_db->SetData(sender_items); - - std::vector sender_data(sender_items.size()); - for (size_t i = 0; i < sender_items.size(); ++i) { - sender_data[i].reserve(sender_items[i].value().size()); - sender_data[i].append(absl::string_view( - reinterpret_cast(sender_items[i].value().data()), - sender_items[i].value().size())); - } - SPDLOG_INFO("sender_data[0].length:{}", sender_data[0].length()); - std::shared_ptr batch_provider = - std::make_shared(sender_data); + std::make_shared(sender_items); sender_db->SetData(batch_provider); @@ -249,10 +245,10 @@ INSTANTIATE_TEST_SUITE_P(Works_Instances, LabelPsiTest, TestParams{4096, 100000}, // 4096-100K TestParams{10000, 100000}, // 10000-100K #else - TestParams{1, 10000}, // 1-10K - TestParams{1, 100000}, // 1-100K - TestParams{256, 100000}, // 256-100K - TestParams{256, 100000}) // 256-100K + TestParams{1, 10000}, // 1-10K + TestParams{1, 10000, false}, // 1-10K memdb + TestParams{10, 10000}, // 1-10K + TestParams{100, 100000, false}) // 100-100K #endif ); diff --git a/libspu/psi/core/labeled_psi/kv_test.cc b/libspu/psi/core/labeled_psi/kv_test.cc index a0592111..239cc2eb 100644 --- a/libspu/psi/core/labeled_psi/kv_test.cc +++ b/libspu/psi/core/labeled_psi/kv_test.cc @@ -31,6 +31,7 @@ #include "libspu/psi/core/labeled_psi/psi_params.h" #include "libspu/psi/core/labeled_psi/receiver.h" #include "libspu/psi/core/labeled_psi/sender.h" +#include "libspu/psi/core/labeled_psi/sender_kvdb.h" #include "libspu/psi/utils/utils.h" namespace spu::psi { @@ -186,10 +187,10 @@ TEST_P(LabelPsiTest, Works) { }); bool compressed = false; - std::shared_ptr sender_db = - std::make_shared(psi_params, oprf_key, kv_store_path, - label_byte_count, nonce_byte_count, - compressed); + std::shared_ptr sender_db = + std::make_shared(psi_params, oprf_key, + kv_store_path, label_byte_count, + nonce_byte_count, compressed); std::vector receiver_items = GenerateData(rd(), params.nr); diff --git a/libspu/psi/core/labeled_psi/psi_params.cc b/libspu/psi/core/labeled_psi/psi_params.cc index e3ab2c22..c1cd0b27 100644 --- a/libspu/psi/core/labeled_psi/psi_params.cc +++ b/libspu/psi/core/labeled_psi/psi_params.cc @@ -16,6 +16,7 @@ #include #include +#include #include #include "spdlog/spdlog.h" @@ -45,6 +46,52 @@ std::vector kSealParams = { {8192, 65537, 0, {56, 56, 30}}, // 14 }; +std::map kPolynomialParams = { + {20, {0, {1, 2, 5, 8, 9, 10}}}, + {35, {5, {1, 2, 3, 4, 5, 6, 18, 30, 42, 54, 60}}}, + {36, {0, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36}}}, + {42, {0, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42}}}, + {51, {3, {1, 2, 3, 4, 12, 20, 24}}}, + {55, {0, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, + 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55}}}, + {64, {0, {1, 3, 4, 9, 11, 16, 21, 23, 28, 29, 31, 32}}}, + {72, {0, {1, 3, 4, 9, 11, 16, 20, 25, 27, 32, 33, 35, 36}}}, + {98, {8, {1, 3, 4, 9, 27}}}, + {125, {5, {1, 2, 3, 4, 5, 6, 18, 30, 42, 54, 60}}}, + {128, + {0, {1, 3, 4, 5, 8, 14, 20, 26, 32, 38, 44, 50, 56, 59, 60, 61, 63, 64}}}, + {180, {0, {1, 3, 4, 6, 10, 13, 15, 21, 29, 37, 45, + 53, 61, 69, 77, 81, 83, 86, 87, 90, 92, 96}}}, + {228, {0, {1, 3, 8, 19, 33, 39, 92, 102}}}, + {512, + {0, {1, 3, 4, 6, 10, 13, 15, 21, 29, 37, 45, 53, 61, 69, + 75, 77, 80, 84, 86, 87, 89, 90, 181, 183, 188, 190, 195, 197, + 206, 213, 214, 222, 230, 238, 246, 254, 261, 337, 338, 345, 353, 361, + 370, 372, 377, 379, 384, 386, 477, 479, 486, 487, 495, 503, 511}}}, + {726, {0, {1, 4, 13, 18, 51, 92, 163, 208, 223}}}, + {770, {26, {1, 5, 8, 27, 135}}}, + {782, {26, {1, 5, 8, 27, 135}}}, + {1016, {0, {1, 6, 8, 21, 60, 93, 104, 154, 378, 414}}}, + {1024, + {0, + {1, 3, 4, 6, 10, 13, 15, 21, 29, 37, 45, 53, 61, 69, + 75, 77, 80, 84, 86, 87, 89, 90, 181, 183, 188, 190, 195, 197, + 206, 213, 214, 222, 230, 238, 246, 254, 261, 337, 338, 345, 353, 361, + 370, 372, 377, 379, 384, 386, 477, 479, 486, 487, 495, 503, 511, 519, + 521, 528, 533, 540, 548, 553, 590, 591, 593, 597, 601, 613, 616, 624, + 639, 642, 651, 662, 673, 684, 695, 701, 711, 722, 733, 744, 755, 766, + 777, 788, 799, 811, 822, 833, 844, 855, 866, 877, 888, 899, 901, 911, + 922, 933, 944, 955, 966, 977, 988, 993, 999, 1011, 1021, 1022, 1023}}}, + {1304, {44, {1, 3, 11, 18, 45, 225}}}, + {4000, {310, {1, 4, 10, 11, 28, 33, 78, 118, 143, 311, 1555}}}, + {8100, {310, {1, 4, 10, 11, 28, 33, 78, 118, 143, 311, 1555}}} // +}; } // namespace yacl::Buffer PsiParamsToBuffer(const apsi::PSIParams &psi_params) { @@ -170,7 +217,7 @@ SEALParams GetSealParams(size_t nr, size_t ns) { return kSealParams[12]; } -apsi::PSIParams GetPsiParams(size_t nr, size_t ns) { +apsi::PSIParams GetPsiParams(size_t nr, size_t ns, size_t max_items_per_bin) { SEALParams seal_params = GetSealParams(nr, ns); apsi::PSIParams::ItemParams item_params; @@ -249,6 +296,13 @@ apsi::PSIParams GetPsiParams(size_t nr, size_t ns) { apsi_seal_params.set_coeff_modulus(seal::CoeffModulus::Create( seal_params.poly_modulus_degree, seal_params.coeff_modulus_bits)); + if ((table_params.hash_func_count > 1) && (max_items_per_bin != 0) && + kPolynomialParams.count(max_items_per_bin)) { + table_params.max_items_per_bin = max_items_per_bin; + + query_params = kPolynomialParams[max_items_per_bin]; + } + apsi::PSIParams psi_params(item_params, table_params, query_params, apsi_seal_params); diff --git a/libspu/psi/core/labeled_psi/psi_params.h b/libspu/psi/core/labeled_psi/psi_params.h index d26c55e3..f710297d 100644 --- a/libspu/psi/core/labeled_psi/psi_params.h +++ b/libspu/psi/core/labeled_psi/psi_params.h @@ -56,7 +56,8 @@ struct SEALParams { * @param ns sender's items size * @return apsi::PSIParams */ -apsi::PSIParams GetPsiParams(size_t nr, size_t ns); +apsi::PSIParams GetPsiParams(size_t nr, size_t ns, + size_t max_items_per_bin = 0); /** * @brief Serialize apsi::PSIParams to yacl::Buffer diff --git a/libspu/psi/core/labeled_psi/receiver.cc b/libspu/psi/core/labeled_psi/receiver.cc index 6066ce37..f5cd2eaa 100644 --- a/libspu/psi/core/labeled_psi/receiver.cc +++ b/libspu/psi/core/labeled_psi/receiver.cc @@ -269,7 +269,7 @@ LabelPsiReceiver::RequestQuery( // plain_modulus parts std::vector alg_item = apsi::util::bits_to_field_elts( item_bits, psi_params_.seal_params().plain_modulus()); - copy(alg_item.cbegin(), alg_item.cend(), back_inserter(alg_items)); + std::copy(alg_item.cbegin(), alg_item.cend(), back_inserter(alg_items)); } // Now that we have the algebraized items for this bundle index, we diff --git a/libspu/psi/core/labeled_psi/sender.cc b/libspu/psi/core/labeled_psi/sender.cc index e75d91c7..0c0e11b4 100644 --- a/libspu/psi/core/labeled_psi/sender.cc +++ b/libspu/psi/core/labeled_psi/sender.cc @@ -31,6 +31,7 @@ #include "libspu/psi/core/labeled_psi/package.h" #include "libspu/psi/core/labeled_psi/sender_db.h" +#include "libspu/psi/core/labeled_psi/sender_kvdb.h" #include "libspu/psi/core/labeled_psi/serializable.pb.h" @@ -44,7 +45,7 @@ class QueryRequest { std::unordered_map< uint32_t, std::vector>> &encrypted_powers, - const std::shared_ptr &sender_db) { + const std::shared_ptr &sender_db) { auto seal_context = sender_db->GetSealContext(); for (auto &q : encrypted_powers) { @@ -118,7 +119,7 @@ uint32_t reset_powers_dag(apsi::PowersDag *pd, const apsi::PSIParams ¶ms, } // namespace -LabelPsiSender::LabelPsiSender(std::shared_ptr sender_db) +LabelPsiSender::LabelPsiSender(std::shared_ptr sender_db) : sender_db_(std::move(sender_db)) { apsi::PSIParams params(sender_db_->GetParams()); @@ -186,7 +187,7 @@ void LabelPsiSender::RunOPRF( std::vector> SenderRunQuery( const QueryRequest &query, - const std::shared_ptr &sender_db, + const std::shared_ptr &sender_db, const apsi::PowersDag &pd); void LabelPsiSender::RunQuery( @@ -238,21 +239,29 @@ void LabelPsiSender::RunQuery( SenderRunQuery(request, sender_db_, pd_); proto::QueryResponseProto response_proto; + + std::vector> futures; + apsi::ThreadPoolMgr tpm; + for (auto &result : query_result) { proto::QueryResultProto *result_proto = response_proto.add_results(); - result_proto->set_bundle_idx(result->bundle_idx); - std::vector temp; - temp.resize(result->psi_result.save_size(compr_mode_)); - auto size = result->psi_result.save(temp, compr_mode_); - result_proto->set_ciphertext(temp.data(), temp.size()); - result_proto->set_label_byte_count(result->label_byte_count); - result_proto->set_nonce_byte_count(result->nonce_byte_count); - - for (auto &r : result->label_result) { - temp.resize(r.save_size(compr_mode_)); - size = r.save(temp, compr_mode_); - result_proto->add_label_results(temp.data(), size); - } + futures.push_back(tpm.thread_pool().enqueue([&, result, result_proto]() { + result_proto->set_bundle_idx(result->bundle_idx); + std::vector temp(result->psi_result.save_size(compr_mode_)); + auto size = result->psi_result.save(temp, compr_mode_); + result_proto->set_ciphertext(temp.data(), temp.size()); + result_proto->set_label_byte_count(result->label_byte_count); + result_proto->set_nonce_byte_count(result->nonce_byte_count); + + for (auto &r : result->label_result) { + temp.resize(r.save_size(compr_mode_)); + size = r.save(temp, compr_mode_); + result_proto->add_label_results(temp.data(), size); + } + })); + } + for (auto &f : futures) { + f.get(); } yacl::Buffer response_buffer(response_proto.ByteSizeLong()); @@ -267,14 +276,14 @@ void LabelPsiSender::RunQuery( fmt::format("send query response size:{}", response_buffer.size())); } -void ComputePowers(const std::shared_ptr &sender_db, +void ComputePowers(const std::shared_ptr &sender_db, const apsi::CryptoContext &crypto_context, std::vector *all_powers, const apsi::PowersDag &pd, uint32_t bundle_idx, seal::MemoryPoolHandle *pool); void ProcessBinBundleCache( - const std::shared_ptr &sender_db, + const std::shared_ptr &sender_db, const apsi::CryptoContext &crypto_context, const std::shared_ptr &bundle, std::vector *all_powers, uint32_t bundle_idx, @@ -283,7 +292,7 @@ void ProcessBinBundleCache( std::vector> SenderRunQuery( const QueryRequest &query, - const std::shared_ptr &sender_db, + const std::shared_ptr &sender_db, const apsi::PowersDag &pd) { // We use a custom SEAL memory that is freed after the query is done auto pool = seal::MemoryManager::GetPool(seal::mm_force_new); @@ -364,7 +373,8 @@ std::vector> SenderRunQuery( for (size_t cache_idx = 0; cache_idx < cache_count; ++cache_idx) { std::shared_ptr bundle = - sender_db->GetCacheAt(static_cast(bundle_idx), cache_idx); + sender_db->GetBinBundleAt(static_cast(bundle_idx), + cache_idx); query_results.push_back(std::make_shared()); std::shared_ptr result = *(query_results.rbegin()); @@ -388,7 +398,7 @@ std::vector> SenderRunQuery( return query_results; } -void ComputePowers(const std::shared_ptr &sender_db, +void ComputePowers(const std::shared_ptr &sender_db, const apsi::CryptoContext &crypto_context, std::vector *all_powers, const apsi::PowersDag &pd, uint32_t bundle_idx, @@ -479,7 +489,7 @@ void ComputePowers(const std::shared_ptr &sender_db, } void ProcessBinBundleCache( - const std::shared_ptr &sender_db, + const std::shared_ptr &sender_db, const apsi::CryptoContext &crypto_context, const std::shared_ptr &bundle, std::vector *all_powers, uint32_t bundle_idx, diff --git a/libspu/psi/core/labeled_psi/sender.h b/libspu/psi/core/labeled_psi/sender.h index 479fc0ce..6be6caed 100644 --- a/libspu/psi/core/labeled_psi/sender.h +++ b/libspu/psi/core/labeled_psi/sender.h @@ -32,7 +32,7 @@ namespace spu::psi { class LabelPsiSender { public: - explicit LabelPsiSender(std::shared_ptr sender_db); + explicit LabelPsiSender(std::shared_ptr sender_db); /** * @brief Receive PsiParams Request and Send PsiParams Response @@ -60,7 +60,7 @@ class LabelPsiSender { void RunQuery(const std::shared_ptr& link_ctx); private: - std::shared_ptr sender_db_; + std::shared_ptr sender_db_; apsi::CryptoContext crypto_context_; seal::compr_mode_type compr_mode_ = seal::Serialization::compr_mode_default; diff --git a/libspu/psi/core/labeled_psi/sender_db.cc b/libspu/psi/core/labeled_psi/sender_db.cc index 37f285d8..d27bfc1c 100644 --- a/libspu/psi/core/labeled_psi/sender_db.cc +++ b/libspu/psi/core/labeled_psi/sender_db.cc @@ -60,6 +60,10 @@ namespace { using DurationMillis = std::chrono::duration; +} + +namespace labeled_psi { + /** Creates and returns the vector of hash functions similarly to how Kuku 2.x sets them internally. @@ -115,595 +119,18 @@ std::pair UnpackCuckooIdx(size_t cuckoo_idx, return {bin_idx, bundle_idx}; } -/** -Converts each given Item-Label pair in between the given iterators into its -algebraic form, i.e., a sequence of felt-felt pairs. Also computes each Item's -cuckoo index. -*/ -std::vector> PreprocessLabeledData( - const std::vector>::const_iterator begin, - const std::vector< - std::pair>::const_iterator end, - const apsi::PSIParams ¶ms) { - STOPWATCH(sender_stopwatch, "preprocess_labeled_data"); - SPDLOG_DEBUG("Start preprocessing {} labeled items", - std::distance(begin, end)); - - // Some variables we'll need - size_t bins_per_item = params.item_params().felts_per_item; - size_t item_bit_count = params.item_bit_count(); - - // Set up Kuku hash functions - auto hash_funcs = HashFunctions(params); - - // Calculate the cuckoo indices for each item. Store every pair of - // (item-label, cuckoo_idx) in a vector. Later, we're gonna sort this vector - // by cuckoo_idx and use the result to parallelize the work of inserting the - // items into BinBundles. - std::vector> data_with_indices; - for (auto it = begin; it != end; it++) { - const std::pair &item_label_pair = - *it; - - // Serialize the data into field elements - const apsi::HashedItem &item = item_label_pair.first; - const apsi::EncryptedLabel &label = item_label_pair.second; - apsi::util::AlgItemLabel alg_item_label = algebraize_item_label( - item, label, item_bit_count, params.seal_params().plain_modulus()); - - // Get the cuckoo table locations for this item and add to data_with_indices - for (auto location : AllLocations(hash_funcs, item)) { - // The current hash value is an index into a table of Items. In reality - // our BinBundles are tables of bins, which contain chunks of items. How - // many chunks? bins_per_item many chunks - size_t bin_idx = location * bins_per_item; - - // Store the data along with its index - data_with_indices.emplace_back(alg_item_label, bin_idx); - } - } - - SPDLOG_DEBUG("Finished preprocessing {} labeled items", - std::distance(begin, end)); - - return data_with_indices; -} - -/** -Converts each given Item into its algebraic form, i.e., a sequence of -felt-monostate pairs. Also computes each Item's cuckoo index. -*/ -std::vector> PreprocessUnlabeledData( - const std::vector::const_iterator begin, - const std::vector::const_iterator end, - const apsi::PSIParams ¶ms) { - STOPWATCH(sender_stopwatch, "preprocess_unlabeled_data"); - SPDLOG_DEBUG("Start preprocessing {} unlabeled items", - std::distance(begin, end)); - - // Some variables we'll need - size_t bins_per_item = params.item_params().felts_per_item; - size_t item_bit_count = params.item_bit_count(); - - // Set up Kuku hash functions - auto hash_funcs = HashFunctions(params); - - // Calculate the cuckoo indices for each item. Store every pair of - // (item-label, cuckoo_idx) in a vector. Later, we're gonna sort this vector - // by cuckoo_idx and use the result to parallelize the work of inserting the - // items into BinBundles. - std::vector> data_with_indices; - for (auto it = begin; it != end; it++) { - const apsi::HashedItem &item = *it; - - // Serialize the data into field elements - apsi::util::AlgItem alg_item = algebraize_item( - item, item_bit_count, params.seal_params().plain_modulus()); - - // Get the cuckoo table locations for this item and add to data_with_indices - for (auto location : AllLocations(hash_funcs, item)) { - // The current hash value is an index into a table of Items. In reality - // our BinBundles are tables of bins, which contain chunks of items. How - // many chunks? bins_per_item many chunks - size_t bin_idx = location * bins_per_item; - - // Store the data along with its index - data_with_indices.emplace_back(alg_item, bin_idx); - } - } - - SPDLOG_DEBUG("Finished preprocessing {} unlabeled items", - std::distance(begin, end)); - - return data_with_indices; -} - -std::vector> PreprocessLabeledData( - const std::pair &item_label_pair, - const apsi::PSIParams ¶ms) { - SPDLOG_DEBUG("Start preprocessing {} labeled items", distance(begin, end)); - - // Some variables we'll need - size_t bins_per_item = params.item_params().felts_per_item; - size_t item_bit_count = params.item_bit_count(); - - // Set up Kuku hash functions - auto hash_funcs = HashFunctions(params); - - std::vector> data_with_indices; - - // Serialize the data into field elements - const apsi::HashedItem &item = item_label_pair.first; - const apsi::EncryptedLabel &label = item_label_pair.second; - apsi::util::AlgItemLabel alg_item_label = algebraize_item_label( - item, label, item_bit_count, params.seal_params().plain_modulus()); - - // Get the cuckoo table locations for this item and add to data_with_indices - for (auto location : AllLocations(hash_funcs, item)) { - // The current hash value is an index into a table of Items. In reality - // our BinBundles are tables of bins, which contain chunks of items. How - // many chunks? bins_per_item many chunks - size_t bin_idx = location * bins_per_item; - - // Store the data along with its index - data_with_indices.push_back(std::make_pair(alg_item_label, bin_idx)); - } - - return data_with_indices; -} - -std::vector> PreprocessUnlabeledData( - const apsi::HashedItem &hashed_item, const apsi::PSIParams ¶ms) { - // Some variables we'll need - size_t bins_per_item = params.item_params().felts_per_item; - size_t item_bit_count = params.item_bit_count(); - - // Set up Kuku hash functions - auto hash_funcs = HashFunctions(params); - - std::vector> data_with_indices; - - // Serialize the data into field elements - apsi::util::AlgItem alg_item = algebraize_item( - hashed_item, item_bit_count, params.seal_params().plain_modulus()); - - // Get the cuckoo table locations for this item and add to data_with_indices - for (auto location : AllLocations(hash_funcs, hashed_item)) { - // The current hash value is an index into a table of Items. In reality - // our BinBundles are tables of bins, which contain chunks of items. How - // many chunks? bins_per_item many chunks - size_t bin_idx = location * bins_per_item; - - // Store the data along with its index - data_with_indices.emplace_back(std::make_pair(alg_item, bin_idx)); - } - - return data_with_indices; -} - -/** -Inserts the given items and corresponding labels into bin_bundles at their -respective cuckoo indices. It will only insert the data with bundle index in the -half-open range range indicated by work_range. If inserting into a BinBundle -would make the number of items in a bin larger than max_bin_size, this function -will create and insert a new BinBundle. If overwrite is set, this will overwrite -the labels if it finds an AlgItemLabel that matches the input perfectly. -*/ -template -void InsertOrAssignWorker( - const std::vector> &data_with_indices, - std::vector> *bundles_store, - std::vector *bundles_store_idx, - const apsi::CryptoContext &crypto_context, uint32_t bundle_index, - uint32_t bins_per_bundle, size_t label_size, size_t max_bin_size, - size_t ps_low_degree, bool overwrite, bool compressed) { - STOPWATCH(sender_stopwatch, "insert_or_assign_worker"); - SPDLOG_DEBUG( - "Insert-or-Assign worker for bundle index {}; mode of operation: {}", - bundle_index, overwrite ? "overwriting existing" : "inserting new"); - - // Create the bundle set at the given bundle index - std::vector bundle_set; - - // Iteratively insert each item-label pair at the given cuckoo index - for (auto &data_with_idx : data_with_indices) { - const T &data = data_with_idx.first; - - // Get the bundle index - size_t cuckoo_idx = data_with_idx.second; - size_t bin_idx; - size_t bundle_idx; - std::tie(bin_idx, bundle_idx) = - UnpackCuckooIdx(cuckoo_idx, bins_per_bundle); - - // If the bundle_idx isn't in the prescribed range, don't try to insert this - // data - if (bundle_idx != bundle_index) { - // Dealing with this bundle index is not our job - continue; - } - - // Try to insert or overwrite these field elements in an existing BinBundle - // at this bundle index. Keep track of whether or not we succeed. - bool written = false; - - for (auto bundle_it = bundle_set.rbegin(); bundle_it != bundle_set.rend(); - bundle_it++) { - // If we're supposed to overwrite, try to overwrite. One of these - // BinBundles has to have the data we're trying to overwrite. - if (overwrite) { - // If we successfully overwrote, we're done with this bundle - written = bundle_it->try_multi_overwrite(data, bin_idx); - if (written) { - break; - } - } - - // Do a dry-run insertion and see if the new largest bin size in the range - // exceeds the limit - int32_t new_largest_bin_size = - bundle_it->multi_insert_dry_run(data, bin_idx); - - // Check if inserting would violate the max bin size constraint - if (new_largest_bin_size > 0 && - seal::util::safe_cast(new_largest_bin_size) < max_bin_size) { - // All good - bundle_it->multi_insert_for_real(data, bin_idx); - written = true; - break; - } - } - - // We tried to overwrite an item that doesn't exist. This should never - // happen - if (overwrite && !written) { - SPDLOG_ERROR( - "Insert-or-Assign worker: " - "failed to overwrite item at bundle index {} because the item was " - "not found", - bundle_idx); - SPU_THROW("tried to overwrite non-existent item"); - } - - // If we had conflicts everywhere when trying to insert, then we need to - // make a new BinBundle and insert the data there - if (!written) { - // Make a fresh BinBundle and insert - apsi::sender::BinBundle new_bin_bundle( - crypto_context, label_size, max_bin_size, ps_low_degree, - bins_per_bundle, compressed, false); - int res = new_bin_bundle.multi_insert_for_real(data, bin_idx); - - // If even that failed, I don't know what could've happened - if (res < 0) { - SPDLOG_ERROR( - "Insert-or-Assign worker: " - "failed to insert item into a new BinBundle at bundle index {}", - bundle_idx); - SPU_THROW("failed to insert item into a new BinBundle"); - } - - // Push a new BinBundle to the set of BinBundles at this bundle index - bundle_set.push_back(std::move(new_bin_bundle)); - } - } - - const auto db_save_start = std::chrono::system_clock::now(); - - for (auto &bundle : bundle_set) { - std::stringstream stream; - - // Generate the BinBundle caches - // bundle.regen_cache(); - bundle.strip(); - - size_t store_idx = (*bundles_store_idx)[bundle_index]++; - bundle.save(stream, store_idx); - - (*bundles_store)[bundle_index]->Put(store_idx, stream.str()); - } - - const auto db_save_end = std::chrono::system_clock::now(); - const DurationMillis db_save_duration = db_save_end - db_save_start; - SPDLOG_INFO("*** step leveldb put duration:{}", db_save_duration.count()); - - SPDLOG_DEBUG("Insert-or-Assign worker: finished processing bundle index {}", - bundle_index); -} - -void InsertOrAssignWorker( - const std::shared_ptr &indices_store, - size_t indices_count, - std::vector> *bundles_store, - std::vector *bundles_store_idx, bool is_labeled, - const apsi::CryptoContext &crypto_context, uint32_t bundle_index, - uint32_t bins_per_bundle, size_t label_size, size_t max_bin_size, - - size_t ps_low_degree, bool overwrite, bool compressed) { - STOPWATCH(sender_stopwatch, "insert_or_assign_worker"); - - SPDLOG_DEBUG( - "Insert-or-Assign worker for bundle index {}; mode of operation: {}", - bundle_index, overwrite ? "overwriting existing" : "inserting new"); - - // Create the bundle set at the given bundle index - std::vector bundle_set; - - // Iteratively insert each item-label pair at the given cuckoo index - for (size_t i = 0; i < indices_count; ++i) { - yacl::Buffer value; - bool get_status = indices_store->Get(i, &value); - SPU_ENFORCE(get_status, "get_status:{}", get_status); - - size_t cuckoo_idx; - - std::pair datalabel_with_idx; - std::pair data_with_idx; - - if (is_labeled) { - datalabel_with_idx = DeserializeDataLabelWithIndices(std::string_view( - reinterpret_cast(value.data()), value.size())); - - cuckoo_idx = datalabel_with_idx.second; - } else { - data_with_idx = DeserializeDataWithIndices(std::string_view( - reinterpret_cast(value.data()), value.size())); - cuckoo_idx = data_with_idx.second; - } - - // const apsi::util::AlgItem &data = data_with_idx.first; - - // Get the bundle index - size_t bin_idx, bundle_idx; - std::tie(bin_idx, bundle_idx) = - UnpackCuckooIdx(cuckoo_idx, bins_per_bundle); - - // If the bundle_idx isn't in the prescribed range, don't try to insert - // this data - if (bundle_idx != bundle_index) { - // Dealing with this bundle index is not our job - continue; - } - - // Try to insert or overwrite these field elements in an existing - // BinBundle at this bundle index. Keep track of whether or not we - // succeed. - bool written = false; - - for (auto bundle_it = bundle_set.rbegin(); bundle_it != bundle_set.rend(); - bundle_it++) { - // If we're supposed to overwrite, try to overwrite. One of these - // BinBundles has to have the data we're trying to overwrite. - if (overwrite) { - // If we successfully overwrote, we're done with this bundle - // written = bundle_it->try_multi_overwrite(data, bin_idx); - if (is_labeled) { - written = - bundle_it->try_multi_overwrite(datalabel_with_idx.first, bin_idx); - - } else { - written = - bundle_it->try_multi_overwrite(data_with_idx.first, bin_idx); - } - - if (written) { - break; - } - } - - // Do a dry-run insertion and see if the new largest bin size in the - // range exceeds the limit - // int32_t new_largest_bin_size = bundle_it->multi_insert_dry_run(data, - // bin_idx); - int32_t new_largest_bin_size; - if (is_labeled) { - new_largest_bin_size = - bundle_it->multi_insert_dry_run(datalabel_with_idx.first, bin_idx); - - } else { - new_largest_bin_size = - bundle_it->multi_insert_dry_run(data_with_idx.first, bin_idx); - } - - // Check if inserting would violate the max bin size constraint - if (new_largest_bin_size > 0 && - seal::util::safe_cast(new_largest_bin_size) < max_bin_size) { - // All good - // bundle_it->multi_insert_for_real(data, bin_idx); - if (is_labeled) { - bundle_it->multi_insert_for_real(datalabel_with_idx.first, bin_idx); - - } else { - bundle_it->multi_insert_for_real(data_with_idx.first, bin_idx); - } - written = true; - break; - } - } - - // We tried to overwrite an item that doesn't exist. This should never - // happen - if (overwrite && !written) { - SPDLOG_ERROR( - "Insert-or-Assign worker: " - "failed to overwrite item at bundle index {} because the item was " - "not found", - bundle_idx); - SPU_THROW("tried to overwrite non-existent item"); - } - - // If we had conflicts everywhere when trying to insert, then we need to - // make a new BinBundle and insert the data there - if (!written) { - // Make a fresh BinBundle and insert - apsi::sender::BinBundle new_bin_bundle( - crypto_context, label_size, max_bin_size, ps_low_degree, - bins_per_bundle, compressed, false); - - // int res = new_bin_bundle.multi_insert_for_real(data, bin_idx); - int res; - if (is_labeled) { - res = new_bin_bundle.multi_insert_for_real(datalabel_with_idx.first, - bin_idx); - - } else { - res = - new_bin_bundle.multi_insert_for_real(data_with_idx.first, bin_idx); - } - - // If even that failed, I don't know what could've happened - if (res < 0) { - SPDLOG_ERROR( - "Insert-or-Assign worker: " - "failed to insert item into a new BinBundle at bundle index {}", - bundle_idx); - SPU_THROW("failed to insert item into a new BinBundle"); - } - - // Push a new BinBundle to the set of BinBundles at this bundle index - bundle_set.push_back(std::move(new_bin_bundle)); - } - } - - const auto db_save_start = std::chrono::system_clock::now(); - - for (auto &bundle : bundle_set) { - std::stringstream stream; - - // Generate the BinBundle caches - // bundle.regen_cache(); - bundle.strip(); - - size_t store_idx = (*bundles_store_idx)[bundle_index]++; - bundle.save(stream, store_idx); - - (*bundles_store)[bundle_index]->Put(store_idx, stream.str()); - } - const auto db_save_end = std::chrono::system_clock::now(); - const DurationMillis db_save_duration = db_save_end - db_save_start; - SPDLOG_INFO("*** step leveldb put duration:{}", db_save_duration.count()); - - SPDLOG_DEBUG("Insert-or-Assign worker: finished processing bundle index {}", - bundle_index); -} - -/** -Takes algebraized data to be inserted, splits it up, and distributes it so -that thread_count many threads can all insert in parallel. If overwrite is -set, this will overwrite the labels if it finds an AlgItemLabel that matches -the input perfectly. -*/ -template -void DispatchInsertOrAssign( - const std::vector> &data_with_indices, - std::vector> *bundles_store, - std::vector *bundles_store_idx, - const apsi::CryptoContext &crypto_context, uint32_t bins_per_bundle, - size_t label_size, uint32_t max_bin_size, uint32_t ps_low_degree, - bool overwrite, bool compressed) { - apsi::ThreadPoolMgr tpm; - - // Collect the bundle indices and partition them into thread_count many - // partitions. By some uniformity assumption, the number of things to insert - // per partition should be roughly the same. Note that the contents of - // bundle_indices is always sorted (increasing order). - std::set bundle_indices_set; - for (auto &data_with_idx : data_with_indices) { - size_t cuckoo_idx = data_with_idx.second; - size_t bin_idx; - size_t bundle_idx; - std::tie(bin_idx, bundle_idx) = - UnpackCuckooIdx(cuckoo_idx, bins_per_bundle); - bundle_indices_set.insert(bundle_idx); - } +} // namespace labeled_psi - // Copy the set of indices into a vector and sort so each thread processes a - // range of indices - std::vector bundle_indices; - bundle_indices.reserve(bundle_indices_set.size()); - std::copy(bundle_indices_set.begin(), bundle_indices_set.end(), - std::back_inserter(bundle_indices)); - std::sort(bundle_indices.begin(), bundle_indices.end()); - - // Run the threads on the partitions - std::vector> futures(bundle_indices.size()); - SPDLOG_INFO("Launching {} insert-or-assign worker tasks", - bundle_indices.size()); - size_t future_idx = 0; - for (auto &bundle_idx : bundle_indices) { - futures[future_idx++] = tpm.thread_pool().enqueue([&, bundle_idx]() { - InsertOrAssignWorker(data_with_indices, bundles_store, bundles_store_idx, - crypto_context, static_cast(bundle_idx), - bins_per_bundle, label_size, max_bin_size, - ps_low_degree, overwrite, compressed); - }); - } - - // Wait for the tasks to finish - for (auto &f : futures) { - f.get(); - } - - SPDLOG_INFO("Finished insert-or-assign worker tasks"); -} - -void DispatchInsertOrAssign( - const std::shared_ptr &indices_store, - size_t indices_count, const std::set &bundle_indices_set, - std::vector> *bundles_store, - std::vector *bundles_store_idx, bool is_labeled, - const apsi::CryptoContext &crypto_context, uint32_t bins_per_bundle, - size_t label_size, uint32_t max_bin_size, uint32_t ps_low_degree, - bool overwrite, bool compressed) { - apsi::ThreadPoolMgr tpm; - - std::vector bundle_indices; - bundle_indices.reserve(bundle_indices_set.size()); - std::copy(bundle_indices_set.begin(), bundle_indices_set.end(), - std::back_inserter(bundle_indices)); - std::sort(bundle_indices.begin(), bundle_indices.end()); - - // Run the threads on the partitions - std::vector> futures(bundle_indices.size()); - SPDLOG_INFO("Launching {} insert-or-assign worker tasks", - bundle_indices.size()); - size_t future_idx = 0; - for (auto &bundle_idx : bundle_indices) { - futures[future_idx++] = tpm.thread_pool().enqueue([&, bundle_idx]() { - InsertOrAssignWorker(indices_store, indices_count, bundles_store, - bundles_store_idx, is_labeled, crypto_context, - static_cast(bundle_idx), bins_per_bundle, - label_size, max_bin_size, ps_low_degree, overwrite, - compressed); - }); - } - - // Wait for the tasks to finish - for (auto &f : futures) { - f.get(); - } - - SPDLOG_INFO("Finished insert-or-assign worker tasks"); -} - -constexpr char kMetaInfoStoreName[] = "meta_info"; -constexpr char kServerDataCount[] = "server_data_count"; - -constexpr char kMemoryStoreFlag[] = "::memory"; - -} // namespace - -SenderDB::SenderDB(const apsi::PSIParams ¶ms, - std::string_view kv_store_path, size_t label_byte_count, - size_t nonce_byte_count, bool compressed) +ISenderDB::ISenderDB(const apsi::PSIParams ¶ms, + yacl::ByteContainerView oprf_key, + std::size_t label_byte_count, std::size_t nonce_byte_count, + bool compressed) : params_(params), crypto_context_(params_), label_byte_count_(label_byte_count), nonce_byte_count_(label_byte_count_ != 0 ? nonce_byte_count : 0), item_count_(0), - compressed_(compressed), - kv_store_path_(kv_store_path) { + compressed_(compressed) { // The labels cannot be more than 1 KB. if (label_byte_count_ > 1024) { SPDLOG_ERROR("Requested label byte count {} exceeds the maximum (1024)", @@ -734,60 +161,6 @@ SenderDB::SenderDB(const apsi::PSIParams ¶ms, // Set the evaluator. This will be used for BatchedPlaintextPolyn::eval. crypto_context_.set_evaluator(); - // Reset the SenderDB data structures - clear(); - - bundles_store_.resize(params_.bundle_idx_count()); - bundles_store_idx_.resize(params_.bundle_idx_count()); - - if (kv_store_path == kMemoryStoreFlag) { - meta_info_store_ = std::make_shared(); - - for (size_t i = 0; i < bundles_store_.size(); ++i) { - std::shared_ptr kv_store = - std::make_shared(); - - bundles_store_[i] = std::make_shared(kv_store); - } - } else { - std::string meta_store_name = - fmt::format("{}/{}", kv_store_path, kMetaInfoStoreName); - meta_info_store_ = - std::make_shared(false, meta_store_name); - - for (size_t i = 0; i < bundles_store_.size(); ++i) { - std::string bundle_store_name = - fmt::format("{}/bundle_{}", kv_store_path_, i); - - std::shared_ptr kv_store = - std::make_shared(false, bundle_store_name); - - bundles_store_[i] = std::make_shared(kv_store); - } - } - - try { - yacl::Buffer temp_value; - - meta_info_store_->Get(kServerDataCount, &temp_value); - item_count_ = std::stoul(std::string(std::string_view( - reinterpret_cast(temp_value.data()), temp_value.size()))); - - for (size_t i = 0; i < bundles_store_idx_.size(); ++i) { - bundles_store_idx_[i] = bundles_store_[i]->Count(); - } - - } catch (const std::exception &e) { - SPDLOG_INFO("key item_count no value"); - } -} - -SenderDB::SenderDB(const apsi::PSIParams ¶ms, - yacl::ByteContainerView oprf_key, - std::string_view kv_store_path, size_t label_byte_count, - size_t nonce_byte_count, bool compressed) - : SenderDB(params, kv_store_path, label_byte_count, nonce_byte_count, - compressed) { oprf_key_.resize(oprf_key.size()); std::memcpy(oprf_key_.data(), oprf_key.data(), oprf_key.size()); @@ -796,89 +169,7 @@ SenderDB::SenderDB(const apsi::PSIParams ¶ms, oprf_server_->SetCompareLength(kEccKeySize); } -SenderDB::SenderDB(SenderDB &&source) noexcept - : params_(source.params_), - crypto_context_(source.crypto_context_), - label_byte_count_(source.label_byte_count_), - nonce_byte_count_(source.nonce_byte_count_), - item_count_(source.item_count_), - compressed_(source.compressed_), - stripped_(source.stripped_) { - // Lock the source before moving stuff over - auto lock = source.GetWriterLock(); - - hashed_items_ = std::move(source.hashed_items_); - // bin_bundles_ = std::move(source.bin_bundles_); - - std::vector oprf_key = source.GetOprfKey(); - oprf_key_.resize(oprf_key.size()); - std::memcpy(oprf_key_.data(), oprf_key.data(), oprf_key.size()); - - oprf_server_ = - CreateEcdhOprfServer(oprf_key, OprfType::Basic, CurveType::CURVE_FOURQ); - - // Reset the source data structures - source.ClearInternal(); -} - -SenderDB &SenderDB::operator=(SenderDB &&source) noexcept { - // Do nothing if moving to self - if (&source == this) { - return *this; - } - - // Lock the current SenderDB - auto this_lock = GetWriterLock(); - - params_ = source.params_; - crypto_context_ = source.crypto_context_; - label_byte_count_ = source.label_byte_count_; - nonce_byte_count_ = source.nonce_byte_count_; - item_count_ = source.item_count_; - compressed_ = source.compressed_; - stripped_ = source.stripped_; - - // Lock the source before moving stuff over - auto source_lock = source.GetWriterLock(); - - hashed_items_ = std::move(source.hashed_items_); - // bin_bundles_ = std::move(source.bin_bundles_); - - std::vector oprf_key = source.GetOprfKey(); - oprf_key_.resize(oprf_key.size()); - std::memcpy(oprf_key_.data(), oprf_key.data(), oprf_key.size()); - - oprf_server_ = - CreateEcdhOprfServer(oprf_key, OprfType::Basic, CurveType::CURVE_FOURQ); - - // Reset the source data structures - source.ClearInternal(); - - return *this; -} - -size_t SenderDB::GetBinBundleCount(uint32_t bundle_idx) const { - // Lock the database for reading - auto lock = GetReaderLock(); - - // return bin_bundles_.at(seal::util::safe_cast(bundle_idx)).size(); - return bundles_store_idx_[seal::util::safe_cast(bundle_idx)]; -} - -size_t SenderDB::GetBinBundleCount() const { - // Lock the database for reading - auto lock = GetReaderLock(); - - // Compute the total number of BinBundles - // return std::accumulate(bin_bundles_.cbegin(), bin_bundles_.cend(), - // static_cast(0), - // [&](auto &a, auto &b) { return a + b.size(); }); - return std::accumulate(bundles_store_idx_.cbegin(), bundles_store_idx_.cend(), - static_cast(0), - [&](auto &a, auto &b) { return a + b; }); -} - -double SenderDB::GetPackingRate() const { +double ISenderDB::GetPackingRate() const { // Lock the database for reading auto lock = GetReaderLock(); @@ -895,88 +186,7 @@ double SenderDB::GetPackingRate() const { : 0.0; } -void SenderDB::ClearInternal() { - // Assume the SenderDB is already locked for writing - - // Clear the set of inserted items - hashed_items_.clear(); - item_count_ = 0; - - // Reset the stripped_ flag - stripped_ = false; -} - -void SenderDB::clear() { - if (!hashed_items_.empty()) { - SPDLOG_INFO("Removing {} items pairs from SenderDB", hashed_items_.size()); - } - - // Lock the database for writing - auto lock = GetWriterLock(); - - ClearInternal(); -} - -void SenderDB::GenerateCaches() { - STOPWATCH(sender_stopwatch, "SenderDB::GenerateCaches"); - SPDLOG_INFO("Start generating bin bundle caches"); - - SPDLOG_INFO("Finished generating bin bundle caches"); -} - -std::shared_ptr SenderDB::GetCacheAt( - uint32_t bundle_idx, size_t cache_idx) { - yacl::Buffer value; - - bool get_status = bundles_store_[bundle_idx]->Get(cache_idx, &value); - - SPU_ENFORCE(get_status); - - size_t label_size = - ComputeLabelSize(nonce_byte_count_ + label_byte_count_, params_); - - uint32_t bins_per_bundle = params_.bins_per_bundle(); - uint32_t max_bin_size = params_.table_params().max_items_per_bin; - uint32_t ps_low_degree = params_.query_params().ps_low_degree; - - bool compressed = false; - - std::shared_ptr load_bin_bundle = - std::make_shared( - crypto_context_, label_size, max_bin_size, ps_low_degree, - bins_per_bundle, compressed, false); - - gsl::span value_span = { - reinterpret_cast(value.data()), - gsl::narrow_cast::size_type>(value.size())}; - std::pair load_ret = - load_bin_bundle->load(value_span); - - SPU_ENFORCE(load_ret.first == cache_idx); - - // check cache is valid - if (load_bin_bundle->cache_invalid()) { - load_bin_bundle->regen_cache(); - } - - return load_bin_bundle; -} - -void SenderDB::strip() { - // Lock the database for writing - auto lock = GetWriterLock(); - - stripped_ = true; - - memset(oprf_key_.data(), 0, oprf_key_.size()); - hashed_items_.clear(); - - apsi::ThreadPoolMgr tpm; - - SPDLOG_INFO("SenderDB has been stripped"); -} - -std::vector SenderDB::GetOprfKey() const { +std::vector ISenderDB::GetOprfKey() const { if (stripped_) { SPDLOG_ERROR("Cannot return the OPRF key from a stripped SenderDB"); SPU_THROW("failed to return OPRF key"); @@ -984,357 +194,4 @@ std::vector SenderDB::GetOprfKey() const { return oprf_key_; } -void SenderDB::InsertOrAssign( - const std::vector> &data) { - if (stripped_) { - SPDLOG_ERROR("Cannot insert data to a stripped SenderDB"); - SPU_THROW("failed to insert data"); - } - if (!IsLabeled()) { - SPDLOG_ERROR( - "Attempted to insert labeled data but this is an unlabeled SenderDB"); - SPU_THROW("failed to insert data"); - } - - SPDLOG_INFO("Start inserting {} items in SenderDB", data.size()); - - // First compute the hashes for the input data - // auto hashed_data = OPRFSender::ComputeHashes( - // data, oprf_key_, label_byte_count_, nonce_byte_count_); - std::vector data_str(data.size()); - for (size_t i = 0; i < data.size(); ++i) { - std::string temp_str(data[i].first.value().size(), '\0'); - - std::memcpy(temp_str.data(), data[i].first.value().data(), - data[i].first.value().size()); - data_str[i] = temp_str; - } - std::vector oprf_out = oprf_server_->FullEvaluate(data_str); - std::vector> hashed_data; - for (size_t i = 0; i < oprf_out.size(); ++i) { - apsi::HashedItem hashed_item; - std::memcpy(hashed_item.value().data(), oprf_out[i].data(), - hashed_item.value().size()); - - apsi::LabelKey key; - std::memcpy(key.data(), &oprf_out[i][hashed_item.value().size()], - key.size()); - - apsi::EncryptedLabel encrypted_label = apsi::util::encrypt_label( - data[i].second, key, label_byte_count_, nonce_byte_count_); - - hashed_data.emplace_back(hashed_item, encrypted_label); - } - - // Lock the database for writing - auto lock = GetWriterLock(); - - // We need to know which items are new and which are old, since we have to - // tell dispatch_insert_or_assign when to have an overwrite-on-collision - // versus add-binbundle-on-collision policy. - auto new_data_end = std::remove_if( - hashed_data.begin(), hashed_data.end(), [&](const auto &item_label_pair) { - bool found = - hashed_items_.find(item_label_pair.first) != hashed_items_.end(); - if (!found) { - // Add to hashed_items_ already at this point! - hashed_items_.insert(item_label_pair.first); - item_count_++; - } - - // Remove those that were found - return found; - }); - - // Dispatch the insertion, first for the new data, then for the data we're - // gonna overwrite - uint32_t bins_per_bundle = params_.bins_per_bundle(); - uint32_t max_bin_size = params_.table_params().max_items_per_bin; - uint32_t ps_low_degree = params_.query_params().ps_low_degree; - - // Compute the label size; this ceil(effective_label_bit_count / - // item_bit_count) - size_t label_size = - ComputeLabelSize(nonce_byte_count_ + label_byte_count_, params_); - - auto new_item_count = std::distance(hashed_data.begin(), new_data_end); - auto existing_item_count = std::distance(new_data_end, hashed_data.end()); - - if (existing_item_count != 0) { - SPDLOG_INFO("Found {} existing items to replace in SenderDB", - existing_item_count); - - // Break the data into field element representation. Also compute the - // items' cuckoo indices. - std::vector> data_with_indices = - PreprocessLabeledData(new_data_end, hashed_data.end(), params_); - - DispatchInsertOrAssign(data_with_indices, &bundles_store_, - &bundles_store_idx_, crypto_context_, - bins_per_bundle, label_size, max_bin_size, - ps_low_degree, true, /* overwrite items */ - compressed_); - - // Release memory that is no longer needed - hashed_data.erase(new_data_end, hashed_data.end()); - } - - if (new_item_count != 0) { - SPDLOG_INFO("Found {} new items to insert in SenderDB", new_item_count); - - // Process and add the new data. Break the data into field element - // representation. Also compute the items' cuckoo indices. - std::vector> data_with_indices = - PreprocessLabeledData(hashed_data.begin(), hashed_data.end(), params_); - - DispatchInsertOrAssign(data_with_indices, &bundles_store_, - &bundles_store_idx_, crypto_context_, - bins_per_bundle, label_size, max_bin_size, - ps_low_degree, false, /* don't overwrite items */ - compressed_); - } - - SPDLOG_INFO("Finished inserting {} items in SenderDB", data.size()); -} - -void SenderDB::InsertOrAssign(const std::vector &data) { - if (stripped_) { - SPDLOG_ERROR("Cannot insert data to a stripped SenderDB"); - SPU_THROW("failed to insert data"); - } - if (IsLabeled()) { - SPDLOG_ERROR( - "Attempted to insert unlabeled data but this is a labeled SenderDB"); - SPU_THROW("failed to insert data"); - } - - STOPWATCH(sender_stopwatch, "SenderDB::insert_or_assign (unlabeled)"); - SPDLOG_INFO("Start inserting {} items in SenderDB", data.size()); - - // First compute the hashes for the input data - // auto hashed_data = OPRFSender::ComputeHashes(data, oprf_key_); - std::vector data_str(data.size()); - for (size_t i = 0; i < data.size(); ++i) { - std::string item_str(data[i].value().size(), '\0'); - std::memcpy(item_str.data(), data[i].value().data(), - data[i].value().size()); - data_str[i] = item_str; - } - std::vector oprf_out = oprf_server_->FullEvaluate(data_str); - std::vector hashed_data; - for (const auto &out : oprf_out) { - apsi::Item::value_type value{}; - std::memcpy(value.data(), out.data(), value.size()); - - hashed_data.emplace_back(value); - } - - // Lock the database for writing - auto lock = GetWriterLock(); - - // We are not going to insert items that already appear in the database. - auto new_data_end = std::remove_if( - hashed_data.begin(), hashed_data.end(), [&](const auto &item) { - bool found = hashed_items_.find(item) != hashed_items_.end(); - if (!found) { - // Add to hashed_items_ already at this point! - hashed_items_.insert(item); - item_count_++; - } - - // Remove those that were found - return found; - }); - - // Erase the previously existing items from hashed_data; in unlabeled case - // there is nothing to do - hashed_data.erase(new_data_end, hashed_data.end()); - - SPDLOG_INFO("Found {} new items to insert in SenderDB", hashed_data.size()); - - // Break the new data down into its field element representation. Also - // compute the items' cuckoo indices. - std::vector> data_with_indices = - PreprocessUnlabeledData(hashed_data.begin(), hashed_data.end(), params_); - - // Dispatch the insertion - uint32_t bins_per_bundle = params_.bins_per_bundle(); - uint32_t max_bin_size = params_.table_params().max_items_per_bin; - uint32_t ps_low_degree = params_.query_params().ps_low_degree; - - DispatchInsertOrAssign( - data_with_indices, &bundles_store_, &bundles_store_idx_, crypto_context_, - bins_per_bundle, 0, /* label size */ - max_bin_size, ps_low_degree, false, /* don't overwrite items */ - compressed_); - - // Generate the BinBundle caches - GenerateCaches(); - - SPDLOG_INFO("Finished inserting {} items in SenderDB", data.size()); -} - -void SenderDB::InsertOrAssign( - const std::shared_ptr &batch_provider, size_t batch_size) { - [[maybe_unused]] size_t batch_count = 0; - size_t indices_count = 0; - - std::shared_ptr kv_store; - - if (kv_store_path_ == kMemoryStoreFlag) { - kv_store = std::make_shared(); - } else { - kv_store = std::make_shared(true); - } - - std::shared_ptr items_oprf_store = - std::make_shared(kv_store); - - std::set bundle_indices_set; - - // Dispatch the insertion - uint32_t bins_per_bundle = params_.bins_per_bundle(); - uint32_t max_bin_size = params_.table_params().max_items_per_bin; - uint32_t ps_low_degree = params_.query_params().ps_low_degree; - - while (true) { - std::vector batch_items; - std::vector batch_labels; - - if (IsLabeled()) { - std::tie(batch_items, batch_labels) = - batch_provider->ReadNextBatchWithLabel(batch_size); - } else { - batch_items = batch_provider->ReadNextBatch(batch_size); - } - - if (batch_items.empty()) { - break; - } - - std::vector oprf_out = oprf_server_->FullEvaluate(batch_items); - - std::vector>> - data_with_indices_vec; - - if (IsLabeled()) { - data_with_indices_vec.resize(oprf_out.size()); - size_t key_offset_pos = sizeof(apsi::Item::value_type); - - yacl::parallel_for( - 0, oprf_out.size(), 1, [&](int64_t begin, int64_t end) { - for (int64_t idx = begin; idx < end; ++idx) { - apsi::Item::value_type value{}; - std::memcpy(value.data(), &oprf_out[idx][0], value.size()); - - apsi::HashedItem hashed_item(value); - - apsi::LabelKey key; - std::memcpy(key.data(), &oprf_out[idx][key_offset_pos], - apsi::label_key_byte_count); - - apsi::Label label_with_padding = - PaddingData(batch_labels[idx], label_byte_count_); - - apsi::EncryptedLabel encrypted_label = apsi::util::encrypt_label( - label_with_padding, key, label_byte_count_, - nonce_byte_count_); - - std::pair - item_label_pair = - std::make_pair(hashed_item, encrypted_label); - - data_with_indices_vec[idx] = - PreprocessLabeledData(item_label_pair, params_); - } - }); - - for (size_t i = 0; i < oprf_out.size(); ++i) { - for (size_t j = 0; j < data_with_indices_vec[i].size(); ++j) { - std::string indices_buffer = - SerializeDataLabelWithIndices(data_with_indices_vec[i][j]); - - items_oprf_store->Put(indices_count + j, indices_buffer); - - size_t cuckoo_idx = data_with_indices_vec[i][j].second; - size_t bin_idx, bundle_idx; - std::tie(bin_idx, bundle_idx) = - UnpackCuckooIdx(cuckoo_idx, bins_per_bundle); - bundle_indices_set.insert(bundle_idx); - } - - indices_count += data_with_indices_vec[i].size(); - } - } else { - for (size_t i = 0; i < oprf_out.size(); ++i) { - // - apsi::Item::value_type value{}; - std::memcpy(value.data(), &oprf_out[i][0], value.size()); - - apsi::HashedItem hashed_item(value); - - std::vector> data_with_indices = - PreprocessUnlabeledData(hashed_item, params_); - - for (size_t j = 0; j < data_with_indices.size(); ++j) { - std::string indices_buffer = - SerializeDataWithIndices(data_with_indices[j]); - - items_oprf_store->Put(indices_count + j, indices_buffer); - - size_t cuckoo_idx = data_with_indices[j].second; - - size_t bin_idx, bundle_idx; - std::tie(bin_idx, bundle_idx) = - UnpackCuckooIdx(cuckoo_idx, bins_per_bundle); - - bundle_indices_set.insert(bundle_idx); - } - indices_count += data_with_indices.size(); - } - } - item_count_ += batch_items.size(); - - batch_count++; - } - - size_t label_size = 0; - if (IsLabeled()) { - label_size = - ComputeLabelSize(nonce_byte_count_ + label_byte_count_, params_); - } - - DispatchInsertOrAssign( - items_oprf_store, indices_count, bundle_indices_set, &bundles_store_, - &bundles_store_idx_, IsLabeled(), crypto_context_, bins_per_bundle, - label_size, /* label size */ - max_bin_size, ps_low_degree, false, /* don't overwrite items */ - compressed_); - - SPDLOG_INFO("Finished inserting {} items in SenderDB", indices_count); -} - -bool SenderDB::HasItem(const apsi::Item &item) const { - if (stripped_) { - SPDLOG_ERROR( - "Cannot retrieve the presence of an item from a stripped SenderDB"); - SPU_THROW("failed to retrieve the presence of item"); - } - - // First compute the hash for the input item - // auto hashed_item = OPRFSender::ComputeHashes({&item, 1}, oprf_key_)[0]; - std::string item_str; - item_str.reserve(item.value().size()); - std::memcpy(item_str.data(), item.value().data(), item.value().size()); - std::string oprf_out = oprf_server_->FullEvaluate(item_str); - apsi::HashedItem hashed_item; - std::memcpy(hashed_item.value().data(), oprf_out.data(), - hashed_item.value().size()); - - // Lock the database for reading - auto lock = GetReaderLock(); - - return hashed_items_.find(hashed_item) != hashed_items_.end(); -} - } // namespace spu::psi diff --git a/libspu/psi/core/labeled_psi/sender_db.h b/libspu/psi/core/labeled_psi/sender_db.h index e2e473b4..54c081a4 100644 --- a/libspu/psi/core/labeled_psi/sender_db.h +++ b/libspu/psi/core/labeled_psi/sender_db.h @@ -73,210 +73,104 @@ SenderDB. The downside of in-memory compression is a performance reduction from decompressing parts of the data when they are used, and recompressing them if they are updated. */ -class SenderDB { +class ISenderDB { public: /** Creates a new SenderDB. */ - explicit SenderDB(const apsi::PSIParams ¶ms, - std::string_view kv_store_path, - std::size_t label_byte_count = 0, - std::size_t nonce_byte_count = 16, bool compressed = true); + ISenderDB(const apsi::PSIParams ¶ms, yacl::ByteContainerView oprf_key, + std::size_t label_byte_count = 0, std::size_t nonce_byte_count = 16, + bool compressed = false); - /** - Creates a new SenderDB. - */ - SenderDB(const apsi::PSIParams ¶ms, yacl::ByteContainerView oprf_key, - std::string_view kv_store_path = "", - std::size_t label_byte_count = 0, std::size_t nonce_byte_count = 16, - bool compressed = true); - - /** - Creates a new SenderDB by moving from an existing one. - */ - SenderDB(SenderDB &&source) noexcept; - - SenderDB(const SenderDB ©) = delete; - - /** - Moves an existing SenderDB to the current one. - */ - SenderDB &operator=(SenderDB &&source) noexcept; - - /** - Clears the database. Every item and label will be removed. The OPRF key is - unchanged. - */ - void clear(); + virtual ~ISenderDB() { std::memset(oprf_key_.data(), 0, oprf_key_.size()); } /** Returns whether this is a labeled SenderDB. */ - bool IsLabeled() const { return 0 != label_byte_count_; } + virtual bool IsLabeled() const { return 0 != label_byte_count_; } /** Returns the label byte count. A zero value indicates an unlabeled SenderDB. */ - std::size_t GetLabelByteCount() const { return label_byte_count_; } + virtual std::size_t GetLabelByteCount() const { return label_byte_count_; } /** Returns the nonce byte count used for encrypting labels. */ - std::size_t GetNonceByteCount() const { return nonce_byte_count_; } + virtual std::size_t GetNonceByteCount() const { return nonce_byte_count_; } /** Indicates whether SEAL plaintexts are compressed in memory. */ - bool IsCompressed() const { return compressed_; } + virtual bool IsCompressed() const { return compressed_; } /** Indicates whether the SenderDB has been stripped of all information not needed for serving a query. */ - bool IsStripped() const { return stripped_; } - - /** - Strips the SenderDB of all information not needed for serving a query. - Returns a copy of the OPRF key and clears it from the SenderDB. - */ - void strip(); - - /** - Inserts the given data into the database. This function can be used only on - a labeled SenderDB instance. If an item already exists in the database, its - label is overwritten with the new label. - */ - void InsertOrAssign( - const std::vector> &data); - - /** - Inserts the given (hashed) item-label pair into the database. This function - can be used only on a labeled SenderDB instance. If the item already exists - in the database, its label is overwritten with the new label. - */ - void InsertOrAssign(const std::pair &data) { - std::vector> data_singleton{data}; - InsertOrAssign(data_singleton); - } - - /** - Inserts the given data into the database. This function can be used only on - an unlabeled SenderDB instance. - */ - void InsertOrAssign(const std::vector &data); - - /** - * @brief Insert data from BatchProvider - * - * @param batch_provider - */ - void InsertOrAssign(const std::shared_ptr &batch_provider, - size_t batch_size); - - /** - Clears the database and inserts the given data. This function can be used - only on a labeled SenderDB instance. - */ - void SetData(const std::vector> &data) { - clear(); - InsertOrAssign(data); - } - - /** - Clears the database and inserts the given data. This function can be used - only on an unlabeled SenderDB instance. - */ - void SetData(const std::vector &data) { - clear(); - InsertOrAssign(data); - } + virtual bool IsStripped() const { return stripped_; } - void SetData(const std::shared_ptr &batch_provider, - size_t batch_size = 40960) { - clear(); - InsertOrAssign(batch_provider, batch_size); - } - - /** - Returns whether the given item has been inserted in the SenderDB. - */ - bool HasItem(const apsi::Item &item) const; + virtual void SetData(const std::shared_ptr &batch_provider, + size_t batch_size = 500000) = 0; /** Returns the bundle at the given bundle index. */ - std::shared_ptr GetCacheAt(std::uint32_t bundle_idx, - size_t cache_idx); + virtual std::shared_ptr GetBinBundleAt( + std::uint32_t bundle_idx, size_t cache_idx) = 0; /** Returns a reference to the PSI parameters for this SenderDB. */ - const apsi::PSIParams &GetParams() const { return params_; } + virtual const apsi::PSIParams &GetParams() const { return params_; } /** Returns a reference to the CryptoContext for this SenderDB. */ - const apsi::CryptoContext &GetCryptoContext() const { + virtual const apsi::CryptoContext &GetCryptoContext() const { return crypto_context_; } /** Returns a reference to the SEALContext for this SenderDB. */ - std::shared_ptr GetSealContext() const { + virtual std::shared_ptr GetSealContext() const { return crypto_context_.seal_context(); } - /** - Returns a reference to a set of item hashes already existing in the - SenderDB. - */ - const std::unordered_set &GetHashedItems() const { - return hashed_items_; - } - /** Returns the number of items in this SenderDB. */ - size_t GetItemCount() const { return item_count_; } + virtual size_t GetItemCount() const { return item_count_; } /** Returns the total number of bin bundles at a specific bundle index. */ - std::size_t GetBinBundleCount(std::uint32_t bundle_idx) const; + virtual std::size_t GetBinBundleCount(std::uint32_t bundle_idx) const = 0; /** Returns the total number of bin bundles. */ - std::size_t GetBinBundleCount() const; + virtual std::size_t GetBinBundleCount() const = 0; /** Returns how efficiently the SenderDB is packaged. A higher rate indicates better performance and a lower communication cost in a query execution. */ - double GetPackingRate() const; + virtual double GetPackingRate() const; /** Obtains a scoped lock preventing the SenderDB from being changed. */ - seal::util::ReaderLock GetReaderLock() const { + virtual seal::util::ReaderLock GetReaderLock() const { return db_lock_.acquire_read(); } - std::vector GetOprfKey() const; + virtual std::vector GetOprfKey() const; - private: + protected: seal::util::WriterLock GetWriterLock() { return db_lock_.acquire_write(); } - void ClearInternal(); - - void GenerateCaches(); - - /** - The set of all items that have been inserted into the database - */ - std::unordered_set hashed_items_; - /** The PSI parameters define the SEAL parameters, base field, item size, table size, etc. @@ -323,19 +217,6 @@ class SenderDB { */ bool stripped_; - /** - All the BinBundles in the database, indexed by bundle index. The set - (represented by a vector internally) at bundle index i contains all the - BinBundles with bundle index i. - */ - // std::vector> bin_bundles_; - - std::string kv_store_path_; - std::shared_ptr meta_info_store_; - - std::vector> bundles_store_; - std::vector bundles_store_idx_; - /** Holds the OPRF key for this SenderDB. */ @@ -343,4 +224,18 @@ class SenderDB { std::unique_ptr oprf_server_; }; // class SenderDB +namespace labeled_psi { + +std::vector HashFunctions(const apsi::PSIParams ¶ms); + +std::unordered_set AllLocations( + const std::vector &hash_funcs, const apsi::HashedItem &item); + +size_t ComputeLabelSize(size_t label_byte_count, const apsi::PSIParams ¶ms); + +std::pair UnpackCuckooIdx(size_t cuckoo_idx, + size_t bins_per_bundle); + +} // namespace labeled_psi + } // namespace spu::psi diff --git a/libspu/psi/core/labeled_psi/sender_kvdb.cc b/libspu/psi/core/labeled_psi/sender_kvdb.cc new file mode 100644 index 00000000..2989edf0 --- /dev/null +++ b/libspu/psi/core/labeled_psi/sender_kvdb.cc @@ -0,0 +1,844 @@ +// Copyright 2021 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// code reference https://github.com/microsoft/APSI/sender/sender_db.cpp +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +// we are using our own OPRF, the reason is we wanna make the oprf +// switchable between secp256k1, sm2 or other types + +// STD +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// APSI +#include "apsi/psi_params.h" +#include "apsi/thread_pool_mgr.h" +#include "apsi/util/db_encoding.h" +#include "apsi/util/label_encryptor.h" +#include "apsi/util/utils.h" +#include "spdlog/spdlog.h" + +#include "libspu/core/prelude.h" +#include "libspu/psi/core/ecdh_oprf/ecdh_oprf_selector.h" +#include "libspu/psi/core/labeled_psi/sender_kvdb.h" +#include "libspu/psi/core/labeled_psi/serialize.h" +#include "libspu/psi/utils/utils.h" + +// Kuku +#include "kuku/locfunc.h" + +// SEAL +#include "absl/strings/escaping.h" +#include "seal/util/common.h" +#include "seal/util/streambuf.h" +#include "yacl/crypto/utils/rand.h" +#include "yacl/utils/parallel.h" + +namespace spu::psi { + +namespace { + +using DurationMillis = std::chrono::duration; + +std::vector> PreprocessUnlabeledData( + const apsi::HashedItem &hashed_item, const apsi::PSIParams ¶ms) { + // Some variables we'll need + size_t bins_per_item = params.item_params().felts_per_item; + size_t item_bit_count = params.item_bit_count(); + + // Set up Kuku hash functions + auto hash_funcs = labeled_psi::HashFunctions(params); + + std::vector> data_with_indices; + + // Serialize the data into field elements + apsi::util::AlgItem alg_item = algebraize_item( + hashed_item, item_bit_count, params.seal_params().plain_modulus()); + + // Get the cuckoo table locations for this item and add to data_with_indices + for (auto location : labeled_psi::AllLocations(hash_funcs, hashed_item)) { + // The current hash value is an index into a table of Items. In reality + // our BinBundles are tables of bins, which contain chunks of items. How + // many chunks? bins_per_item many chunks + size_t bin_idx = location * bins_per_item; + + // Store the data along with its index + data_with_indices.emplace_back(std::make_pair(alg_item, bin_idx)); + } + + return data_with_indices; +} + +std::vector> PreprocessLabeledData( + const std::pair &item_label_pair, + const apsi::PSIParams ¶ms, + const std::vector &hash_funcs) { + SPDLOG_DEBUG("Start preprocessing {} labeled items", distance(begin, end)); + + // Some variables we'll need + size_t bins_per_item = params.item_params().felts_per_item; + size_t item_bit_count = params.item_bit_count(); + + std::vector> data_with_indices; + + // Serialize the data into field elements + const apsi::HashedItem &item = item_label_pair.first; + const apsi::EncryptedLabel &label = item_label_pair.second; + apsi::util::AlgItemLabel alg_item_label = algebraize_item_label( + item, label, item_bit_count, params.seal_params().plain_modulus()); + + std::set loc_set; + + // Get the cuckoo table locations for this item and add to data_with_indices + for (auto location : spu::psi::labeled_psi::AllLocations(hash_funcs, item)) { + // The current hash value is an index into a table of Items. In reality + // our BinBundles are tables of bins, which contain chunks of items. How + // many chunks? bins_per_item many chunks + if (loc_set.find(location) == loc_set.end()) { + size_t bin_idx = location * bins_per_item; + + // Store the data along with its index + data_with_indices.emplace_back(alg_item_label, bin_idx); + loc_set.insert(location); + } + } + + return data_with_indices; +} + +/** +Inserts the given items and corresponding labels into bin_bundles at their +respective cuckoo indices. It will only insert the data with bundle index in the +half-open range range indicated by work_range. If inserting into a BinBundle +would make the number of items in a bin larger than max_bin_size, this function +will create and insert a new BinBundle. If overwrite is set, this will overwrite +the labels if it finds an AlgItemLabel that matches the input perfectly. +*/ +template +void InsertOrAssignWorker( + const std::vector> &data_with_indices, + std::vector> *bundles_store, + std::vector *bundles_store_idx, + const apsi::CryptoContext &crypto_context, uint32_t bundle_index, + uint32_t bins_per_bundle, size_t label_size, size_t max_bin_size, + size_t ps_low_degree, bool overwrite, bool compressed) { + STOPWATCH(sender_stopwatch, "insert_or_assign_worker"); + SPDLOG_DEBUG( + "Insert-or-Assign worker for bundle index {}; mode of operation: {}", + bundle_index, overwrite ? "overwriting existing" : "inserting new"); + + // Create the bundle set at the given bundle index + std::vector bundle_set; + + // Iteratively insert each item-label pair at the given cuckoo index + for (auto &data_with_idx : data_with_indices) { + const T &data = data_with_idx.first; + + // Get the bundle index + size_t cuckoo_idx = data_with_idx.second; + size_t bin_idx; + size_t bundle_idx; + std::tie(bin_idx, bundle_idx) = + labeled_psi::UnpackCuckooIdx(cuckoo_idx, bins_per_bundle); + + // If the bundle_idx isn't in the prescribed range, don't try to insert this + // data + if (bundle_idx != bundle_index) { + // Dealing with this bundle index is not our job + continue; + } + + // Try to insert or overwrite these field elements in an existing BinBundle + // at this bundle index. Keep track of whether or not we succeed. + bool written = false; + + for (auto bundle_it = bundle_set.rbegin(); bundle_it != bundle_set.rend(); + bundle_it++) { + // If we're supposed to overwrite, try to overwrite. One of these + // BinBundles has to have the data we're trying to overwrite. + if (overwrite) { + // If we successfully overwrote, we're done with this bundle + written = bundle_it->try_multi_overwrite(data, bin_idx); + if (written) { + break; + } + } + + // Do a dry-run insertion and see if the new largest bin size in the range + // exceeds the limit + int32_t new_largest_bin_size = + bundle_it->multi_insert_dry_run(data, bin_idx); + + // Check if inserting would violate the max bin size constraint + if (new_largest_bin_size > 0 && + seal::util::safe_cast(new_largest_bin_size) < max_bin_size) { + // All good + bundle_it->multi_insert_for_real(data, bin_idx); + written = true; + break; + } + } + + // We tried to overwrite an item that doesn't exist. This should never + // happen + if (overwrite && !written) { + SPDLOG_ERROR( + "Insert-or-Assign worker: " + "failed to overwrite item at bundle index {} because the item was " + "not found", + bundle_idx); + SPU_THROW("tried to overwrite non-existent item"); + } + + // If we had conflicts everywhere when trying to insert, then we need to + // make a new BinBundle and insert the data there + if (!written) { + // Make a fresh BinBundle and insert + apsi::sender::BinBundle new_bin_bundle( + crypto_context, label_size, max_bin_size, ps_low_degree, + bins_per_bundle, compressed, false); + int res = new_bin_bundle.multi_insert_for_real(data, bin_idx); + + // If even that failed, I don't know what could've happened + if (res < 0) { + SPDLOG_ERROR( + "Insert-or-Assign worker: " + "failed to insert item into a new BinBundle at bundle index {}", + bundle_idx); + SPU_THROW("failed to insert item into a new BinBundle"); + } + + // Push a new BinBundle to the set of BinBundles at this bundle index + bundle_set.push_back(std::move(new_bin_bundle)); + } + } + + const auto db_save_start = std::chrono::system_clock::now(); + + for (auto &bundle : bundle_set) { + size_t store_idx = (*bundles_store_idx)[bundle_index]++; + + // Generate the BinBundle caches + // bundle.regen_cache(); + bundle.strip(); + + std::stringstream stream; + bundle.save(stream, store_idx); + + (*bundles_store)[bundle_index]->Put(store_idx, stream.str()); + } + + const auto db_save_end = std::chrono::system_clock::now(); + const DurationMillis db_save_duration = db_save_end - db_save_start; + SPDLOG_INFO("*** step leveldb put duration:{}", db_save_duration.count()); + + SPDLOG_DEBUG("Insert-or-Assign worker: finished processing bundle index {}", + bundle_index); +} + +void InsertOrAssignWorker( + const std::shared_ptr &indices_store, + size_t indices_count, + std::vector> *bundles_store, + std::vector *bundles_store_idx, bool is_labeled, + const apsi::CryptoContext &crypto_context, uint32_t bundle_index, + uint32_t bins_per_bundle, size_t label_size, size_t max_bin_size, + + size_t ps_low_degree, bool overwrite, bool compressed) { + STOPWATCH(sender_stopwatch, "insert_or_assign_worker"); + + SPDLOG_DEBUG( + "Insert-or-Assign worker for bundle index {}; mode of operation: {}", + bundle_index, overwrite ? "overwriting existing" : "inserting new"); + + // Create the bundle set at the given bundle index + std::vector bundle_set; + + // Iteratively insert each item-label pair at the given cuckoo index + for (size_t i = 0; i < indices_count; ++i) { + yacl::Buffer value; + bool get_status = indices_store->Get(i, &value); + SPU_ENFORCE(get_status, "get_status:{}", get_status); + + size_t cuckoo_idx; + + std::pair datalabel_with_idx; + std::pair data_with_idx; + + if (is_labeled) { + datalabel_with_idx = DeserializeDataLabelWithIndices(std::string_view( + reinterpret_cast(value.data()), value.size())); + + cuckoo_idx = datalabel_with_idx.second; + } else { + data_with_idx = DeserializeDataWithIndices(std::string_view( + reinterpret_cast(value.data()), value.size())); + cuckoo_idx = data_with_idx.second; + } + + // const apsi::util::AlgItem &data = data_with_idx.first; + + // Get the bundle index + size_t bin_idx, bundle_idx; + std::tie(bin_idx, bundle_idx) = + labeled_psi::UnpackCuckooIdx(cuckoo_idx, bins_per_bundle); + + // If the bundle_idx isn't in the prescribed range, don't try to insert + // this data + if (bundle_idx != bundle_index) { + // Dealing with this bundle index is not our job + continue; + } + + // Try to insert or overwrite these field elements in an existing + // BinBundle at this bundle index. Keep track of whether or not we + // succeed. + bool written = false; + + for (auto bundle_it = bundle_set.rbegin(); bundle_it != bundle_set.rend(); + bundle_it++) { + // If we're supposed to overwrite, try to overwrite. One of these + // BinBundles has to have the data we're trying to overwrite. + if (overwrite) { + // If we successfully overwrote, we're done with this bundle + // written = bundle_it->try_multi_overwrite(data, bin_idx); + if (is_labeled) { + written = + bundle_it->try_multi_overwrite(datalabel_with_idx.first, bin_idx); + + } else { + written = + bundle_it->try_multi_overwrite(data_with_idx.first, bin_idx); + } + + if (written) { + break; + } + } + + // Do a dry-run insertion and see if the new largest bin size in the + // range exceeds the limit + // int32_t new_largest_bin_size = bundle_it->multi_insert_dry_run(data, + // bin_idx); + int32_t new_largest_bin_size; + if (is_labeled) { + new_largest_bin_size = + bundle_it->multi_insert_dry_run(datalabel_with_idx.first, bin_idx); + + } else { + new_largest_bin_size = + bundle_it->multi_insert_dry_run(data_with_idx.first, bin_idx); + } + + // Check if inserting would violate the max bin size constraint + if (new_largest_bin_size > 0 && + seal::util::safe_cast(new_largest_bin_size) < max_bin_size) { + // All good + // bundle_it->multi_insert_for_real(data, bin_idx); + if (is_labeled) { + bundle_it->multi_insert_for_real(datalabel_with_idx.first, bin_idx); + } else { + bundle_it->multi_insert_for_real(data_with_idx.first, bin_idx); + } + written = true; + break; + } + } + + // We tried to overwrite an item that doesn't exist. This should never + // happen + if (overwrite && !written) { + SPDLOG_ERROR( + "Insert-or-Assign worker: " + "failed to overwrite item at bundle index {} because the item was " + "not found", + bundle_idx); + SPU_THROW("tried to overwrite non-existent item"); + } + + // If we had conflicts everywhere when trying to insert, then we need to + // make a new BinBundle and insert the data there + if (!written) { + // Make a fresh BinBundle and insert + apsi::sender::BinBundle new_bin_bundle( + crypto_context, label_size, max_bin_size, ps_low_degree, + bins_per_bundle, compressed, false); + + // int res = new_bin_bundle.multi_insert_for_real(data, bin_idx); + int res; + if (is_labeled) { + res = new_bin_bundle.multi_insert_for_real(datalabel_with_idx.first, + bin_idx); + + } else { + res = + new_bin_bundle.multi_insert_for_real(data_with_idx.first, bin_idx); + } + + // If even that failed, I don't know what could've happened + if (res < 0) { + SPDLOG_ERROR( + "Insert-or-Assign worker: " + "failed to insert item into a new BinBundle at bundle index {}", + bundle_idx); + SPU_THROW("failed to insert item into a new BinBundle"); + } + + // Push a new BinBundle to the set of BinBundles at this bundle index + bundle_set.push_back(std::move(new_bin_bundle)); + } + } + + const auto db_save_start = std::chrono::system_clock::now(); + + for (auto &bundle : bundle_set) { + size_t store_idx = (*bundles_store_idx)[bundle_index]++; + + SPDLOG_INFO( + "Polynomial Interpolate and HE Plaintext Encode, bundle_indx:{}, " + "store_idx:{}", + bundle_index, store_idx); + + // Generate the BinBundle caches + // bundle.regen_cache(); + bundle.strip(); + + std::stringstream stream; + bundle.save(stream, store_idx); + + (*bundles_store)[bundle_index]->Put(store_idx, stream.str()); + } + const auto db_save_end = std::chrono::system_clock::now(); + const DurationMillis db_save_duration = db_save_end - db_save_start; + SPDLOG_INFO("*** step leveldb put duration:{}", db_save_duration.count()); + + SPDLOG_DEBUG("Insert-or-Assign worker: finished processing bundle index {}", + bundle_index); +} + +/** +Takes algebraized data to be inserted, splits it up, and distributes it so +that thread_count many threads can all insert in parallel. If overwrite is +set, this will overwrite the labels if it finds an AlgItemLabel that matches +the input perfectly. +*/ +template +void DispatchInsertOrAssign( + const std::vector> &data_with_indices, + std::vector> *bundles_store, + std::vector *bundles_store_idx, + const apsi::CryptoContext &crypto_context, uint32_t bins_per_bundle, + size_t label_size, uint32_t max_bin_size, uint32_t ps_low_degree, + bool overwrite, bool compressed) { + apsi::ThreadPoolMgr tpm; + + // Collect the bundle indices and partition them into thread_count many + // partitions. By some uniformity assumption, the number of things to insert + // per partition should be roughly the same. Note that the contents of + // bundle_indices is always sorted (increasing order). + std::set bundle_indices_set; + for (auto &data_with_idx : data_with_indices) { + size_t cuckoo_idx = data_with_idx.second; + size_t bin_idx; + size_t bundle_idx; + std::tie(bin_idx, bundle_idx) = + labeled_psi::UnpackCuckooIdx(cuckoo_idx, bins_per_bundle); + bundle_indices_set.insert(bundle_idx); + } + + // Copy the set of indices into a vector and sort so each thread processes a + // range of indices + std::vector bundle_indices; + bundle_indices.reserve(bundle_indices_set.size()); + std::copy(bundle_indices_set.begin(), bundle_indices_set.end(), + std::back_inserter(bundle_indices)); + std::sort(bundle_indices.begin(), bundle_indices.end()); + + // Run the threads on the partitions + std::vector> futures(bundle_indices.size()); + SPDLOG_INFO("Launching {} insert-or-assign worker tasks", + bundle_indices.size()); + size_t future_idx = 0; + for (auto &bundle_idx : bundle_indices) { + futures[future_idx++] = tpm.thread_pool().enqueue([&, bundle_idx]() { + InsertOrAssignWorker(data_with_indices, bundles_store, bundles_store_idx, + crypto_context, static_cast(bundle_idx), + bins_per_bundle, label_size, max_bin_size, + ps_low_degree, overwrite, compressed); + }); + } + + // Wait for the tasks to finish + for (auto &f : futures) { + f.get(); + } + + SPDLOG_INFO("Finished insert-or-assign worker tasks"); +} + +void DispatchInsertOrAssign( + const std::shared_ptr &indices_store, + size_t indices_count, const std::set &bundle_indices_set, + std::vector> *bundles_store, + std::vector *bundles_store_idx, bool is_labeled, + const apsi::CryptoContext &crypto_context, uint32_t bins_per_bundle, + size_t label_size, uint32_t max_bin_size, uint32_t ps_low_degree, + bool overwrite, bool compressed) { + apsi::ThreadPoolMgr tpm; + + std::vector bundle_indices; + bundle_indices.reserve(bundle_indices_set.size()); + std::copy(bundle_indices_set.begin(), bundle_indices_set.end(), + std::back_inserter(bundle_indices)); + std::sort(bundle_indices.begin(), bundle_indices.end()); + + // Run the threads on the partitions + std::vector> futures(bundle_indices.size()); + SPDLOG_INFO("Launching {} insert-or-assign worker tasks", + bundle_indices.size()); + size_t future_idx = 0; + for (auto &bundle_idx : bundle_indices) { + futures[future_idx++] = tpm.thread_pool().enqueue([&, bundle_idx]() { + InsertOrAssignWorker(indices_store, indices_count, bundles_store, + bundles_store_idx, is_labeled, crypto_context, + static_cast(bundle_idx), bins_per_bundle, + label_size, max_bin_size, ps_low_degree, overwrite, + compressed); + }); + } + + // Wait for the tasks to finish + for (auto &f : futures) { + f.get(); + } + + SPDLOG_INFO("Finished insert-or-assign worker tasks"); +} + +constexpr char kMetaInfoStoreName[] = "db_meta_info"; +constexpr char kServerDataCount[] = "server_data_count"; + +constexpr char kMemoryStoreFlag[] = "::memory"; + +} // namespace + +SenderKvDB::SenderKvDB(const apsi::PSIParams ¶ms, + yacl::ByteContainerView oprf_key, + std::string_view kv_store_path, size_t label_byte_count, + size_t nonce_byte_count, bool compressed) + : ISenderDB(params, oprf_key, label_byte_count, nonce_byte_count, + compressed), + kv_store_path_(kv_store_path) { + // Reset the SenderDB data structures + clear(); + + bundles_store_.resize(params_.bundle_idx_count()); + bundles_store_idx_.resize(params_.bundle_idx_count()); + + if (kv_store_path == kMemoryStoreFlag) { + meta_info_store_ = std::make_shared(); + + for (size_t i = 0; i < bundles_store_.size(); ++i) { + std::shared_ptr kv_store = + std::make_shared(); + + bundles_store_[i] = std::make_shared(kv_store); + } + } else { + std::string meta_store_name = + fmt::format("{}/{}", kv_store_path, kMetaInfoStoreName); + + meta_info_store_ = + std::make_shared(false, meta_store_name); + + for (size_t i = 0; i < bundles_store_.size(); ++i) { + std::string bundle_store_name = + fmt::format("{}/bundle_{}", kv_store_path_, i); + + std::shared_ptr kv_store = + std::make_shared(false, bundle_store_name); + + bundles_store_[i] = std::make_shared(kv_store); + } + } + + try { + yacl::Buffer temp_value; + + meta_info_store_->Get(kServerDataCount, &temp_value); + item_count_ = std::stoul(std::string(std::string_view( + reinterpret_cast(temp_value.data()), temp_value.size()))); + + for (size_t i = 0; i < bundles_store_idx_.size(); ++i) { + bundles_store_idx_[i] = bundles_store_[i]->Count(); + } + + } catch (const std::exception &e) { + SPDLOG_INFO("key item_count no value"); + } +} + +size_t SenderKvDB::GetBinBundleCount(uint32_t bundle_idx) const { + // Lock the database for reading + auto lock = GetReaderLock(); + + // return bin_bundles_.at(seal::util::safe_cast(bundle_idx)).size(); + return bundles_store_idx_[seal::util::safe_cast(bundle_idx)]; +} + +size_t SenderKvDB::GetBinBundleCount() const { + // Lock the database for reading + auto lock = GetReaderLock(); + + // Compute the total number of BinBundles + // return std::accumulate(bin_bundles_.cbegin(), bin_bundles_.cend(), + // static_cast(0), + // [&](auto &a, auto &b) { return a + b.size(); }); + return std::accumulate(bundles_store_idx_.cbegin(), bundles_store_idx_.cend(), + static_cast(0), + [&](auto &a, auto &b) { return a + b; }); +} + +void SenderKvDB::ClearInternal() { + item_count_ = 0; + + // Reset the stripped_ flag + stripped_ = false; + // TODO(changjun): delete kv store +} + +void SenderKvDB::clear() { + // Lock the database for writing + auto lock = GetWriterLock(); + + ClearInternal(); +} + +void SenderKvDB::GenerateCaches() { + STOPWATCH(sender_stopwatch, "SenderDB::GenerateCaches"); + SPDLOG_INFO("Start generating bin bundle caches"); + + SPDLOG_INFO("Finished generating bin bundle caches"); +} + +std::shared_ptr SenderKvDB::GetBinBundleAt( + uint32_t bundle_idx, size_t cache_idx) { + yacl::Buffer value; + + bool get_status = bundles_store_[bundle_idx]->Get(cache_idx, &value); + + SPU_ENFORCE(get_status); + + size_t label_size = labeled_psi::ComputeLabelSize( + nonce_byte_count_ + label_byte_count_, params_); + + uint32_t bins_per_bundle = params_.bins_per_bundle(); + uint32_t max_bin_size = params_.table_params().max_items_per_bin; + uint32_t ps_low_degree = params_.query_params().ps_low_degree; + + bool compressed = false; + + std::shared_ptr load_bin_bundle = + std::make_shared( + crypto_context_, label_size, max_bin_size, ps_low_degree, + bins_per_bundle, compressed, false); + + gsl::span value_span = { + reinterpret_cast(value.data()), + gsl::narrow_cast::size_type>(value.size())}; + std::pair load_ret = + load_bin_bundle->load(value_span); + + SPU_ENFORCE(load_ret.first == cache_idx); + + // check cache is valid + if (load_bin_bundle->cache_invalid()) { + load_bin_bundle->regen_cache(); + } + + return load_bin_bundle; +} + +void SenderKvDB::strip() { + // Lock the database for writing + auto lock = GetWriterLock(); + + stripped_ = true; + + memset(oprf_key_.data(), 0, oprf_key_.size()); + + SPDLOG_INFO("SenderDB has been stripped"); +} + +void SenderKvDB::InsertOrAssign( + const std::shared_ptr &batch_provider, size_t batch_size) { + size_t batch_count = 0; + size_t indices_count = 0; + + std::shared_ptr kv_store; + + if (kv_store_path_ == kMemoryStoreFlag) { + kv_store = std::make_shared(); + } else { + kv_store = std::make_shared(true); + } + + std::shared_ptr items_oprf_store = + std::make_shared(kv_store); + + std::set bundle_indices_set; + + // Dispatch the insertion + uint32_t bins_per_bundle = params_.bins_per_bundle(); + uint32_t max_bin_size = params_.table_params().max_items_per_bin; + uint32_t ps_low_degree = params_.query_params().ps_low_degree; + + while (true) { + std::vector batch_items; + std::vector batch_labels; + + if (IsLabeled()) { + std::tie(batch_items, batch_labels) = + batch_provider->ReadNextBatchWithLabel(batch_size); + } else { + batch_items = batch_provider->ReadNextBatch(batch_size); + } + + if (batch_items.empty()) { + SPDLOG_INFO( + "OPRF FullEvaluate and EncryptLabel Last batch count: " + "{}, item_count:{}", + batch_count, item_count_); + break; + } + SPDLOG_INFO("OPRF FullEvaluate and EncryptLabel batch_count: {}", + batch_count); + + std::vector oprf_out = oprf_server_->FullEvaluate(batch_items); + + std::vector>> + data_with_indices_vec; + + if (IsLabeled()) { + data_with_indices_vec.resize(oprf_out.size()); + size_t key_offset_pos = sizeof(apsi::Item::value_type); + + // Set up Kuku hash functions + auto hash_funcs = labeled_psi::HashFunctions(params_); + + yacl::parallel_for( + 0, oprf_out.size(), 1, [&](int64_t begin, int64_t end) { + for (int64_t idx = begin; idx < end; ++idx) { + apsi::Item::value_type value{}; + std::memcpy(value.data(), &oprf_out[idx][0], value.size()); + + apsi::HashedItem hashed_item(value); + + apsi::LabelKey key; + std::memcpy(key.data(), &oprf_out[idx][key_offset_pos], + apsi::label_key_byte_count); + + apsi::Label label_with_padding = + PaddingData(batch_labels[idx], label_byte_count_); + + apsi::EncryptedLabel encrypted_label = apsi::util::encrypt_label( + label_with_padding, key, label_byte_count_, + nonce_byte_count_); + + std::pair + item_label_pair = + std::make_pair(hashed_item, encrypted_label); + + data_with_indices_vec[idx] = + PreprocessLabeledData(item_label_pair, params_, hash_funcs); + } + }); + + for (size_t i = 0; i < oprf_out.size(); ++i) { + for (size_t j = 0; j < data_with_indices_vec[i].size(); ++j) { + std::string indices_buffer = + SerializeDataLabelWithIndices(data_with_indices_vec[i][j]); + + items_oprf_store->Put(indices_count + j, indices_buffer); + + size_t cuckoo_idx = data_with_indices_vec[i][j].second; + size_t bin_idx, bundle_idx; + std::tie(bin_idx, bundle_idx) = + labeled_psi::UnpackCuckooIdx(cuckoo_idx, bins_per_bundle); + bundle_indices_set.insert(bundle_idx); + } + + indices_count += data_with_indices_vec[i].size(); + } + + } else { + for (size_t i = 0; i < oprf_out.size(); ++i) { + // + apsi::Item::value_type value{}; + std::memcpy(value.data(), &oprf_out[i][0], value.size()); + + apsi::HashedItem hashed_item(value); + + std::vector> data_with_indices = + PreprocessUnlabeledData(hashed_item, params_); + + for (size_t j = 0; j < data_with_indices.size(); ++j) { + std::string indices_buffer = + SerializeDataWithIndices(data_with_indices[j]); + + items_oprf_store->Put(indices_count + j, indices_buffer); + + size_t cuckoo_idx = data_with_indices[j].second; + + size_t bin_idx, bundle_idx; + std::tie(bin_idx, bundle_idx) = + labeled_psi::UnpackCuckooIdx(cuckoo_idx, bins_per_bundle); + + bundle_indices_set.insert(bundle_idx); + } + indices_count += data_with_indices.size(); + } + } + item_count_ += batch_items.size(); + + batch_count++; + } + meta_info_store_->Put(kServerDataCount, std::to_string(item_count_)); + + size_t label_size = 0; + if (IsLabeled()) { + label_size = labeled_psi::ComputeLabelSize( + nonce_byte_count_ + label_byte_count_, params_); + } + + DispatchInsertOrAssign( + items_oprf_store, indices_count, bundle_indices_set, &bundles_store_, + &bundles_store_idx_, IsLabeled(), crypto_context_, bins_per_bundle, + label_size, /* label size */ + max_bin_size, ps_low_degree, false, /* don't overwrite items */ + compressed_); + + SPDLOG_INFO("Finished inserting {} items in SenderDB", item_count_); +} + +} // namespace spu::psi diff --git a/libspu/psi/core/labeled_psi/sender_kvdb.h b/libspu/psi/core/labeled_psi/sender_kvdb.h new file mode 100644 index 00000000..febc1f52 --- /dev/null +++ b/libspu/psi/core/labeled_psi/sender_kvdb.h @@ -0,0 +1,164 @@ +// Copyright 2021 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// code reference https://github.com/microsoft/APSI/sender/sender_db.h +// Licensed under the MIT license. + +#pragma once + +// STD +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// GSL +#include "gsl/span" + +// APSI +#include "apsi/bin_bundle.h" +#include "apsi/crypto_context.h" +#include "apsi/item.h" +#include "apsi/psi_params.h" +#include "yacl/base/byte_container_view.h" +#include "yacl/io/kv/leveldb_kvstore.h" +#include "yacl/io/kv/memory_kvstore.h" + +#include "libspu/psi/core/ecdh_oprf/basic_ecdh_oprf.h" +#include "libspu/psi/core/labeled_psi/sender_db.h" +#include "libspu/psi/utils/batch_provider.h" + +// SEAL +#include "seal/plaintext.h" +#include "seal/util/locks.h" +#include "spdlog/spdlog.h" + +namespace spu::psi { + +/** +A SenderDB maintains an in-memory representation of the sender's set of items +and labels (in labeled mode). This data is not simply copied into the SenderDB +data structures, but also preprocessed heavily to allow for faster online +computation time. Since inserting a large number of new items into a SenderDB +can take time, it is not recommended to recreate the SenderDB when the +database changes a little bit. Instead, the class supports fast update and +deletion operations that should be preferred: SenderDB::InsertOrAssign and +SenderDB::remove. + +The SenderDB constructor allows the label byte count to be specified; +unlabeled mode is activated by setting the label byte count to zero. It is +possible to optionally specify the size of the nonce used in encrypting the +labels, but this is best left to its default value unless the user is +absolutely sure of what they are doing. + +The SenderDB requires substantially more memory than the raw data would. Part +of that memory can automatically be compressed when it is not in use; this +feature is enabled by default, and can be disabled when constructing the +SenderDB. The downside of in-memory compression is a performance reduction +from decompressing parts of the data when they are used, and recompressing +them if they are updated. +*/ +class SenderKvDB : public ISenderDB { + public: + /** + Creates a new SenderDB. + */ + SenderKvDB(const apsi::PSIParams ¶ms, yacl::ByteContainerView oprf_key, + std::string_view kv_store_path = "", + std::size_t label_byte_count = 0, + std::size_t nonce_byte_count = 16, bool compressed = true); + + /** + Clears the database. Every item and label will be removed. The OPRF key is + unchanged. + */ + void clear(); + + /** + Strips the SenderDB of all information not needed for serving a query. + Returns a copy of the OPRF key and clears it from the SenderDB. + */ + void strip(); + + /** + * @brief Insert data from BatchProvider + * + * @param batch_provider + */ + void InsertOrAssign(const std::shared_ptr &batch_provider, + size_t batch_size); + + /** + Clears the database and inserts the given data. This function can be used + only on a labeled SenderDB instance. + */ + void SetData(const std::vector &keys, + const std::vector &labels) { + std::shared_ptr batch_provider = + std::make_shared(keys, labels); + + SetData(batch_provider); + } + + /** + Clears the database and inserts the given data. This function can be used + only on an unlabeled SenderDB instance. + */ + void SetData(const std::vector &data) { + std::shared_ptr batch_provider = + std::make_shared(data); + + SetData(batch_provider); + } + + void SetData(const std::shared_ptr &batch_provider, + size_t batch_size = 500000) override { + clear(); + InsertOrAssign(batch_provider, batch_size); + } + + /** + Returns the bundle at the given bundle index. + */ + std::shared_ptr GetBinBundleAt( + std::uint32_t bundle_idx, size_t cache_idx) override; + + /** + Returns the total number of bin bundles at a specific bundle index. + */ + std::size_t GetBinBundleCount(std::uint32_t bundle_idx) const override; + + /** + Returns the total number of bin bundles. + */ + std::size_t GetBinBundleCount() const override; + + private: + void ClearInternal(); + + void GenerateCaches(); + + std::string kv_store_path_; + std::shared_ptr meta_info_store_; + + std::vector> bundles_store_; + std::vector bundles_store_idx_; +}; // class SenderDB + +} // namespace spu::psi diff --git a/libspu/psi/core/labeled_psi/sender_memdb.cc b/libspu/psi/core/labeled_psi/sender_memdb.cc new file mode 100644 index 00000000..91f05c4b --- /dev/null +++ b/libspu/psi/core/labeled_psi/sender_memdb.cc @@ -0,0 +1,958 @@ +// Copyright 2021 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// code reference https://github.com/microsoft/APSI/sender/sender_db.cpp +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +// we are using our own OPRF, the reason is we wanna make the oprf +// switchable between secp256k1, sm2 or other types + +// STD +#include +#include +#include +#include +#include +#include +#include +#include + +// APSI +#include "apsi/psi_params.h" +#include "apsi/thread_pool_mgr.h" +#include "apsi/util/db_encoding.h" +#include "apsi/util/label_encryptor.h" +#include "apsi/util/utils.h" +#include "spdlog/spdlog.h" + +#include "libspu/psi/core/ecdh_oprf/ecdh_oprf_selector.h" +#include "libspu/psi/core/labeled_psi/sender_memdb.h" +#include "libspu/psi/utils/utils.h" + +// Kuku +#include "kuku/locfunc.h" + +// SEAL +#include "absl/strings/escaping.h" +#include "seal/util/common.h" +#include "seal/util/streambuf.h" +#include "yacl/utils/parallel.h" + +namespace spu::psi { + +namespace { + +/** +Converts each given Item-Label pair in between the given iterators into its +algebraic form, i.e., a sequence of felt-felt pairs. Also computes each Item's +cuckoo index. +*/ +std::vector> PreprocessLabeledData( + const std::vector>::const_iterator begin, + const std::vector< + std::pair>::const_iterator end, + const apsi::PSIParams ¶ms) { + STOPWATCH(sender_stopwatch, "preprocess_labeled_data"); + SPDLOG_DEBUG("Start preprocessing {} labeled items", distance(begin, end)); + + // Some variables we'll need + size_t bins_per_item = params.item_params().felts_per_item; + size_t item_bit_count = params.item_bit_count(); + + // Set up Kuku hash functions + auto hash_funcs = labeled_psi::HashFunctions(params); + + // Calculate the cuckoo indices for each item. Store every pair of + // (item-label, cuckoo_idx) in a vector. Later, we're gonna sort this vector + // by cuckoo_idx and use the result to parallelize the work of inserting the + // items into BinBundles. + std::vector> data_with_indices; + + for (auto it = begin; it != end; it++) { + const std::pair &item_label_pair = + *it; + + // Serialize the data into field elements + const apsi::HashedItem &item = item_label_pair.first; + const apsi::EncryptedLabel &label = item_label_pair.second; + apsi::util::AlgItemLabel alg_item_label = algebraize_item_label( + item, label, item_bit_count, params.seal_params().plain_modulus()); + + std::vector> temp_data; + std::set loc_set; + + // Get the cuckoo table locations for this item and add to + // data_with_indices + for (auto location : labeled_psi::AllLocations(hash_funcs, item)) { + // The current hash value is an index into a table of Items. In + // reality our BinBundles are tables of bins, which contain chunks + // of items. How many chunks? bins_per_item many chunks + if (loc_set.find(location) == loc_set.end()) { + size_t bin_idx = location * bins_per_item; + + // Store the data along with its index + temp_data.emplace_back(alg_item_label, bin_idx); + loc_set.insert(location); + } + } + + data_with_indices.insert(data_with_indices.end(), temp_data.begin(), + temp_data.end()); + } + + SPDLOG_DEBUG("Finished preprocessing {} labeled items", distance(begin, end)); + + return data_with_indices; +} + +/** +Converts each given Item into its algebraic form, i.e., a sequence of +felt-monostate pairs. Also computes each Item's cuckoo index. +*/ +std::vector> PreprocessUnlabeledData( + const std::vector::const_iterator begin, + const std::vector::const_iterator end, + const apsi::PSIParams ¶ms) { + STOPWATCH(sender_stopwatch, "preprocess_unlabeled_data"); + SPDLOG_DEBUG("Start preprocessing {} unlabeled items", distance(begin, end)); + + // Some variables we'll need + size_t bins_per_item = params.item_params().felts_per_item; + size_t item_bit_count = params.item_bit_count(); + + // Set up Kuku hash functions + auto hash_funcs = labeled_psi::HashFunctions(params); + + // Calculate the cuckoo indices for each item. Store every pair of + // (item-label, cuckoo_idx) in a vector. Later, we're gonna sort this vector + // by cuckoo_idx and use the result to parallelize the work of inserting the + // items into BinBundles. + std::vector> data_with_indices; + for (auto it = begin; it != end; it++) { + const apsi::HashedItem &item = *it; + + // Serialize the data into field elements + apsi::util::AlgItem alg_item = algebraize_item( + item, item_bit_count, params.seal_params().plain_modulus()); + + // Get the cuckoo table locations for this item and add to data_with_indices + for (auto location : labeled_psi::AllLocations(hash_funcs, item)) { + // The current hash value is an index into a table of Items. In reality + // our BinBundles are tables of bins, which contain chunks of items. How + // many chunks? bins_per_item many chunks + size_t bin_idx = location * bins_per_item; + + // Store the data along with its index + data_with_indices.emplace_back(alg_item, bin_idx); + } + } + + SPDLOG_DEBUG("Finished preprocessing {} unlabeled items", + distance(begin, end)); + + return data_with_indices; +} + +/** +Converts given Item into its algebraic form, i.e., a sequence of felt-monostate +pairs. Also computes the Item's cuckoo index. +*/ +std::vector> PreprocessUnlabeledData( + const apsi::HashedItem &item, const apsi::PSIParams ¶ms) { + std::vector item_singleton{item}; + return PreprocessUnlabeledData(item_singleton.begin(), item_singleton.end(), + params); +} + +/** +Inserts the given items and corresponding labels into bin_bundles at their +respective cuckoo indices. It will only insert the data with bundle index in the +half-open range range indicated by work_range. If inserting into a BinBundle +would make the number of items in a bin larger than max_bin_size, this function +will create and insert a new BinBundle. If overwrite is set, this will overwrite +the labels if it finds an AlgItemLabel that matches the input perfectly. +*/ +template +void InsertOrAssignWorker( + const std::vector> &data_with_indices, + std::vector>> + *bin_bundles, + const apsi::CryptoContext &crypto_context, uint32_t bundle_index, + uint32_t bins_per_bundle, size_t label_size, size_t max_bin_size, + size_t ps_low_degree, bool overwrite, bool compressed) { + STOPWATCH(sender_stopwatch, "insert_or_assign_worker"); + SPDLOG_DEBUG( + "Insert-or-Assign worker for bundle index {}; mode of operation: {}", + bundle_index, overwrite ? "overwriting existing" : "inserting new"); + + // Iteratively insert each item-label pair at the given cuckoo index + for (auto &data_with_idx : data_with_indices) { + const T &data = data_with_idx.first; + + // Get the bundle index + size_t cuckoo_idx = data_with_idx.second; + size_t bin_idx; + size_t bundle_idx; + std::tie(bin_idx, bundle_idx) = + labeled_psi::UnpackCuckooIdx(cuckoo_idx, bins_per_bundle); + + // If the bundle_idx isn't in the prescribed range, don't try to insert this + // data + if (bundle_idx != bundle_index) { + // Dealing with this bundle index is not our job + continue; + } + + // Get the bundle set at the given bundle index + std::vector> &bundle_set = + (*bin_bundles)[bundle_idx]; + + // Try to insert or overwrite these field elements in an existing BinBundle + // at this bundle index. Keep track of whether or not we succeed. + bool written = false; + for (auto bundle_it = bundle_set.rbegin(); bundle_it != bundle_set.rend(); + bundle_it++) { + // If we're supposed to overwrite, try to overwrite. One of these + // BinBundles has to have the data we're trying to overwrite. + if (overwrite) { + // If we successfully overwrote, we're done with this bundle + written = (*bundle_it)->try_multi_overwrite(data, bin_idx); + if (written) { + break; + } + } + + // Do a dry-run insertion and see if the new largest bin size in the range + // exceeds the limit + int32_t new_largest_bin_size = + (*bundle_it)->multi_insert_dry_run(data, bin_idx); + + // Check if inserting would violate the max bin size constraint + if (new_largest_bin_size > 0 && + seal::util::safe_cast(new_largest_bin_size) < max_bin_size) { + // All good + (*bundle_it)->multi_insert_for_real(data, bin_idx); + written = true; + break; + } + } + + // We tried to overwrite an item that doesn't exist. This should never + // happen + if (overwrite && !written) { + SPDLOG_ERROR( + "Insert-or-Assign worker: " + "failed to overwrite item at bundle index {} because the item was " + "not found", + bundle_idx); + SPU_THROW("tried to overwrite non-existent item"); + } + + // If we had conflicts everywhere when trying to insert, then we need to + // make a new BinBundle and insert the data there + if (!written) { + // Make a fresh BinBundle and insert + std::shared_ptr new_bin_bundle = + std::make_shared( + crypto_context, label_size, max_bin_size, ps_low_degree, + bins_per_bundle, compressed, false); + int res = new_bin_bundle->multi_insert_for_real(data, bin_idx); + + // If even that failed, I don't know what could've happened + if (res < 0) { + SPDLOG_ERROR( + "Insert-or-Assign worker: " + "failed to insert item into a new BinBundle at bundle index {}", + bundle_idx); + SPU_THROW("failed to insert item into a new BinBundle"); + } + + // Push a new BinBundle to the set of BinBundles at this bundle index + bundle_set.push_back(new_bin_bundle); + } + } + + SPDLOG_DEBUG("Insert-or-Assign worker: finished processing bundle index {}", + bundle_index); +} + +/** +Takes algebraized data to be inserted, splits it up, and distributes it so that +thread_count many threads can all insert in parallel. If overwrite is set, this +will overwrite the labels if it finds an AlgItemLabel that matches the input +perfectly. +*/ +template +void DispatchInsertOrAssign( + const std::vector> &data_with_indices, + std::vector>> + *bin_bundles, + const apsi::CryptoContext &crypto_context, uint32_t bins_per_bundle, + size_t label_size, uint32_t max_bin_size, uint32_t ps_low_degree, + bool overwrite, bool compressed) { + apsi::ThreadPoolMgr tpm; + + // Collect the bundle indices and partition them into thread_count many + // partitions. By some uniformity assumption, the number of things to insert + // per partition should be roughly the same. Note that the contents of + // bundle_indices is always sorted (increasing order). + std::set bundle_indices_set; + for (auto &data_with_idx : data_with_indices) { + size_t cuckoo_idx = data_with_idx.second; + size_t bin_idx; + size_t bundle_idx; + std::tie(bin_idx, bundle_idx) = + labeled_psi::UnpackCuckooIdx(cuckoo_idx, bins_per_bundle); + bundle_indices_set.insert(bundle_idx); + } + + // Copy the set of indices into a vector and sort so each thread processes a + // range of indices + std::vector bundle_indices; + bundle_indices.reserve(bundle_indices_set.size()); + copy(bundle_indices_set.begin(), bundle_indices_set.end(), + back_inserter(bundle_indices)); + std::sort(bundle_indices.begin(), bundle_indices.end()); + + // Run the threads on the partitions + std::vector> futures(bundle_indices.size()); + SPDLOG_INFO("Launching {} insert-or-assign worker tasks", + bundle_indices.size()); + size_t future_idx = 0; + for (auto &bundle_idx : bundle_indices) { + futures[future_idx++] = tpm.thread_pool().enqueue([&, bundle_idx]() { + InsertOrAssignWorker(data_with_indices, bin_bundles, crypto_context, + static_cast(bundle_idx), bins_per_bundle, + label_size, max_bin_size, ps_low_degree, overwrite, + compressed); + }); + } + + // Wait for the tasks to finish + for (auto &f : futures) { + f.get(); + } + + SPDLOG_INFO("Finished insert-or-assign worker tasks"); +} + +/** +Removes the given items and corresponding labels from bin_bundles at their +respective cuckoo indices. +*/ +void RemoveWorker( + const std::vector> + &data_with_indices, + std::vector>> + *bin_bundles, + uint32_t bundle_index, uint32_t bins_per_bundle) { + STOPWATCH(sender_stopwatch, "remove_worker"); + SPDLOG_INFO("Remove worker [{}]", bundle_index); + + // Iteratively remove each item-label pair at the given cuckoo index + for (const auto &data_with_idx : data_with_indices) { + // Get the bundle index + size_t cuckoo_idx = data_with_idx.second; + size_t bin_idx; + size_t bundle_idx; + std::tie(bin_idx, bundle_idx) = + labeled_psi::UnpackCuckooIdx(cuckoo_idx, bins_per_bundle); + + // If the bundle_idx isn't in the prescribed range, don't try to remove this + // data + if (bundle_idx != bundle_index) { + // Dealing with this bundle index is not our job + continue; + } + + // Get the bundle set at the given bundle index + std::vector> &bundle_set = + (*bin_bundles)[bundle_idx]; + + // Try to remove these field elements from an existing BinBundle at this + // bundle index. Keep track of whether or not we succeed. + bool removed = false; + for (auto &bundle : bundle_set) { + // If we successfully removed, we're done with this bundle + removed = bundle->try_multi_remove(data_with_idx.first, bin_idx); + if (removed) { + break; + } + } + + // We may have produced some empty BinBundles so just remove them all + auto rem_it = std::remove_if(bundle_set.begin(), bundle_set.end(), + [](auto &bundle) { return bundle->empty(); }); + bundle_set.erase(rem_it, bundle_set.end()); + + // We tried to remove an item that doesn't exist. This should never happen + if (!removed) { + SPDLOG_ERROR( + "Remove worker: " + "failed to remove item at bundle index {} because the item was not " + "found", + bundle_idx); + SPU_THROW("failed to remove item"); + } + } + + SPDLOG_INFO("Remove worker: finished processing bundle index {}", + bundle_index); +} + +/** +Takes algebraized data to be removed, splits it up, and distributes it so that +thread_count many threads can all remove in parallel. +*/ +void DispatchRemove( + const std::vector> + &data_with_indices, + std::vector>> + *bin_bundles, + uint32_t bins_per_bundle) { + apsi::ThreadPoolMgr tpm; + + // Collect the bundle indices and partition them into thread_count many + // partitions. By some uniformity assumption, the number of things to remove + // per partition should be roughly the same. Note that the contents of + // bundle_indices is always sorted (increasing order). + std::set bundle_indices_set; + for (const auto &data_with_idx : data_with_indices) { + size_t cuckoo_idx = data_with_idx.second; + size_t bin_idx; + size_t bundle_idx; + std::tie(bin_idx, bundle_idx) = + labeled_psi::UnpackCuckooIdx(cuckoo_idx, bins_per_bundle); + bundle_indices_set.insert(bundle_idx); + } + + // Copy the set of indices into a vector and sort so each thread processes a + // range of indices + std::vector bundle_indices; + bundle_indices.reserve(bundle_indices_set.size()); + copy(bundle_indices_set.begin(), bundle_indices_set.end(), + back_inserter(bundle_indices)); + sort(bundle_indices.begin(), bundle_indices.end()); + + // Run the threads on the partitions + std::vector> futures(bundle_indices.size()); + SPDLOG_INFO("Launching {} remove worker tasks", bundle_indices.size()); + size_t future_idx = 0; + for (auto &bundle_idx : bundle_indices) { + futures[future_idx++] = tpm.thread_pool().enqueue([&]() { + RemoveWorker(data_with_indices, bin_bundles, + static_cast(bundle_idx), bins_per_bundle); + }); + } + + // Wait for the tasks to finish + for (auto &f : futures) { + f.get(); + } +} + +} // namespace + +SenderMemDB::SenderMemDB(const apsi::PSIParams ¶ms, + yacl::ByteContainerView oprf_key, + size_t label_byte_count, size_t nonce_byte_count, + bool compressed) + : ISenderDB(params, oprf_key, label_byte_count, nonce_byte_count, + compressed) { + // Reset the SenderDB data structures + clear(); +} + +size_t SenderMemDB::GetBinBundleCount(uint32_t bundle_idx) const { + // Lock the database for reading + auto lock = GetReaderLock(); + + return bin_bundles_.at(seal::util::safe_cast(bundle_idx)).size(); +} + +size_t SenderMemDB::GetBinBundleCount() const { + // Lock the database for reading + auto lock = GetReaderLock(); + + // Compute the total number of BinBundles + return std::accumulate(bin_bundles_.cbegin(), bin_bundles_.cend(), + static_cast(0), + [&](auto &a, auto &b) { return a + b.size(); }); +} + +void SenderMemDB::ClearInternal() { + // Assume the SenderDB is already locked for writing + + // Clear the set of inserted items + hashed_items_.clear(); + item_count_ = 0; + + // Clear the BinBundles + bin_bundles_.clear(); + bin_bundles_.resize(params_.bundle_idx_count()); + + // Reset the stripped_ flag + stripped_ = false; +} + +void SenderMemDB::clear() { + if (!hashed_items_.empty()) { + SPDLOG_INFO("Removing {} items pairs from SenderDB", hashed_items_.size()); + } + + // Lock the database for writing + auto lock = GetWriterLock(); + + ClearInternal(); +} + +void SenderMemDB::GenerateCaches() { + STOPWATCH(sender_stopwatch, "SenderDB::GenerateCaches"); + SPDLOG_INFO("Start generating bin bundle caches"); + + apsi::ThreadPoolMgr tpm; + + std::vector> futures; + + for (auto &bundle_idx : bin_bundles_) { + futures.push_back(tpm.thread_pool().enqueue([&bundle_idx]() { + for (auto &bb : bundle_idx) { + bb->regen_cache(); + } + })); + } + + // Wait for the tasks to finish + for (auto &f : futures) { + f.get(); + } + + SPDLOG_INFO("Finished generating bin bundle caches"); +} + +void SenderMemDB::strip() { + // Lock the database for writing + auto lock = GetWriterLock(); + + stripped_ = true; + + memset(oprf_key_.data(), 0, oprf_key_.size()); + hashed_items_.clear(); + + apsi::ThreadPoolMgr tpm; + + std::vector> futures; + for (auto &bundle_idx : bin_bundles_) { + for (auto &bb : bundle_idx) { + futures.push_back(tpm.thread_pool().enqueue([&bb]() { bb->strip(); })); + } + } + + // Wait for the tasks to finish + for (auto &f : futures) { + f.get(); + } + + SPDLOG_INFO("SenderDB has been stripped"); +} + +void SenderMemDB::InsertOrAssign(const std::vector &keys, + const std::vector &labels) { + if (stripped_) { + SPDLOG_ERROR("Cannot insert data to a stripped SenderDB"); + SPU_THROW("failed to insert data"); + } + if (!IsLabeled()) { + SPDLOG_ERROR( + "Attempted to insert labeled data but this is an unlabeled SenderDB"); + SPU_THROW("failed to insert data"); + } + + SPDLOG_INFO("Start inserting {} items in SenderDB", keys.size()); + + // First compute the hashes for the input data + std::vector oprf_out = oprf_server_->FullEvaluate(keys); + std::vector> hashed_data( + oprf_out.size()); + + yacl::parallel_for(0, oprf_out.size(), 1, [&](int64_t begin, int64_t end) { + for (int64_t i = begin; i < end; ++i) { + apsi::HashedItem hashed_item; + std::memcpy(hashed_item.value().data(), oprf_out[i].data(), + hashed_item.value().size()); + + apsi::LabelKey key; + std::memcpy(key.data(), &oprf_out[i][hashed_item.value().size()], + key.size()); + + apsi::Label label_with_padding = + PaddingData(labels[i], label_byte_count_); + + apsi::EncryptedLabel encrypted_label = encrypt_label( + label_with_padding, key, label_byte_count_, nonce_byte_count_); + + hashed_data[i] = std::make_pair(hashed_item, encrypted_label); + } + }); + + // Lock the database for writing + auto lock = GetWriterLock(); + + // We need to know which items are new and which are old, since we have to + // tell dispatch_insert_or_assign when to have an overwrite-on-collision + // versus add-binbundle-on-collision policy. + auto new_data_end = std::remove_if( + hashed_data.begin(), hashed_data.end(), [&](const auto &item_label_pair) { + bool found = + hashed_items_.find(item_label_pair.first) != hashed_items_.end(); + if (!found) { + // Add to hashed_items_ already at this point! + hashed_items_.insert(item_label_pair.first); + item_count_++; + } + + // Remove those that were found + return found; + }); + + // Dispatch the insertion, first for the new data, then for the data we're + // gonna overwrite + uint32_t bins_per_bundle = params_.bins_per_bundle(); + uint32_t max_bin_size = params_.table_params().max_items_per_bin; + uint32_t ps_low_degree = params_.query_params().ps_low_degree; + + // Compute the label size; this ceil(effective_label_bit_count / + // item_bit_count) + size_t label_size = labeled_psi::ComputeLabelSize( + nonce_byte_count_ + label_byte_count_, params_); + + auto new_item_count = distance(hashed_data.begin(), new_data_end); + auto existing_item_count = distance(new_data_end, hashed_data.end()); + + if (existing_item_count != 0) { + SPDLOG_INFO("Found {} existing items to replace in SenderDB", + existing_item_count); + + // Break the data into field element representation. Also compute the items' + // cuckoo indices. + std::vector> data_with_indices = + PreprocessLabeledData(new_data_end, hashed_data.end(), params_); + + DispatchInsertOrAssign(data_with_indices, &bin_bundles_, crypto_context_, + bins_per_bundle, label_size, max_bin_size, + ps_low_degree, true, /* overwrite items */ + compressed_); + + // Release memory that is no longer needed + hashed_data.erase(new_data_end, hashed_data.end()); + } + + if (new_item_count != 0) { + SPDLOG_INFO("Found {} new items to insert in SenderDB", new_item_count); + + // Process and add the new data. Break the data into field element + // representation. Also compute the items' cuckoo indices. + std::vector> data_with_indices = + PreprocessLabeledData(hashed_data.begin(), hashed_data.end(), params_); + + DispatchInsertOrAssign(data_with_indices, &bin_bundles_, crypto_context_, + bins_per_bundle, label_size, max_bin_size, + ps_low_degree, false, /* don't overwrite items */ + compressed_); + } + + SPDLOG_INFO("Finished inserting {} items in SenderDB", keys.size()); +} + +void SenderMemDB::InsertOrAssign(const std::vector &data) { + if (stripped_) { + SPDLOG_ERROR("Cannot insert data to a stripped SenderDB"); + SPU_THROW("failed to insert data"); + } + if (IsLabeled()) { + SPDLOG_ERROR( + "Attempted to insert unlabeled data but this is a labeled SenderDB"); + SPU_THROW("failed to insert data"); + } + + STOPWATCH(sender_stopwatch, "SenderMemDB::insert_or_assign (unlabeled)"); + SPDLOG_INFO("Start inserting {} items in SenderMemDB", data.size()); + + // First compute the hashes for the input data + std::vector oprf_out = oprf_server_->FullEvaluate(data); + + std::vector hashed_data; + for (const auto &out : oprf_out) { + apsi::Item::value_type value{}; + std::memcpy(value.data(), out.data(), value.size()); + + hashed_data.emplace_back(value); + } + + // Lock the database for writing + auto lock = GetWriterLock(); + + // We are not going to insert items that already appear in the database. + auto new_data_end = std::remove_if( + hashed_data.begin(), hashed_data.end(), [&](const auto &item) { + bool found = hashed_items_.find(item) != hashed_items_.end(); + if (!found) { + // Add to hashed_items_ already at this point! + hashed_items_.insert(item); + item_count_++; + } + + // Remove those that were found + return found; + }); + + // Erase the previously existing items from hashed_data; in unlabeled case + // there is nothing to do + hashed_data.erase(new_data_end, hashed_data.end()); + + SPDLOG_INFO("Found {} new items to insert in SenderDB", hashed_data.size()); + + // Break the new data down into its field element representation. Also compute + // the items' cuckoo indices. + std::vector> data_with_indices = + PreprocessUnlabeledData(hashed_data.begin(), hashed_data.end(), params_); + + // Dispatch the insertion + uint32_t bins_per_bundle = params_.bins_per_bundle(); + uint32_t max_bin_size = params_.table_params().max_items_per_bin; + uint32_t ps_low_degree = params_.query_params().ps_low_degree; + + DispatchInsertOrAssign(data_with_indices, &bin_bundles_, crypto_context_, + bins_per_bundle, 0, /* label size */ + max_bin_size, ps_low_degree, + false, /* don't overwrite items */ + compressed_); + + SPDLOG_INFO("Finished inserting {} items in SenderDB", data.size()); +} + +void SenderMemDB::remove(const std::vector &data) { + if (stripped_) { + SPDLOG_ERROR("Cannot remove data from a stripped SenderDB"); + SPU_THROW("failed to remove data"); + } + + STOPWATCH(sender_stopwatch, "SenderDB::remove"); + SPDLOG_INFO("Start removing {} items from SenderDB", data.size()); + + // First compute the hashes for the input data + // auto hashed_data = OPRFSender::ComputeHashes(data, oprf_key_); + std::vector data_str(data.size()); + for (size_t i = 0; i < data.size(); ++i) { + data_str[i].reserve(data[i].value().size()); + std::memcpy(data_str[i].data(), data[i].value().data(), + data[i].value().size()); + } + std::vector oprf_out = oprf_server_->FullEvaluate(data_str); + std::vector hashed_data; + for (const auto &out : oprf_out) { + apsi::Item::value_type value{}; + std::memcpy(value.data(), out.data(), value.size()); + + hashed_data.emplace_back(value); + } + + // Lock the database for writing + auto lock = GetWriterLock(); + + // Remove items that do not exist in the database. + auto existing_data_end = std::remove_if( + hashed_data.begin(), hashed_data.end(), [&](const auto &item) { + bool found = hashed_items_.find(item) != hashed_items_.end(); + if (found) { + // Remove from hashed_items_ already at this point! + hashed_items_.erase(item); + item_count_--; + } + + // Remove those that were not found + return !found; + }); + + // This distance is always non-negative + auto existing_item_count = + static_cast(distance(existing_data_end, hashed_data.end())); + if (existing_item_count != 0) { + SPDLOG_WARN("Ignoring {} items that are not present in the SenderDB", + existing_item_count); + } + + // Break the data down into its field element representation. Also compute the + // items' cuckoo indices. + std::vector> data_with_indices = + PreprocessUnlabeledData(hashed_data.begin(), hashed_data.end(), params_); + + // Dispatch the removal + uint32_t bins_per_bundle = params_.bins_per_bundle(); + DispatchRemove(data_with_indices, &bin_bundles_, bins_per_bundle); + + // Generate the BinBundle caches + GenerateCaches(); + + SPDLOG_INFO("Finished removing {} items from SenderDB", data.size()); +} + +bool SenderMemDB::HasItem(const apsi::Item &item) const { + if (stripped_) { + SPDLOG_ERROR( + "Cannot retrieve the presence of an item from a stripped SenderDB"); + SPU_THROW("failed to retrieve the presence of item"); + } + + // First compute the hash for the input item + // auto hashed_item = OPRFSender::ComputeHashes({&item, 1}, oprf_key_)[0]; + std::string item_str; + item_str.reserve(item.value().size()); + std::memcpy(item_str.data(), item.value().data(), item.value().size()); + std::string oprf_out = oprf_server_->FullEvaluate(item_str); + apsi::HashedItem hashed_item; + std::memcpy(hashed_item.value().data(), oprf_out.data(), + hashed_item.value().size()); + + // Lock the database for reading + auto lock = GetReaderLock(); + + return hashed_items_.find(hashed_item) != hashed_items_.end(); +} + +apsi::Label SenderMemDB::GetLabel(const apsi::Item &item) const { + if (stripped_) { + SPDLOG_ERROR("Cannot retrieve a label from a stripped SenderDB"); + SPU_THROW("failed to retrieve label"); + } + if (!IsLabeled()) { + SPDLOG_ERROR( + "Attempted to retrieve a label but this is an unlabeled SenderDB"); + SPU_THROW("failed to retrieve label"); + } + + // First compute the hash for the input item + apsi::HashedItem hashed_item; + apsi::LabelKey key; + // tie(hashed_item, key) = OPRFSender::GetItemHash(item, oprf_key_); + + std::string item_str; + item_str.reserve(item.value().size()); + std::memcpy(item_str.data(), item.value().data(), item.value().size()); + std::string oprf_out = oprf_server_->FullEvaluate(item_str); + std::memcpy(hashed_item.value().data(), oprf_out.data(), + hashed_item.value().size()); + + // Lock the database for reading + auto lock = GetReaderLock(); + + // Check if this item is in the DB. If not, throw an exception + if (hashed_items_.find(hashed_item) == hashed_items_.end()) { + SPDLOG_ERROR( + "Cannot retrieve label for an item that is not in the SenderDB"); + SPU_THROW("failed to retrieve label"); + } + + uint32_t bins_per_bundle = params_.bins_per_bundle(); + + // Preprocess a single element. This algebraizes the item and gives back its + // field element representation as well as its cuckoo hash. We only read one + // of the locations because the labels are the same in each location. + apsi::util::AlgItem alg_item; + size_t cuckoo_idx; + std::tie(alg_item, cuckoo_idx) = + PreprocessUnlabeledData(hashed_item, params_)[0]; + + // Now figure out where to look to get the label + size_t bin_idx; + size_t bundle_idx; + std::tie(bin_idx, bundle_idx) = + labeled_psi::UnpackCuckooIdx(cuckoo_idx, bins_per_bundle); + + // Retrieve the algebraic labels from one of the BinBundles at this index + const std::vector> &bundle_set = + bin_bundles_[bundle_idx]; + std::vector alg_label; + bool got_labels = false; + for (const auto &bundle : bundle_set) { + // Try to retrieve the contiguous labels from this BinBundle + if (bundle->try_get_multi_label(alg_item, bin_idx, alg_label)) { + got_labels = true; + break; + } + } + + // It shouldn't be possible to have items in your set but be unable to + // retrieve the associated label. Throw an exception because something is + // terribly wrong. + if (!got_labels) { + SPDLOG_ERROR( + "Failed to retrieve label for an item that was supposed to be in the " + "SenderDB"); + SPU_THROW("failed to retrieve label"); + } + + // All good. Now just reconstruct the big label from its split-up parts + apsi::EncryptedLabel encrypted_label = dealgebraize_label( + alg_label, + alg_label.size() * static_cast(params_.item_bit_count_per_felt()), + params_.seal_params().plain_modulus()); + + // Resize down to the effective byte count + encrypted_label.resize(nonce_byte_count_ + label_byte_count_); + + // Decrypt the label + return decrypt_label(encrypted_label, key, nonce_byte_count_); +} + +void SenderMemDB::SetData(const std::shared_ptr &batch_provider, + size_t batch_size) { + size_t batch_count = 0; + + while (true) { + std::vector batch_items; + std::vector batch_labels; + + if (IsLabeled()) { + std::tie(batch_items, batch_labels) = + batch_provider->ReadNextBatchWithLabel(batch_size); + } else { + batch_items = batch_provider->ReadNextBatch(batch_size); + } + if (batch_items.empty()) { + SPDLOG_INFO("count: {}, item_count:{}", batch_count, item_count_); + break; + } + + if (IsLabeled()) { + InsertOrAssign(batch_items, batch_labels); + } else { + InsertOrAssign(batch_items); + } + + batch_count++; + } + + // Generate the BinBundle caches + GenerateCaches(); +} + +std::shared_ptr SenderMemDB::GetBinBundleAt( + std::uint32_t bundle_idx, size_t cache_idx) { + return bin_bundles_[bundle_idx][cache_idx]; +} + +} // namespace spu::psi diff --git a/libspu/psi/core/labeled_psi/sender_memdb.h b/libspu/psi/core/labeled_psi/sender_memdb.h new file mode 100644 index 00000000..827cebf3 --- /dev/null +++ b/libspu/psi/core/labeled_psi/sender_memdb.h @@ -0,0 +1,228 @@ +// Copyright 2021 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// code reference https://github.com/microsoft/APSI/sender/sender_db.h +// Licensed under the MIT license. + +#pragma once + +// STD +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// GSL +#include "gsl/span" + +// APSI +#include "apsi/bin_bundle.h" +#include "apsi/crypto_context.h" +#include "apsi/item.h" +#include "apsi/psi_params.h" +#include "yacl/base/byte_container_view.h" + +#include "libspu/psi/core/ecdh_oprf/basic_ecdh_oprf.h" +#include "libspu/psi/core/labeled_psi/sender_db.h" + +// SEAL +#include "seal/plaintext.h" +#include "seal/util/locks.h" +#include "spdlog/spdlog.h" + +namespace spu::psi { + +/** +A SenderDB maintains an in-memory representation of the sender's set of items +and labels (in labeled mode). This data is not simply copied into the SenderDB +data structures, but also preprocessed heavily to allow for faster online +computation time. Since inserting a large number of new items into a SenderDB +can take time, it is not recommended to recreate the SenderDB when the database +changes a little bit. Instead, the class supports fast update and deletion +operations that should be preferred: SenderDB::InsertOrAssign and +SenderDB::remove. + +The SenderDB constructor allows the label byte count to be specified; unlabeled +mode is activated by setting the label byte count to zero. It is possible to +optionally specify the size of the nonce used in encrypting the labels, but this +is best left to its default value unless the user is absolutely sure of what +they are doing. + +The SenderDB requires substantially more memory than the raw data would. Part of +that memory can automatically be compressed when it is not in use; this feature +is enabled by default, and can be disabled when constructing the SenderDB. The +downside of in-memory compression is a performance reduction from decompressing +parts of the data when they are used, and recompressing them if they are +updated. +*/ +class SenderMemDB : public ISenderDB { + public: + /** + Creates a new SenderDB. + */ + SenderMemDB(const apsi::PSIParams ¶ms, yacl::ByteContainerView oprf_key, + std::size_t label_byte_count = 0, + std::size_t nonce_byte_count = 16, bool compressed = true); + + /** + Clears the database. Every item and label will be removed. The OPRF key is + unchanged. + */ + void clear(); + + /** + Strips the SenderDB of all information not needed for serving a query. + */ + void strip(); + + /** + Inserts the given data into the database. This function can be used only on a + labeled SenderDB instance. If an item already exists in the database, its + label is overwritten with the new label. + */ + void InsertOrAssign(const std::vector &keys, + const std::vector &labels); + + /** + Inserts the given (hashed) item-label pair into the database. This function + can be used only on a labeled SenderDB instance. If the item already exists in + the database, its label is overwritten with the new label. + */ + void InsertOrAssign(const std::string &key, const std::string &label) { + std::vector key_singleton{key}; + std::vector label_singleton{label}; + InsertOrAssign(key_singleton, label_singleton); + } + + /** + Inserts the given data into the database. This function can be used only on an + unlabeled SenderDB instance. + */ + void InsertOrAssign(const std::vector &data); + + /** + Inserts the given (hashed) item into the database. This function can be used + only on an unlabeled SenderDB instance. + */ + void InsertOrAssign(const std::string &data) { + std::vector data_singleton{data}; + InsertOrAssign(data_singleton); + } + + /** + Clears the database and inserts the given data. This function can be used only + on a labeled SenderDB instance. + */ + void SetData(const std::vector &keys, + const std::vector &labels) { + clear(); + InsertOrAssign(keys, labels); + // Generate the BinBundle caches + GenerateCaches(); + } + + void SetData(const std::shared_ptr &batch_provider, + size_t batch_size = 500000) override; + + /** + Clears the database and inserts the given data. This function can be used only + on an unlabeled SenderDB instance. + */ + void SetData(const std::vector &data) { + clear(); + InsertOrAssign(data); + // Generate the BinBundle caches + GenerateCaches(); + } + + /** + Removes the given data from the database, using at most thread_count threads. + */ + void remove(const std::vector &data); + + /** + Removes the given (hashed) item from the database. + */ + void remove(const apsi::Item &data) { + std::vector data_singleton{data}; + remove(data_singleton); + } + + /** + Returns whether the given item has been inserted in the SenderDB. + */ + bool HasItem(const apsi::Item &item) const; + + /** + Returns the label associated to the given item in the database. Throws + std::invalid_argument if the item does not appear in the database. + */ + apsi::Label GetLabel(const apsi::Item &item) const; + + /** + Returns a set of cache references corresponding to the bundles at the given + bundle index. Even though this function returns a vector, the order has no + significance. This function is meant for internal use. + */ + std::shared_ptr GetBinBundleAt( + std::uint32_t bundle_idx, size_t cache_idx) override; + + /** + Returns a reference to a set of item hashes already existing in the SenderDB. + */ + const std::unordered_set &GetHashedItems() const { + return hashed_items_; + } + + /** + Returns the number of items in this SenderDB. + */ + size_t GetItemCount() const { return item_count_; } + + /** + Returns the total number of bin bundles at a specific bundle index. + */ + std::size_t GetBinBundleCount(std::uint32_t bundle_idx) const override; + + /** + Returns the total number of bin bundles. + */ + std::size_t GetBinBundleCount() const override; + + private: + void ClearInternal(); + + void GenerateCaches(); + + /** + The set of all items that have been inserted into the database + */ + std::unordered_set hashed_items_; + + /** + All the BinBundles in the database, indexed by bundle index. The set + (represented by a vector internally) at bundle index i contains all the + BinBundles with bundle index i. + */ + std::vector>> + bin_bundles_; + +}; // class SenderDB + +} // namespace spu::psi diff --git a/libspu/psi/core/labeled_psi/serializable.proto b/libspu/psi/core/labeled_psi/serializable.proto index 0211a61e..b8580682 100644 --- a/libspu/psi/core/labeled_psi/serializable.proto +++ b/libspu/psi/core/labeled_psi/serializable.proto @@ -69,7 +69,7 @@ message AlgItemProto { message AlgItemLabelPairProto { uint64 item = 1; - repeated uint64 label = 2; + bytes label_data = 2; } message AlgItemLabelProto { diff --git a/libspu/psi/core/labeled_psi/serialize.h b/libspu/psi/core/labeled_psi/serialize.h index 8310e86d..b949f3a9 100644 --- a/libspu/psi/core/labeled_psi/serialize.h +++ b/libspu/psi/core/labeled_psi/serialize.h @@ -44,8 +44,6 @@ inline proto::AlgItemProto SerializeAlgItem( inline std::string SerializeAlgItemToString( const apsi::util::AlgItem& alg_items) { - std::string ret; - proto::AlgItemProto proto = SerializeAlgItem(alg_items); std::string item_string(proto.ByteSizeLong(), '\0'); @@ -81,9 +79,9 @@ inline void SerializeAlgItemLabel( proto::AlgItemLabelPairProto* pair_proto = proto->add_item_label(); pair_proto->set_item(item_label_pair[i].first); - for (auto& label : item_label_pair[i].second) { - pair_proto->add_label(label); - } + pair_proto->set_label_data( + item_label_pair[i].second.data(), + item_label_pair[i].second.size() * sizeof(uint64_t)); } } @@ -111,10 +109,13 @@ inline apsi::util::AlgItemLabel DeserializeAlgItemLabel( for (int i = 0; i < proto.item_label_size(); ++i) { auto pair_proto = proto.item_label(i); - std::vector labels(pair_proto.label_size()); - for (int j = 0; j < pair_proto.label_size(); ++j) { - labels[j] = pair_proto.label(j); - } + + auto label_data = pair_proto.label_data(); + + std::vector labels(label_data.size() / + sizeof(apsi::util::felt_t)); + + std::memcpy(labels.data(), label_data.data(), label_data.size()); item_label_pair.emplace_back(pair_proto.item(), labels); } @@ -142,10 +143,6 @@ inline std::string SerializeDataWithIndices( std::string item_string(proto.ByteSizeLong(), '\0'); proto.SerializePartialToArray(item_string.data(), proto.ByteSizeLong()); - // SPDLOG_INFO("*** debug item_size:{} bytes:{} proto size:{}", - // item_proto->item_size(), item_proto->ByteSizeLong(), - // proto.ByteSizeLong()); - return item_string; } @@ -161,8 +158,6 @@ inline std::pair DeserializeDataWithIndices( inline std::string SerializeDataLabelWithIndices( const std::pair& data_with_indices) { - std::string ret; - proto::DataLabelWithIndicesProto proto; proto::AlgItemLabelProto* item_proto = new proto::AlgItemLabelProto(); diff --git a/sml/lr/simple_lr.py b/sml/lr/simple_lr.py index 73abab21..f4bfca5d 100644 --- a/sml/lr/simple_lr.py +++ b/sml/lr/simple_lr.py @@ -1,129 +1,8 @@ import numpy as np import jax.numpy as jnp from enum import Enum +from ..utils.fxp_approx import sigmoid, SigType -def t1_sig(x, limit: bool = True): - ''' - taylor series referenced from: - https://mortendahl.github.io/2017/04/17/private-deep-learning-with-mpc/ - ''' - T0 = 1.0 / 2 - T1 = 1.0 / 4 - ret = T0 + x * T1 - if limit: - return jnp.select([ret < 0, ret > 1], [0, 1], ret) - else: - return ret - - -def t3_sig(x, limit: bool = True): - ''' - taylor series referenced from: - https://mortendahl.github.io/2017/04/17/private-deep-learning-with-mpc/ - ''' - T3 = -1.0 / 48 - ret = t1_sig(x, False) + jnp.power(x, 3) * T3 - if limit: - return jnp.select([x < -2, x > 2], [0, 1], ret) - else: - return ret - - -def t5_sig(x, limit: bool = True): - ''' - taylor series referenced from: - https://mortendahl.github.io/2017/04/17/private-deep-learning-with-mpc/ - ''' - T5 = 1.0 / 480 - ret = t3_sig(x, False) + jnp.power(x, 5) * T5 - if limit: - return jnp.select([ret < 0, ret > 1], [0, 1], ret) - else: - return ret - - -def seg3_sig(x): - ''' - f(x) = 0.5 + 0.125x if -4 <= x <= 4 - 1 if x > 4 - 0 if -4 > x - ''' - return jnp.select([x < -4, x > 4], [0, 1], 0.5 + x * 0.125) - - -def df_sig(x): - ''' - https://dergipark.org.tr/en/download/article-file/54559 - Dataflow implementation of sigmoid function: - F(x) = 0.5 * ( x / ( 1 + |x| ) ) + 0.5 - df_sig has higher precision than sr_sig if x in [-2, 2] - ''' - return 0.5 * (x / (1 + jnp.abs(x))) + 0.5 - - -def sr_sig(x): - ''' - https://en.wikipedia.org/wiki/Sigmoid_function#Examples - Square Root approximation functions: - F(x) = 0.5 * ( x / ( 1 + x^2 )^0.5 ) + 0.5 - sr_sig almost perfect fit to sigmoid if x out of range [-3,3] - ''' - return 0.5 * (x / jnp.sqrt(1 + jnp.square(x))) + 0.5 - - -def ls7_sig(x): - '''Polynomial fitting''' - return ( - 5.00052959e-01 - + 2.35176260e-01 * x - - 3.97212202e-05 * jnp.power(x, 2) - - 1.23407424e-02 * jnp.power(x, 3) - + 4.04588962e-06 * jnp.power(x, 4) - + 3.94330487e-04 * jnp.power(x, 5) - - 9.74060972e-08 * jnp.power(x, 6) - - 4.74674505e-06 * jnp.power(x, 7) - ) - - -def mix_sig(x): - ''' - mix ls7 & sr sig, use ls7 if |x| < 4 , else use sr. - has higher precision in all input range. - NOTICE: this method is very expensive, only use for hessian matrix. - ''' - ls7 = ls7_sig(x) - sr = sr_sig(x) - return jnp.select([x < -4, x > 4], [sr, sr], ls7) - - -def real_sig(x): - return 1 / (1 + jnp.exp(-x)) - -def sigmoid(x, sig_type): - if sig_type is SigType.REAL: - return real_sig(x) - elif sig_type is SigType.T1: - return t1_sig(x) - elif sig_type is SigType.T3: - return t3_sig(x) - elif sig_type is SigType.T5: - return t5_sig(x) - elif sig_type is SigType.DF: - return df_sig(x) - elif sig_type is SigType.SR: - return sr_sig(x) - elif sig_type is SigType.MIX: - return mix_sig(x) - -class SigType(Enum): - REAL = 'real' - T1 = 't1' - T3 = 't3' - T5 = 't5' - DF = 'df' - SR = 'sr' - # DO NOT use this except in hessian case. - MIX = 'mix' class Penalty(Enum): NONE = 'None' @@ -132,27 +11,67 @@ class Penalty(Enum): Elastic = 'elasticnet' # not supported class MultiClass(Enum): - Ovr = 'ovr' # binary problem - Multy = 'multinomial' # multi_class problem not supported + Binary = 'binary' + Ovr = 'ovr' # not supported yet + Multy = 'multinomial' # not supported yet -class SGDClassifier: +class Solver(Enum): + SGD = 'sgd' + +class LogisticRegression: + """ + Logistic Regression (aka logit, MaxEnt) classifier. + + IMPORTANT: Something different between `LogisticRegression` in sklearn: + 1. sigmoid will be computed with approximation + 2. you must define multi_class because we can not inspect y to decision the problem type + 3. for now, only 0-1 binary classification is supported; so if your label is {-1,1}, you must change it first! + + Parameters + ---------- + penalty: Specify the norm of the penalty: + {'l1', 'l2', 'elasticnet', 'None'}, default='l2' (current only support l2) + + solver: Algorithm to use in the optimization problem, default='sgd'. + + multi_class: specify whether and which multi-classification form + {'binary', 'ovr', 'multinomial'}, default='binary' (current only support binary) + - binary: binary problem + - ovr: for each label, will fit a binary problem + - multinomial: the loss minimised is the multinomial loss that fit across the entire probability distribution + + class_weight: not support yet, for multi-class tasks, default=None + + sig_type: the approximation method for sigmoid function, default='sr' + for all choices, refer to `SigType` + + l2_norm: the strength of L2 norm, must be a positive float, default=0.01 + + epochs, learning_rate, batch_size: hyper-parameters for sgd solver + epochs: default=20 + learning_rate: default=0.1 + batch_size: default=512 + + """ def __init__( self, - epochs: int, - learning_rate: float, - batch_size: int, - penalty: str, - sig_type: str, - l2_norm: float, - class_weight: None, - multi_class: str, + penalty: str='l2', + solver: str='sgd', + multi_class: str='binary', + class_weight=None, + sig_type: str='sr', + l2_norm: float=0.01, + epochs: int=20, + learning_rate: float=0.1, + batch_size: int=512, ): # parameter check. assert epochs > 0, f"epochs should >0" assert learning_rate > 0, f"learning_rate should >0" assert batch_size > 0, f"batch_size should >0" assert penalty == 'l2', "only support L2 penalty for now" + assert solver == 'sgd', "only support sgd solver for now" if penalty == Penalty.L2: assert l2_norm > 0, f"l2_norm should >0 if use L2 penalty" assert penalty in [ @@ -162,7 +81,7 @@ def __init__( e.value for e in SigType ], f"sig_type should in {[e.value for e in SigType]}, but got {sig_type}" assert class_weight == None, f"not support class_weight for now" - assert multi_class == 'ovr', f"only support binary problem for now" + assert multi_class == 'binary', f"only support binary problem for now" self._epochs = epochs self._learning_rate = learning_rate @@ -212,9 +131,12 @@ def _update_weights( ) grad = grad + w_with_zero_bias * self._l2_norm elif self._penalty == Penalty.L1: - pass + raise NotImplementedError elif self._penalty == Penalty.Elastic: - pass + raise NotImplementedError + else: + # None penalty + raise NotImplementedError step = (self._learning_rate * grad) / batch_size @@ -230,7 +152,7 @@ def fit(self, x, y): X : {array-like}, shape (n_samples, n_features) Training data. - y : ndarray of shape (n_samples,) + y : ndarray of shape (n_samples,), value MUST between {0,1,...k-1} (k is the number of classes) Target values. Returns @@ -248,9 +170,9 @@ def fit(self, x, y): # not support class_weight for now if isinstance(self._class_weight, dict): - pass + raise NotImplementedError elif self._class_weight == 'balanced': - pass + raise NotImplementedError # do train for _ in range(self._epochs): @@ -277,9 +199,43 @@ def predict_proba(self, x): ------- ndarray of shape (n_samples, n_classes) Returns the probability of the sample for each class in the model, - where classes are ordered as they are in `self.classes_`. """ - if self._multi_class == MultiClass.Ovr: + pred = self.decision_function(x) + + if self._multi_class == MultiClass.Binary: + prob = sigmoid(pred, self._sig_type) + else: + raise NotImplementedError + + return prob + + + def predict(self, x): + """ + Predict class labels for samples in X. + + Parameters + ---------- + X : {array-like, } of shape (n_samples, n_features) + The data matrix for which we want to get the predictions. + + Returns + ------- + y_pred : ndarray of shape (n_samples,) + Vector containing the class labels for each sample. + """ + pred = self.decision_function(x) + + if self._multi_class == MultiClass.Binary: + # for binary task, only check whether logit > 0 (prob > 0.5) + label = jnp.select([pred>0], [1], 0) + else: + raise NotImplementedError + + return label + + def decision_function(self, x): + if self._multi_class == MultiClass.Binary: num_feat = x.shape[1] w = self._weights assert w.shape[0] == num_feat + 1, f"w shape is mismatch to x={x.shape}" @@ -289,8 +245,9 @@ def predict_proba(self, x): bias = w[-1, 0] w = jnp.resize(w, (num_feat, 1)) pred = jnp.matmul(x, w) + bias - pred = sigmoid(pred, self._sig_type) return pred - elif self._multi_class == MultiClass.Multy: - # not support multi_class problem for now - pass + elif self._multi_class == MultiClass.Ovr: + raise NotImplementedError + else: + # Multy model here + raise NotImplementedError \ No newline at end of file diff --git a/sml/lr/simple_lr_emul.py b/sml/lr/simple_lr_emul.py index c029395f..7acb13a9 100644 --- a/sml/lr/simple_lr_emul.py +++ b/sml/lr/simple_lr_emul.py @@ -22,24 +22,32 @@ # Add the library directory to the path sys.path.append(os.path.join(os.path.dirname(__file__), '../../')) -from sml.lr.simple_lr import SGDClassifier +from sml.lr.simple_lr import LogisticRegression import sml.utils.emulation as emulation + # TODO: design the enumation framework, just like py.unittest # all emulation action should begin with `emul_` (for reflection) -def emul_SGDClassifier(mode: emulation.Mode.MULTIPROCESS): +def emul_LogisticRegression(mode: emulation.Mode.MULTIPROCESS): + # Test SGDClassifier def proc(x, y): - model = SGDClassifier( + model = LogisticRegression( epochs=3, learning_rate=0.1, batch_size=8, + solver='sgd', penalty='l2', sig_type='sr', l2_norm=1.0, class_weight=None, - multi_class='ovr' + multi_class='binary', ) - return model.fit(x, y).predict_proba(x) + + model = model.fit(x, y) + + prob = model.predict_proba(x) + pred = model.predict(x) + return prob, pred try: # bandwidth and latency only work for docker mode @@ -54,17 +62,22 @@ def proc(x, y): cols = X.columns X = scalar.fit_transform(X) X = pd.DataFrame(X, columns=cols) - + # mark these data to be protected in SPU - X_spu,y_spu = emulator.seal(X.values, y.values.reshape(-1, 1)) # X, y should be two-dimension array + X_spu, y_spu = emulator.seal( + X.values, y.values.reshape(-1, 1) + ) # X, y should be two-dimension array # Run result = emulator.run(proc)(X_spu, y_spu) - print("Predict result: ", result) - print("ROC Score: ", roc_auc_score(y.values, result)) + print("Predict result prob: ", result[0]) + print("Predict result label: ", result[1]) + + print("ROC Score: ", roc_auc_score(y.values, result[0])) finally: emulator.down() + if __name__ == "__main__": - emul_SGDClassifier(emulation.Mode.MULTIPROCESS) + emul_LogisticRegression(emulation.Mode.MULTIPROCESS) diff --git a/sml/lr/simple_lr_test.py b/sml/lr/simple_lr_test.py index ebf61406..a8ee9b7a 100644 --- a/sml/lr/simple_lr_test.py +++ b/sml/lr/simple_lr_test.py @@ -26,7 +26,7 @@ # Add the library directory to the path sys.path.append(os.path.join(os.path.dirname(__file__), '../../')) -from sml.lr.simple_lr import SGDClassifier +from sml.lr.simple_lr import LogisticRegression class UnitTests(unittest.TestCase): def test_simple(self): @@ -36,17 +36,23 @@ def test_simple(self): # Test SGDClassifier def proc(x, y): - model = SGDClassifier( + model = LogisticRegression( epochs=3, learning_rate=0.1, batch_size=8, + solver='sgd', penalty='l2', sig_type='sr', l2_norm=1.0, class_weight=None, - multi_class='ovr' + multi_class='binary' ) - return model.fit(x, y).predict_proba(x) + + model = model.fit(x,y) + + prob = model.predict_proba(x) + pred = model.predict(x) + return prob, pred # Create dataset X, y = load_breast_cancer(return_X_y=True, as_frame=True) @@ -57,8 +63,11 @@ def proc(x, y): # Run result = spsim.sim_jax(sim, proc)(X.values, y.values.reshape(-1, 1)) # X, y should be two-dimension array - print("Predict result: ", result) - print("ROC Score: ", roc_auc_score(y.values, result)) + print("Predict result prob: ", result[0]) + print("Predict result label: ", result[1]) + + + print("ROC Score: ", roc_auc_score(y.values, result[0])) diff --git a/sml/utils/fxp_approx.py b/sml/utils/fxp_approx.py index e44789b4..5de59386 100644 --- a/sml/utils/fxp_approx.py +++ b/sml/utils/fxp_approx.py @@ -13,8 +13,21 @@ # limitations under the License. import jax.numpy as jnp +from enum import Enum +class SigType(Enum): + T1 = 't1' + T3 = 't3' + T5 = 't5' + SEG3 = 'seg3' + DF = 'df' + SR = 'sr' + LS7 = 'ls7' + # DO NOT use this except in hessian case. + MIX = 'mix' + REAL = 'real' + # taylor series referenced from: # https://mortendahl.github.io/2017/04/17/private-deep-learning-with-mpc/ def sigmoid_t1(x): @@ -55,3 +68,53 @@ def sigmoid_df(x): # highly recommended use this appr as GDBT's default sigmoid method. def sigmoid_sr(x): return 0.5 * (x / jnp.sqrt(1 + jnp.power(x, 2))) + 0.5 + + +# polynomial fitting of degree 7 +def sigmoid_ls7(x): + return ( + 5.00052959e-01 + + 2.35176260e-01 * x + - 3.97212202e-05 * jnp.power(x, 2) + - 1.23407424e-02 * jnp.power(x, 3) + + 4.04588962e-06 * jnp.power(x, 4) + + 3.94330487e-04 * jnp.power(x, 5) + - 9.74060972e-08 * jnp.power(x, 6) + - 4.74674505e-06 * jnp.power(x, 7) + ) + + +# mix ls7 & sr sig, use ls7 if |x| < 4 , else use sr. +# has higher precision in all input range. +# NOTICE: this method is very expensive, only use for hessian matrix. +def sigmoid_mix(x): + ls7 = sigmoid_ls7(x) + sr = sigmoid_sr(x) + return jnp.select([x < -4, x > 4], [sr, sr], ls7) + + +# real computation of sigmoid +# NOTICE: should make sure x not too small +def sigmoid_real(x): + return 1 / (1 + jnp.exp(-x)) + + +def sigmoid(x, sig_type): + if sig_type is SigType.T1: + return sigmoid_t1(x) + elif sig_type is SigType.T3: + return sigmoid_t3(x) + elif sig_type is SigType.T5: + return sigmoid_t5(x) + elif sig_type is SigType.SEG3: + return sigmoid_seg3(x) + elif sig_type is SigType.DF: + return sigmoid_df(x) + elif sig_type is SigType.SR: + return sigmoid_sr(x) + elif sig_type is SigType.LS7: + return sigmoid_ls7(x) + elif sig_type is SigType.MIX: + return sigmoid_mix(x) + elif sig_type is SigType.REAL: + return sigmoid_real(x) \ No newline at end of file diff --git a/spu/libspu.cc b/spu/libspu.cc index 82c9b907..193a1072 100644 --- a/spu/libspu.cc +++ b/spu/libspu.cc @@ -552,6 +552,9 @@ void BindLibs(py::module& m) { pir::PirSetupConfig config; SPU_ENFORCE(config.ParseFromString(config_pb)); + config.set_bucket_size(1000000); + config.set_compressed(false); + auto r = pir::PirSetup(config); return r.SerializeAsString(); }, @@ -577,6 +580,9 @@ void BindLibs(py::module& m) { SPU_ENFORCE(config.ParseFromString(config_pb)); SPU_ENFORCE(config.setup_path() == "::memory"); + config.set_bucket_size(1000000); + config.set_compressed(false); + auto r = pir::PirMemoryServer(lctx, config); return r.SerializeAsString(); }, diff --git a/spu/ops/groupby/BUILD.bazel b/spu/ops/groupby/BUILD.bazel new file mode 100644 index 00000000..f44ad836 --- /dev/null +++ b/spu/ops/groupby/BUILD.bazel @@ -0,0 +1,32 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_python//python:defs.bzl", "py_binary", "py_library", "py_test") + +package(default_visibility = ["//visibility:public"]) + +py_library( + name = "groupby", + srcs = ["groupby.py"], +) + +py_test( + name = "groupby_test", + srcs = ["groupby_test.py"], + deps = [ + ":groupby", + "//spu:init", + "//spu/utils:simulation", + ], +) diff --git a/spu/ops/groupby/groupby.py b/spu/ops/groupby/groupby.py new file mode 100644 index 00000000..d3294150 --- /dev/null +++ b/spu/ops/groupby/groupby.py @@ -0,0 +1,255 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from typing import List, Tuple + +import jax +import jax.numpy as jnp +import numpy as np + +# Conceptually we want to do the following +# However, due to limitations of design, we cannot do this +# we will use other groupby accumulators to do groupby some +# and treat (target_columns_sorted: List[jnp.ndarray], segment_ids: List[jnp.ndarray]) as the groupby object. +# class GroupByObject: +# def __init__( +# self, target_columns_sorted: List[jnp.ndarray], segment_ids: List[jnp.ndarray] +# ): +# self.target_columns_sorted = target_columns_sorted +# self.segment_ids = segment_ids + +# def sum(self, group_num: int): +# """ +# group num should be revealed and accessed from segment ids and taken as a static int. +# """ +# segment_ids = self.segment_ids +# x = self.target_columns_sorted +# return jax.ops.segment_sum(x, segment_ids, num_segments=group_num) + + +def segment_aware_addition(row1, row2): + return segment_aware_ops(row1, row2, jnp.add) + + +def segment_aware_max(row1, row2): + return segment_aware_ops(row1, row2, jnp.maximum) + + +def segment_aware_ops(row1, row2, ops): + cum_part = jnp.where((row2[:, 0] == 1).reshape(-1, 1), ops(row1, row2), row2)[:, 1:] + lead_part = (row1[:, 0] * row2[:, 0]).reshape(-1, 1) + return jnp.c_[lead_part, cum_part] + + +def groupby_agg_via_shuffle( + cols, + seg_end_marks, + segment_ids, + secret_random_order: jnp.ndarray, + segment_aware_ops, +): + """Groupby Aggregation with shuffled outputs. + + trick: output segment_end_marks and group aggregations in shuffled state, + filter to get the group aggregations in cleartext. + + shuffle to protect number of group elements. + + The returns of this function are supposed to be ok to be opened. + return: + segment_ids_shuffled: + shuffled segment ids + shuffled_group_end_masks: + shuffled group end masks + shuffled_group_agg_matrix: + shape = (n_samples, n_cols) + group aggregations shuffled + padded with zeros. + + """ + group_mask = jnp.ones(seg_end_marks.shape) - jnp.roll(seg_end_marks, 1) + + X = jnp.vstack([group_mask] + list(cols)).T + + X_prefix_sum = jax.lax.associative_scan(segment_aware_ops, X, axis=0) + X_prefix_sum_masked = seg_end_marks.reshape(-1, 1) * X_prefix_sum + segment_ids_masked = seg_end_marks * segment_ids + shuffled_cols = jax.lax.sort( + [secret_random_order] + + [segment_ids_masked] + + [seg_end_marks] + + [X_prefix_sum_masked[:, i] for i in range(1, X_prefix_sum_masked.shape[1])], + num_keys=1, + ) + return [ + shuffled_cols[1], + shuffled_cols[2], + jnp.vstack(shuffled_cols[3:]).T, + ] + + +def groupby_sum_via_shuffle( + cols, seg_end_marks, segment_ids, secret_random_order: jnp.ndarray +): + return groupby_agg_via_shuffle( + cols, + seg_end_marks, + segment_ids, + secret_random_order, + segment_aware_ops=segment_aware_addition, + ) + + +def groupby_max_via_shuffle( + cols, seg_end_marks, segment_ids, secret_random_order: jnp.ndarray +): + return groupby_agg_via_shuffle( + cols, + seg_end_marks, + segment_ids, + secret_random_order, + segment_aware_ops=segment_aware_max, + ) + + +# cleartext function +def groupby_agg_postprocess( + segment_ids, seg_end_marks, group_agg_matrix, group_num: int +): + assert ( + isinstance(group_num, int) and group_num > 0 + ), f"group num must be a positve integer. got {group_num}, {type(group_num)}" + if group_num > 1: + filter_mask = seg_end_marks == 1 + segment_ids = segment_ids[filter_mask] + group_agg_matrix = group_agg_matrix[filter_mask] + sorted_results = jax.lax.sort( + [segment_ids] + + [group_agg_matrix[:, i] for i in range(group_agg_matrix.shape[1])], + num_keys=1, + )[1:] + return jnp.vstack(sorted_results).T + else: + return group_agg_matrix[-1] + + +def batch_product(list_of_cols, multiplier_col): + return list( + map( + lambda x: x * multiplier_col, + list_of_cols, + ) + ) + + +def view_key( + key_columns_sorted: List[jnp.ndarray], + seg_end_marks: jnp.ndarray, + secret_random_order: jnp.ndarray, +): + """The secret_random_order must be secret to all parties + trick: open a shuffled array and unique in cleartext + """ + assert len(key_columns_sorted) > 0, "number of keys must be non-empty" + keys = batch_product(key_columns_sorted, seg_end_marks) + assert ( + secret_random_order.shape == key_columns_sorted[0].shape + ), "the secret_random_order should be the same shape as each of the key columns." + keys_shuffled = jax.lax.sort([secret_random_order] + keys, num_keys=1)[1:] + return keys_shuffled + + +# function operating on cleartext, used to postprocess opened results. +def view_key_postprocessing(keys, group_num: int): + keys = np.unique(np.vstack(keys).T, axis=0) + if keys.shape[0] > group_num: + keys = keys[1:, :] + return keys + + +def groupby( + key_columns: List[jnp.ndarray], + target_columns: List[jnp.ndarray], +) -> Tuple[List[jnp.ndarray], jnp.ndarray]: + """GroupBy + Given a matrix X, it has multiple columns. + We want to calculate some statistics of target columns grouped by some columns as keys. + This operator completes the first step of GroupBy statistics: transfom the matrix x into a form, + that is suitable for subsequent statistics. + + Parameters + ---------- + + key_columns : List[jnp.ndarray] + List of columns that are used as keys, these should be arrays of the same shape. + + target_columns : List[jnp.ndarray] + List of columns that are used as keys, these should be arrays of the same shape as the shape in key columns. + + + Returns + ------- + key_columns_sorted : List[jnp.ndarray] + target_columns_sorted : List[jnp.ndarray] + segment_ids : List[jnp.ndarray] + """ + # parameter check. + assert isinstance(key_columns, List) + assert isinstance(target_columns, List) + assert len(key_columns) > 0, "There should be at least one key_column." + assert len(target_columns) > 0, "There should be at least one target_column." + assert ( + len(set(map(lambda x: x.shape, key_columns + target_columns))) == 1 + ), f"Columns' shape should be consistant. {set(map(lambda x: x.shape, key_columns + target_columns))}" + key_columns = key_columns + target_columns = target_columns + sorted_columns = jax.lax.sort( + key_columns + target_columns, num_keys=len(key_columns) + ) + key_columns_sorted = sorted_columns[: len(key_columns)] + target_columns_sorted = sorted_columns[len(key_columns) :] + key_columns_sorted_rolled = rotate_cols(key_columns_sorted) + seg_end_marks = get_segment_marks(key_columns_sorted, key_columns_sorted_rolled) + mark_accumulated = associative_scan(seg_end_marks) + segment_ids = mark_accumulated - seg_end_marks + return key_columns_sorted, target_columns_sorted, segment_ids, seg_end_marks + + +# the method is simple: open segment_ids and do count in cleartext +# in SPU all NaN values are encoded as 0, so count becomes trivial. +# cleartext function: +# further if a query includes count, the shuffle ops in groupby sum can be skipped +def groupby_count(opened_segment_ids): + _, counts = jnp.unique(opened_segment_ids, return_counts=True) + return counts + + +def rotate_cols(key_columns_sorted) -> List[jnp.ndarray]: + return list(map(lambda x: jnp.roll(x, -1), key_columns_sorted)) + + +def get_segment_marks(key_columns_sorted, key_columns_sorted_rolled): + tuple_list = list(zip(key_columns_sorted, key_columns_sorted_rolled)) + equal = [a - b == 0 for (a, b) in tuple_list] + c = ~functools.reduce(lambda x, y: x & y, equal) + c = c.astype(int) + result = jnp.r_[c[: c.size - 1], [1]] + # try + # result = c.at[c.size - 1].set(1) + return result + + +def associative_scan(seg_end_marks): + return jax.lax.associative_scan(jnp.add, seg_end_marks) diff --git a/spu/ops/groupby/groupby_test.py b/spu/ops/groupby/groupby_test.py new file mode 100644 index 00000000..9f899709 --- /dev/null +++ b/spu/ops/groupby/groupby_test.py @@ -0,0 +1,294 @@ +# Copyright 2023 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys +import time +import unittest + +import numpy as np +import pandas as pd + +import spu.spu_pb2 as spu_pb2 +import spu.utils.simulation as spsim + +# Add the ops directory to the path +sys.path.append(os.path.join(os.path.dirname(__file__), '../../')) + +from groupby import ( + groupby, + groupby_agg_postprocess, + groupby_sum_via_shuffle, + groupby_max_via_shuffle, + groupby_count, + view_key, + view_key_postprocessing, +) + + +class UnitTests(unittest.TestCase): + def test_sum(self): + sim = spsim.Simulator.simple( + 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 + ) + + def proc(x1, x2, y): + return groupby([x1[:, 2], x2[:, 3]], [y]) + + def proc_sum_shuffle(cols, seg_end_marks, segment_ids, secret_random_order): + return groupby_sum_via_shuffle( + cols, seg_end_marks, segment_ids, secret_random_order + ) + + def proc_view_key(key_cols, segment_end_marks, key): + return view_key(key_cols, segment_end_marks, key) + + n_rows = 30000 + n_cols = 10 + x1 = np.random.random((n_rows, n_cols)) + x2 = np.random.random((n_rows, n_cols)) + y = np.random.random((n_rows,)) + # groupby category only supports discrete values + # here we are taking the shortcut, in reality, only the key need to be discrete. + # (their difference should be large enough so that their fxp repr are difference) + # we shrink the data value in order to reduce the size of groups. + # in our test data, we will get a group size about 15 + x1 = (x1 * 10 / 3).astype(int) + x2 = (x2 * 10 / 3).astype(int) + y = (y * 10 / 3).astype(int) + start = time.perf_counter() + keys, target_cols, segment_ids, segment_end_marks = spsim.sim_jax(sim, proc)( + x1, x2, y + ) + end = time.perf_counter() + print("groupby takes time", end - start) + X = np.zeros((x1.shape[0], 3)) + X[:, 0] = x1[:, 2] + X[:, 1] = x2[:, 3] + X[:, 2] = y + df = pd.DataFrame( + X, + columns=[f'col{i}' for i in range(3)], + ) + # Perform group by sum using pandas + pandas_groupby_sum = df.groupby([df.columns[0], df.columns[1]])[ + df.columns[2] + ].sum() + num_groups = pandas_groupby_sum.shape[0] + # num_groups can also be obtained by revealing segment_ids[-1] + start = time.perf_counter() + # how to produce a secret random array of shape (row_num, ) is another question (not addressed here). + p1_random_order = np.random.random((X.shape[0],)) + p2_random_order = np.random.random((X.shape[0],)) + secret_random_order = p1_random_order + p2_random_order + agg_result = spsim.sim_jax(sim, proc_sum_shuffle)( + target_cols, segment_end_marks, segment_ids, secret_random_order + ) + agg_result = groupby_agg_postprocess( + agg_result[0], agg_result[1], agg_result[2], num_groups + ) + end = time.perf_counter() + print("sum takes take", end - start) + assert ( + np.max( + abs(pandas_groupby_sum.values.reshape(agg_result.shape) - agg_result) + ) + < 0.001 + ), f"{pandas_groupby_sum}, ours: \n {agg_result}" + + correct_keys = list(pandas_groupby_sum.index.to_numpy()) + correct_keys = np.array([[*a] for a in correct_keys]) + + # we open shuffled keys and take set(keys) + start = time.perf_counter() + # how to produce a secret random array of shape (row_num, ) is another question (not addressed here). + p1_random_order = np.random.random((X.shape[0],)) + p2_random_order = np.random.random((X.shape[0],)) + secret_random_order = p1_random_order + p2_random_order + keys = spsim.sim_jax(sim, proc_view_key)( + keys, + segment_end_marks, + secret_random_order, + ) + + keys = view_key_postprocessing(keys, num_groups) + + end = time.perf_counter() + print("view key takes take", end - start) + assert ( + np.max(abs(correct_keys - keys)) < 0.001 + ), f"value{ max(abs(correct_keys - keys))}, correct_keys, {correct_keys}, keys{keys}" + + def test_count(self): + sim = spsim.Simulator.simple( + 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 + ) + + def proc(x1, x2, y): + return groupby([x1[:, 2], x2[:, 3]], [y]) + + # for count operation, we can open segment_ids and sorted_keys + + n_rows = 3000 + n_cols = 10 + x1 = np.random.random((n_rows, n_cols)) + x2 = np.random.random((n_rows, n_cols)) + y = np.random.random((n_rows,)) + # groupby category only supports discrete values + # here we are taking the shortcut, in reality, only the key need to be discrete. + # (their difference should be large enough so that their fxp repr are difference) + # we shrink the data value in order to reduce the size of groups. + # in our test data, we will get a group size about 15 + x1 = (x1 * 10 / 3).astype(int) + x2 = (x2 * 10 / 3).astype(int) + y = (y * 10 / 3).astype(int) + start = time.perf_counter() + keys, _, segment_ids, _ = spsim.sim_jax(sim, proc)(x1, x2, y) + end = time.perf_counter() + print("groupby takes time", end - start) + X = np.zeros((x1.shape[0], 3)) + X[:, 0] = x1[:, 2] + X[:, 1] = x2[:, 3] + X[:, 2] = y + df = pd.DataFrame( + X, + columns=[f'col{i}' for i in range(3)], + ) + # Perform group by sum using pandas + pandas_groupby_sum = df.groupby([df.columns[0], df.columns[1]])[ + df.columns[2] + ].count() + + start = time.perf_counter() + count_result = groupby_count(segment_ids) + end = time.perf_counter() + print("count takes take", end - start) + assert ( + np.max( + abs( + pandas_groupby_sum.values.reshape(count_result.shape) - count_result + ) + ) + < 0.001 + ), f"{pandas_groupby_sum}, ours: \n {count_result}" + + correct_keys = list(pandas_groupby_sum.index.to_numpy()) + correct_keys = np.array([[*a] for a in correct_keys]) + + # we open shuffled keys and take set(keys) + start = time.perf_counter() + # how to produce a secret random array of shape (row_num, ) is another question (not addressed here). + + keys = view_key_postprocessing(keys, segment_ids[-1] + 1) + + end = time.perf_counter() + print("view key takes take", end - start) + assert ( + np.max(abs(correct_keys - keys)) < 0.001 + ), f"value{ max(abs(correct_keys - keys))}, correct_keys, {correct_keys}, keys{keys}" + + def test_max(self): + sim = spsim.Simulator.simple( + 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 + ) + + def proc(x1, x2, y): + return groupby([x1[:, 2], x2[:, 3]], [y]) + + def proc_max_shuffle(cols, seg_end_marks, segment_ids, secret_random_order): + return groupby_max_via_shuffle( + cols, seg_end_marks, segment_ids, secret_random_order + ) + + def proc_view_key(key_cols, segment_end_marks, key): + return view_key(key_cols, segment_end_marks, key) + + n_rows = 30000 + n_cols = 10 + x1 = np.random.random((n_rows, n_cols)) + x2 = np.random.random((n_rows, n_cols)) + y = np.random.random((n_rows,)) + # groupby category only supports discrete values + # here we are taking the shortcut, in reality, only the key need to be discrete. + # (their difference should be large enough so that their fxp repr are difference) + # we shrink the data value in order to reduce the size of groups. + # in our test data, we will get a group size about 15 + x1 = (x1 * 10 / 3).astype(int) + x2 = (x2 * 10 / 3).astype(int) + y = (y * 10 / 3).astype(int) + start = time.perf_counter() + keys, target_cols, segment_ids, segment_end_marks = spsim.sim_jax(sim, proc)( + x1, x2, y + ) + end = time.perf_counter() + print("groupby takes time", end - start) + X = np.zeros((x1.shape[0], 3)) + X[:, 0] = x1[:, 2] + X[:, 1] = x2[:, 3] + X[:, 2] = y + df = pd.DataFrame( + X, + columns=[f'col{i}' for i in range(3)], + ) + # Perform group by sum using pandas + pandas_groupby_sum = df.groupby([df.columns[0], df.columns[1]])[ + df.columns[2] + ].max() + num_groups = pandas_groupby_sum.shape[0] + # num_groups can also be obtained by revealing segment_ids[-1] + start = time.perf_counter() + # how to produce a secret random array of shape (row_num, ) is another question (not addressed here). + p1_random_order = np.random.random((X.shape[0],)) + p2_random_order = np.random.random((X.shape[0],)) + secret_random_order = p1_random_order + p2_random_order + agg_result = spsim.sim_jax(sim, proc_max_shuffle)( + target_cols, segment_end_marks, segment_ids, secret_random_order + ) + agg_result = groupby_agg_postprocess( + agg_result[0], agg_result[1], agg_result[2], num_groups + ) + end = time.perf_counter() + print("sum takes take", end - start) + assert ( + np.max( + abs(pandas_groupby_sum.values.reshape(agg_result.shape) - agg_result) + ) + < 0.001 + ), f"{pandas_groupby_sum}, ours: \n {agg_result}" + + correct_keys = list(pandas_groupby_sum.index.to_numpy()) + correct_keys = np.array([[*a] for a in correct_keys]) + + # we open shuffled keys and take set(keys) + start = time.perf_counter() + # how to produce a secret random array of shape (row_num, ) is another question (not addressed here). + p1_random_order = np.random.random((X.shape[0],)) + p2_random_order = np.random.random((X.shape[0],)) + secret_random_order = p1_random_order + p2_random_order + keys = spsim.sim_jax(sim, proc_view_key)( + keys, + segment_end_marks, + secret_random_order, + ) + + keys = view_key_postprocessing(keys, num_groups) + + end = time.perf_counter() + print("view key takes take", end - start) + assert ( + np.max(abs(correct_keys - keys)) < 0.001 + ), f"value{ max(abs(correct_keys - keys))}, correct_keys, {correct_keys}, keys{keys}" + + +if __name__ == "__main__": + unittest.main() diff --git a/spu/version.py b/spu/version.py index 74f82508..182c0218 100644 --- a/spu/version.py +++ b/spu/version.py @@ -13,4 +13,4 @@ # limitations under the License. -__version__ = "0.4.2b3" +__version__ = "0.5.0b0.dev0"