diff --git a/.bazelrc b/.bazelrc index dc995af1..fcd46d22 100644 --- a/.bazelrc +++ b/.bazelrc @@ -59,24 +59,6 @@ build:macos --host_macos_minimum_os=11.0 build:linux --copt=-fopenmp build:linux --linkopt=-fopenmp -build:asan --strip=never -build:asan --copt -fno-sanitize-recover=all -build:asan --copt -fsanitize=address -build:asan --copt -Og -build:asan --copt -g -build:asan --copt -fno-omit-frame-pointer -build:asan --linkopt -fsanitize=address -build:asan --linkopt -static-libasan - -build:ubsan --strip=never -build:ubsan --copt -fno-sanitize-recover=all -build:ubsan --copt -fsanitize=undefined -build:ubsan --copt -Og -build:ubsan --copt -g -build:ubsan --copt -fno-omit-frame-pointer -build:ubsan --linkopt -fsanitize=undefined -build:ubsan --linkopt -static-libubsan - -build:macos-asan --features=asan -build:macos-ubsan --features=ubsan +build:asan --features=asan +build:ubsan --features=ubsan diff --git a/CHANGELOG.md b/CHANGELOG.md index 789ab4b1..52a5d94b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ - [API] Add intrinsic support - [Feature] Support half type - [Feature] Add Psi Progress +- [Feature] Add SineOp/CosineOp support ## 20230705 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 05bb0593..66fc7c96 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -58,7 +58,7 @@ docker exec -it spu-dev-$(whoami) bash #### Linux ```sh -Install gcc>=11.2, cmake>=3.18, ninja, nasm>=2.15, python==3.8, bazel==6.2.1, golang +Install gcc>=11.2, cmake>=3.18, ninja, nasm>=2.15, python>=3.8, bazel==6.2.1, golang python3 -m pip install -r requirements.txt python3 -m pip install -r requirements-dev.txt diff --git a/WORKSPACE b/WORKSPACE index afff4c93..cd0db8ac 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -38,6 +38,12 @@ rules_foreign_cc_dependencies( register_preinstalled_tools = True, ) +load("@rules_cuda//cuda:repositories.bzl", "register_detected_cuda_toolchains", "rules_cuda_dependencies") + +rules_cuda_dependencies() + +register_detected_cuda_toolchains() + load("@xla//:workspace4.bzl", "xla_workspace4") xla_workspace4() @@ -64,3 +70,9 @@ python_configure( name = "local_config_python", python_version = "3", ) + +load("@rules_proto_grpc//:repositories.bzl", "rules_proto_grpc_repos", "rules_proto_grpc_toolchains") + +rules_proto_grpc_toolchains() + +rules_proto_grpc_repos() diff --git a/bazel/patches/seal.patch b/bazel/patches/seal.patch index 5959e4ec..a48cba4d 100644 --- a/bazel/patches/seal.patch +++ b/bazel/patches/seal.patch @@ -67,7 +67,7 @@ index 31e07441..c34d0a45 100644 // Write the poly_modulus_degree. Note that it will always be positive. *param_data_ptr++ = static_cast(poly_modulus_degree_); -+ *param_data_ptr++ = static_cast(use_special_prime_); ++ *param_data_ptr++ = static_cast(use_special_prime_); for (const auto &mod : coeff_modulus_) { *param_data_ptr++ = mod.value(); @@ -118,3 +118,104 @@ index 9e1fbe48..eb71c4ac 100644 parms_id_type parms_id_ = parms_id_zero; }; } // namespace seal + +diff --git a/native/src/seal/evaluator.cpp b/native/src/seal/evaluator.cpp +index dabd3bab..afaa71dc 100644 +--- a/native/src/seal/evaluator.cpp ++++ b/native/src/seal/evaluator.cpp +@@ -2382,6 +2382,7 @@ namespace seal + size_t encrypted_size = encrypted.size(); + // Use key_context_data where permutation tables exist since previous runs. + auto galois_tool = context_.key_context_data()->galois_tool(); ++ bool is_ntt_form = encrypted.is_ntt_form(); + + // Size check + if (!product_fits_in(coeff_count, coeff_modulus_size)) +@@ -2412,7 +2413,7 @@ namespace seal + // DO NOT CHANGE EXECUTION ORDER OF FOLLOWING SECTION + // BEGIN: Apply Galois for each ciphertext + // Execution order is sensitive, since apply_galois is not inplace! +- if (parms.scheme() == scheme_type::bfv) ++ if (not is_ntt_form) + { + // !!! DO NOT CHANGE EXECUTION ORDER!!! + +@@ -2426,7 +2427,7 @@ namespace seal + // Next transform encrypted.data(1) + galois_tool->apply_galois(encrypted_iter[1], coeff_modulus_size, galois_elt, coeff_modulus, temp); + } +- else if (parms.scheme() == scheme_type::ckks || parms.scheme() == scheme_type::bgv) ++ else + { + // !!! DO NOT CHANGE EXECUTION ORDER!!! + +@@ -2440,10 +2441,6 @@ namespace seal + // Next transform encrypted.data(1) + galois_tool->apply_galois_ntt(encrypted_iter[1], coeff_modulus_size, galois_elt, temp); + } +- else +- { +- throw logic_error("scheme not implemented"); +- } + + // Wipe encrypted.data(1) + set_zero_poly(coeff_count, coeff_modulus_size, encrypted.data(1)); +@@ -2530,6 +2527,7 @@ namespace seal + auto &key_context_data = *context_.key_context_data(); + auto &key_parms = key_context_data.parms(); + auto scheme = parms.scheme(); ++ bool is_ntt_form = encrypted.is_ntt_form(); + + // Verify parameters. + if (!is_metadata_valid_for(encrypted, context_) || !is_buffer_valid(encrypted)) +@@ -2559,14 +2557,6 @@ namespace seal + { + throw invalid_argument("pool is uninitialized"); + } +- if (scheme == scheme_type::bfv && encrypted.is_ntt_form()) +- { +- throw invalid_argument("BFV encrypted cannot be in NTT form"); +- } +- if (scheme == scheme_type::ckks && !encrypted.is_ntt_form()) +- { +- throw invalid_argument("CKKS encrypted must be in NTT form"); +- } + if (scheme == scheme_type::bgv && !encrypted.is_ntt_form()) + { + throw invalid_argument("BGV encrypted must be in NTT form"); +@@ -2605,7 +2595,7 @@ namespace seal + set_uint(target_iter, decomp_modulus_size * coeff_count, t_target); + + // In CKKS or BGV, t_target is in NTT form; switch back to normal form +- if (scheme == scheme_type::ckks || scheme == scheme_type::bgv) ++ if (is_ntt_form) + { + inverse_ntt_negacyclic_harvey(t_target, decomp_modulus_size, key_ntt_tables); + } +@@ -2632,7 +2622,7 @@ namespace seal + ConstCoeffIter t_operand; + + // RNS-NTT form exists in input +- if ((scheme == scheme_type::ckks || scheme == scheme_type::bgv) && (I == J)) ++ if (is_ntt_form && (I == J)) + { + t_operand = target_iter[J]; + } +@@ -2789,7 +2779,7 @@ namespace seal + SEAL_ITERATE(t_ntt, coeff_count, [fix](auto &K) { K += fix; }); + + uint64_t qi_lazy = qi << 1; // some multiples of qi +- if (scheme == scheme_type::ckks) ++ if (is_ntt_form) + { + // This ntt_negacyclic_harvey_lazy results in [0, 4*qi). + ntt_negacyclic_harvey_lazy(t_ntt, get<2>(J)); +@@ -2802,7 +2792,7 @@ namespace seal + qi_lazy = qi << 2; + #endif + } +- else if (scheme == scheme_type::bfv) ++ else + { + inverse_ntt_negacyclic_harvey_lazy(get<0, 1>(J), get<2>(J)); + } \ No newline at end of file diff --git a/bazel/patches/xla.patch b/bazel/patches/xla.patch new file mode 100644 index 00000000..dbb5a1bc --- /dev/null +++ b/bazel/patches/xla.patch @@ -0,0 +1,22 @@ +diff --git a/third_party/tsl/workspace1.bzl b/third_party/tsl/workspace1.bzl +index 4cfb6da82..0e3774834 100644 +--- a/third_party/tsl/workspace1.bzl ++++ b/third_party/tsl/workspace1.bzl +@@ -3,7 +3,7 @@ + load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps") + load("@io_bazel_rules_closure//closure:defs.bzl", "closure_repositories") +-load("@rules_cuda//cuda:dependencies.bzl", "rules_cuda_dependencies") ++# load("@rules_cuda//cuda:dependencies.bzl", "rules_cuda_dependencies") + load("@rules_pkg//:deps.bzl", "rules_pkg_dependencies") + + # buildifier: disable=unnamed-macro +@@ -14,7 +14,7 @@ def workspace(with_rules_cc = True): + with_rules_cc: whether to load and patch rules_cc repository. + """ + native.register_toolchains("@local_config_python//:py_toolchain") +- rules_cuda_dependencies(with_rules_cc) ++ # rules_cuda_dependencies(with_rules_cc) + rules_pkg_dependencies() + + closure_repositories() \ No newline at end of file diff --git a/bazel/repositories.bzl b/bazel/repositories.bzl index 72a87b53..9bde00cc 100644 --- a/bazel/repositories.bzl +++ b/bazel/repositories.bzl @@ -21,6 +21,8 @@ SECRETFLOW_GIT = "https://github.com/secretflow" YACL_COMMIT_ID = "a9c1d7d119c80eb75d5ec63ee6cd77145dff18c2" def spu_deps(): + _rules_cuda() + _rules_proto_grpc() _bazel_platform() _upb() _com_github_xtensor_xtensor() @@ -61,6 +63,24 @@ def spu_deps(): path = "/opt/homebrew/opt/libomp/", ) +def _rules_proto_grpc(): + http_archive( + name = "rules_proto_grpc", + sha256 = "928e4205f701b7798ce32f3d2171c1918b363e9a600390a25c876f075f1efc0a", + strip_prefix = "rules_proto_grpc-4.4.0", + urls = [ + "https://github.com/rules-proto-grpc/rules_proto_grpc/releases/download/4.4.0/rules_proto_grpc-4.4.0.tar.gz", + ], + ) + +def _rules_cuda(): + http_archive( + name = "rules_cuda", + sha256 = "fa1462c4c3104de44489800a1da055f55afa57795789539c835e069818786f71", + strip_prefix = "rules_cuda-cab1fa2dd0e1f8489f566c91a5025856cf5ae572", + urls = ["https://github.com/bazel-contrib/rules_cuda/archive/cab1fa2dd0e1f8489f566c91a5025856cf5ae572.tar.gz"], + ) + def _bazel_platform(): http_archive( name = "platforms", @@ -138,8 +158,8 @@ def _com_github_xtensor_xtl(): ) def _com_github_openxla_xla(): - OPENXLA_COMMIT = "f23282956725050dd07996d3b80d6788ad8aaeba" - OPENXLA_SHA256 = "6e87093f550573de12af247edfa590d76adc6e931da6b551e78dc2c0c2bbd04d" + OPENXLA_COMMIT = "0c99beffabc5d43fa29f121674eb59e14a22c779" + OPENXLA_SHA256 = "d4c7511a496aeb917976c0d8a65de374c395546f0c3d4077d9dfd4df780d7ea8" SKYLIB_VERSION = "1.3.0" @@ -160,6 +180,10 @@ def _com_github_openxla_xla(): sha256 = OPENXLA_SHA256, strip_prefix = "xla-" + OPENXLA_COMMIT, type = ".tar.gz", + patch_args = ["-p1"], + patches = [ + "@spulib//bazel:patches/xla.patch", + ], urls = [ "https://github.com/openxla/xla/archive/{commit}.tar.gz".format(commit = OPENXLA_COMMIT), ], diff --git a/bazel/spu.bzl b/bazel/spu.bzl index 42c2efda..b96729f0 100644 --- a/bazel/spu.bzl +++ b/bazel/spu.bzl @@ -48,6 +48,7 @@ def spu_cc_binary( cc_binary( linkopts = linkopts, copts = copts + _spu_copts(), + linkstatic = True, **kargs ) @@ -66,6 +67,7 @@ def spu_cc_library( local_defines = local_defines + [ "SPU_BUILD", ], + linkstatic = True, **kargs ) @@ -111,5 +113,6 @@ def spu_cc_test( local_defines = local_defines + [ "SPU_BUILD", ], + linkstatic = True, **kwargs ) diff --git a/docs/reference/xla_status.md b/docs/reference/xla_status.md index 000f6fa1..e1cf948d 100644 --- a/docs/reference/xla_status.md +++ b/docs/reference/xla_status.md @@ -30,7 +30,7 @@ Count: Total = 4, fully supported = 2 | `ceil` | fully | | `convert` | fully | | `count_leading_zeros`| no | -| `cosine` | no | +| `cosine` | fully | | `exponential` | fully | | `exponential_minus_one`| fully | | `floor` | fully | @@ -46,11 +46,11 @@ Count: Total = 4, fully supported = 2 | `round_nearest_afz`| not | | `rsqrt` | fully | | `sign` | partial | -| `sine` | not | +| `sine` | fully | | `sqrt` | fully | | `tanh` | fully | -Count: Total = 24, fully supported = 12, partial = 0 +Count: Total = 24, fully supported = 16, partial = 1 ### XLA binary element-wise ops diff --git a/examples/python/ml/jax_lr/jax_lr.py b/examples/python/ml/jax_lr/jax_lr.py index 45cce6ad..19003cfd 100644 --- a/examples/python/ml/jax_lr/jax_lr.py +++ b/examples/python/ml/jax_lr/jax_lr.py @@ -138,11 +138,11 @@ def save_and_load_model(): print(W_r, b_r) x_test, y_test = dsutil.breast_cancer(slice(None, None, None), False) - print( - "AUC(save_and_load_model)={}".format( - metrics.roc_auc_score(y_test, predict(x_test, W_r, b_r)) - ) - ) + + score = metrics.roc_auc_score(y_test, predict(x_test, W_r, b_r)) + print("AUC(save_and_load_model)={}".format(score)) + + return score def compute_score(W_r, b_r, type): @@ -163,10 +163,7 @@ def train(x1, x2, y): x2, _ = ppd.device("P2")(dsutil.breast_cancer)(slice(15, None), True) W, b = train(x1, x2, y) - W_r, b_r = ppd.get(W), ppd.get(b) - print(W_r, b_r) - - return W_r, b_r + return W, b if __name__ == "__main__": @@ -176,5 +173,6 @@ def train(x1, x2, y): compute_score(w[1], b[1], 'cpu, manual_grad') print('Run on SPU\n------\n') w, b = run_on_spu() - compute_score(w, b, 'spu') + w_r, b_r = ppd.get(w), ppd.get(b) + compute_score(w_r, b_r, 'spu') save_and_load_model() diff --git a/examples/python/ml/ml_test.py b/examples/python/ml/ml_test.py index 807f466b..f9ae8a10 100644 --- a/examples/python/ml/ml_test.py +++ b/examples/python/ml/ml_test.py @@ -145,7 +145,7 @@ def test_jax_lr(self): from examples.python.ml.jax_lr import jax_lr w, b = profile_test_point(jax_lr.run_on_spu) - score = jax_lr.compute_score(w, b, 'spu') + score = jax_lr.compute_score(ppd.get(w), ppd.get(b), 'spu') self.assertGreater(score, 0.95) @@ -222,6 +222,13 @@ def test_torch_experiment(self): score = torch_experiment.run_inference_on_spu(model) self.assertGreater(score, 0.9) + def test_save_and_load_model(self): + from examples.python.ml.jax_lr import jax_lr + + score = jax_lr.save_and_load_model() + self.assertGreater(score, 0.9) + pass + def suite(): suite = unittest.TestSuite() @@ -239,6 +246,7 @@ def suite(): suite.addTest(UnitTests('test_stax_nn')) suite.addTest(UnitTests('test_tf_experiment')) suite.addTest(UnitTests('test_torch_experiment')) + suite.addTest(UnitTests('test_save_and_load_model')) return suite diff --git a/libspu/BUILD.bazel b/libspu/BUILD.bazel index 4007d4aa..34cd8159 100644 --- a/libspu/BUILD.bazel +++ b/libspu/BUILD.bazel @@ -14,6 +14,7 @@ load("@rules_proto//proto:defs.bzl", "proto_library") load("@rules_cc//cc:defs.bzl", "cc_proto_library") +load("@rules_proto_grpc//python:defs.bzl", "python_proto_compile") package(default_visibility = ["//visibility:public"]) @@ -26,3 +27,24 @@ cc_proto_library( name = "spu_cc_proto", deps = [":spu_proto"], ) + +python_proto_compile( + name = "spu_py_proto", + output_mode = "NO_PREFIX", + prefix_path = "..", + protos = ["//libspu:spu_proto"], +) + +python_proto_compile( + name = "psi_py_proto", + output_mode = "NO_PREFIX", + prefix_path = "..", + protos = ["//libspu/psi:psi_proto"], +) + +python_proto_compile( + name = "pir_py_proto", + output_mode = "NO_PREFIX", + prefix_path = "..", + protos = ["//libspu/pir:pir_proto"], +) diff --git a/libspu/compiler/passes/hlo_legalize_to_pphlo.cc b/libspu/compiler/passes/hlo_legalize_to_pphlo.cc index a0e2b054..8ecfd621 100644 --- a/libspu/compiler/passes/hlo_legalize_to_pphlo.cc +++ b/libspu/compiler/passes/hlo_legalize_to_pphlo.cc @@ -1569,6 +1569,8 @@ struct HloLegalizeToPPHlo HloToPPHloOpConverter, HloToPPHloOpConverter, HloToPPHloOpConverter, + HloToPPHloOpConverter, + HloToPPHloOpConverter, HloToPPHloOpConverter, HloToPPHloOpConverter, HloToPPHloOpConverter, diff --git a/libspu/compiler/passes/map_stablehlo_to_pphlo_op.h b/libspu/compiler/passes/map_stablehlo_to_pphlo_op.h index 40ce514c..86650227 100644 --- a/libspu/compiler/passes/map_stablehlo_to_pphlo_op.h +++ b/libspu/compiler/passes/map_stablehlo_to_pphlo_op.h @@ -83,6 +83,8 @@ MAP_HLO_TO_PPHLO(SelectOp) MAP_HLO_TO_PPHLO(ShiftLeftOp) MAP_HLO_TO_PPHLO(ShiftRightArithmeticOp) MAP_HLO_TO_PPHLO(ShiftRightLogicalOp) +MAP_HLO_TO_PPHLO(SineOp) +MAP_HLO_TO_PPHLO(CosineOp) MAP_HLO_TO_PPHLO(SignOp) MAP_HLO_TO_PPHLO(SliceOp) MAP_HLO_TO_PPHLO(SortOp) diff --git a/libspu/core/config.cc b/libspu/core/config.cc index 03bb32d1..ae0dfe40 100644 --- a/libspu/core/config.cc +++ b/libspu/core/config.cc @@ -78,6 +78,10 @@ void populateRuntimeConfig(RuntimeConfig& cfg) { } } + if (cfg.sine_cosine_iters() == 0) { + cfg.set_sine_cosine_iters(10); // Default + } + // inter op concurrency if (cfg.experimental_enable_inter_op_par()) { cfg.set_experimental_inter_op_concurrency( diff --git a/libspu/core/ndarray_ref.h b/libspu/core/ndarray_ref.h index ffa22b72..7d3fb28f 100644 --- a/libspu/core/ndarray_ref.h +++ b/libspu/core/ndarray_ref.h @@ -401,6 +401,8 @@ class NdArrayView { } } + int64_t numel() const { return arr_->numel(); } + T& operator[](size_t idx) { if (arr_->canUseFastIndexing()) { return *reinterpret_cast(arr_->data() + diff --git a/libspu/device/BUILD.bazel b/libspu/device/BUILD.bazel index e0625bc5..f3dc3705 100644 --- a/libspu/device/BUILD.bazel +++ b/libspu/device/BUILD.bazel @@ -13,8 +13,6 @@ # limitations under the License. load("//bazel:spu.bzl", "spu_cc_library", "spu_cc_test") -load("@rules_proto//proto:defs.bzl", "proto_library") -load("@rules_cc//cc:defs.bzl", "cc_proto_library") package(default_visibility = ["//visibility:public"]) diff --git a/libspu/device/pphlo/pphlo_executor.cc b/libspu/device/pphlo/pphlo_executor.cc index 4f5f8d40..a3520afe 100644 --- a/libspu/device/pphlo/pphlo_executor.cc +++ b/libspu/device/pphlo/pphlo_executor.cc @@ -237,6 +237,8 @@ STANDARD_UNARY_OP_EXEC_IMPL(NotOp, Not) STANDARD_UNARY_OP_EXEC_IMPL(RsqrtOp, Rsqrt) STANDARD_UNARY_OP_EXEC_IMPL(SqrtOp, Sqrt) STANDARD_UNARY_OP_EXEC_IMPL(RoundOp, Round_AFZ) +STANDARD_UNARY_OP_EXEC_IMPL(SineOp, Sine) +STANDARD_UNARY_OP_EXEC_IMPL(CosineOp, Cosine) #undef STANDARD_UNARY_OP_EXEC_IMPL diff --git a/libspu/device/pphlo/pphlo_executor_test.cc b/libspu/device/pphlo/pphlo_executor_test.cc index 00373049..f397d152 100644 --- a/libspu/device/pphlo/pphlo_executor_test.cc +++ b/libspu/device/pphlo/pphlo_executor_test.cc @@ -424,6 +424,64 @@ func.func @main(%arg0: tensor<4x6x!pphlo.pub>) -> (tensor<2x2x!pphlo.pub(GetParam()), std::get<1>(GetParam()), + std::get<2>(GetParam())); + + const xt::xarray in1 = {{1, 2}, {3, 4}, {5, 6}}; + r.addInput(in1); + + r.run(r.compileMHlo(R"( +func.func @main(%arg0: tensor<3x2xi32>) -> (tensor<2x2xi32>) { + %0 = "mhlo.constant"() {value = dense<0> : tensor} : () -> tensor + %1 = "mhlo.reduce_window"(%arg0, %0) ( { + ^bb0(%arg1: tensor, %arg2: tensor): // no predecessors + %2 = "mhlo.maximum"(%arg1, %arg2) : (tensor, tensor) -> tensor + "mhlo.return"(%2) : (tensor) -> () + }) { + base_dilations = dense<[2, 1]> : tensor<2xi64>, + padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>, + window_dilations = dense<[3, 1]> : tensor<2xi64>, + window_dimensions = dense<[2, 1]> : tensor<2xi64>, + window_strides = dense<[ 4, 1 ]> : tensor<2xi64> + } : (tensor<3x2xi32>, tensor) -> tensor<2x2xi32> + return %1 : tensor<2x2xi32> +})", + {Visibility::VIS_PUBLIC})); + + xt::xarray expect = {{0, 0}, {3, 4}}; + r.verifyOutput(expect.data()); +} + +TEST_P(ExecutorTest, ReduceWindowStableHloTest2) { + Runner r(std::get<0>(GetParam()), std::get<1>(GetParam()), + std::get<2>(GetParam())); + + const xt::xarray in1 = {{1, 2}, {3, 4}, {5, 6}}; + r.addInput(in1); + + r.run(r.compileMHlo(R"( +func.func @main(%arg0: tensor<3x2xi32>) -> (tensor<1x2xi32>) { + %0 = "mhlo.constant"() {value = dense<0> : tensor} : () -> tensor + %1 = "mhlo.reduce_window"(%arg0, %0) ( { + ^bb0(%arg1: tensor, %arg2: tensor): // no predecessors + %2 = "mhlo.maximum"(%arg1, %arg2) : (tensor, tensor) -> tensor + "mhlo.return"(%2) : (tensor) -> () + }) { + base_dilations = dense<[2, 1]> : tensor<2xi64>, + padding = dense<[[2, 1], [0, 0]]> : tensor<2x2xi64>, + window_dilations = dense<[3, 1]> : tensor<2xi64>, + window_dimensions = dense<[3, 1]> : tensor<2xi64>, + window_strides = dense<[4, 1]> : tensor<2xi64> + } : (tensor<3x2xi32>, tensor) -> tensor<1x2xi32> + return %1 : tensor<1x2xi32> +})", + {Visibility::VIS_PUBLIC})); + + xt::xarray expect = {{5, 6}}; + r.verifyOutput(expect.data()); +} + TEST_P(ExecutorTest, ReduceWindowDefaultStrides) { Runner r(std::get<0>(GetParam()), std::get<1>(GetParam()), std::get<2>(GetParam())); diff --git a/libspu/device/pphlo/pphlo_verifier.cc b/libspu/device/pphlo/pphlo_verifier.cc index cfd09fa6..60537a5c 100644 --- a/libspu/device/pphlo/pphlo_verifier.cc +++ b/libspu/device/pphlo/pphlo_verifier.cc @@ -263,6 +263,8 @@ UNARY_VERIFIER(FloorOp, evalFloorOp) UNARY_VERIFIER(CeilOp, evalCeilOp) UNARY_VERIFIER(LogisticOp, evalLogisticOp) UNARY_VERIFIER(TanhOp, evalTanhOp) +UNARY_VERIFIER(SineOp, evalSineOp) +UNARY_VERIFIER(CosineOp, evalCosineOp) UNARY_VERIFIER(NotOp, evalNotOp) UNARY_VERIFIER(ExpOp, evalExponentialOp) UNARY_VERIFIER(RsqrtOp, evalRsqrtOp) diff --git a/libspu/device/pphlo/pphlo_verifier.h b/libspu/device/pphlo/pphlo_verifier.h index 5e35980a..c0ad262c 100644 --- a/libspu/device/pphlo/pphlo_verifier.h +++ b/libspu/device/pphlo/pphlo_verifier.h @@ -49,6 +49,8 @@ class PPHloVerifier { VERIFY_DECL(CeilOp) VERIFY_DECL(LogisticOp) VERIFY_DECL(TanhOp) + VERIFY_DECL(SineOp) + VERIFY_DECL(CosineOp) VERIFY_DECL(NotOp) VERIFY_DECL(ExpOp) VERIFY_DECL(Expm1Op) diff --git a/libspu/dialect/pphlo_ops.td b/libspu/dialect/pphlo_ops.td index cb933264..eb3985be 100644 --- a/libspu/dialect/pphlo_ops.td +++ b/libspu/dialect/pphlo_ops.td @@ -137,6 +137,28 @@ def PPHLO_TanhOp }]; } +def PPHLO_SineOp + : PPHLO_UnaryElementwiseOp<"sine", [Pure, SameOperandsAndResultType], PPHLO_FpTensor> { + let summary = "sine operator"; + let description = [{ + Returns `sine(operand)` element-wise. + + See + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#sine. + }]; +} + +def PPHLO_CosineOp + : PPHLO_UnaryElementwiseOp<"cosine", [Pure, SameOperandsAndResultType], PPHLO_FpTensor> { + let summary = "cosine operator"; + let description = [{ + Returns `cosine(operand)` element-wise. + + See + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#cosine. + }]; +} + def PPHLO_NegOp : PPHLO_UnaryElementwiseOp<"negate", [Pure, SameOperandsAndResultType], PPHLO_Tensor> { let summary = "Negation operator"; diff --git a/libspu/kernel/hal/constants.cc b/libspu/kernel/hal/constants.cc index 9485bf08..c5c8bf04 100644 --- a/libspu/kernel/hal/constants.cc +++ b/libspu/kernel/hal/constants.cc @@ -64,7 +64,7 @@ bool kCastFlags[DT_F64+1][DT_F64+1] = { // NOLINTEND // clang-format on -bool canImplicitCastTo(DataType frm, DataType to) { +[[maybe_unused]] bool canImplicitCastTo(DataType frm, DataType to) { return kCastFlags[frm][to]; } @@ -74,13 +74,6 @@ Value constant(SPUContext* ctx, PtBufferView init, DataType dtype, const Shape& shape) { SPU_TRACE_HAL_DISP(ctx, init, dtype, shape); - const auto init_dtype = getEncodeType(init.pt_type); - - // FIXME: semantically, this is casting from literal to type, not from one - // type to another type. - SPU_ENFORCE(canImplicitCastTo(init_dtype, dtype), "cast from {} to {} failed", - init_dtype, dtype); - auto result = make_pub2k(ctx, init).setDtype(dtype, true); if (shape.numel() == 0) { diff --git a/libspu/kernel/hal/fxp_approx.cc b/libspu/kernel/hal/fxp_approx.cc index 6b03a7b1..edb54e83 100644 --- a/libspu/kernel/hal/fxp_approx.cc +++ b/libspu/kernel/hal/fxp_approx.cc @@ -615,4 +615,56 @@ Value f_sigmoid(SPUContext* ctx, const Value& x) { } } +// Ref: +// https://github.com/facebookresearch/CrypTen/blob/f07c6b7e7a9d2e6ddce260929ef9aaec4addce3a/crypten/common/functions/approximations.py#L224 +std::pair sine_cosine_apprix(SPUContext* ctx, const Value& x) { + int64_t iteration = ctx->config().sine_cosine_iters(); + + // Normalize input to [-pi, pi] + // theta - TWO_PI * Math.floor((theta + Math.PI) / TWO_PI) + auto pi = constant(ctx, M_PI, x.dtype(), x.shape()); + auto two_pi = constant(ctx, 2 * M_PI, x.dtype(), x.shape()); + auto normalized = f_div(ctx, f_add(ctx, x, pi), two_pi); + normalized = f_mul(ctx, f_floor(ctx, normalized), two_pi); + normalized = f_sub(ctx, x, normalized); + + auto re = constant(ctx, 1.0F, x.dtype(), x.shape()); + auto im = + f_div(ctx, normalized, + constant(ctx, std::pow(2.0F, iteration), x.dtype(), x.shape())); + for (int64_t idx = 0; idx < iteration; ++idx) { + auto a2 = f_square(ctx, re); + auto b2 = f_square(ctx, im); + im = f_mul(ctx, im, re); + im = _mul(ctx, im, constant(ctx, 2, DT_I32, x.shape())).setDtype(x.dtype()); + re = f_sub(ctx, a2, b2); + } + + return {re, im}; +} + +Value f_sine(SPUContext* ctx, const Value& x) { + SPU_TRACE_HAL_DISP(ctx, x); + + SPU_ENFORCE(x.isFxp()); + + if (x.isPublic()) { + return f_sine_p(ctx, x); + } + + return sine_cosine_apprix(ctx, x).second; +} + +Value f_cosine(SPUContext* ctx, const Value& x) { + SPU_TRACE_HAL_DISP(ctx, x); + + SPU_ENFORCE(x.isFxp()); + + if (x.isPublic()) { + return f_cosine_p(ctx, x); + } + + return sine_cosine_apprix(ctx, x).first; +} + } // namespace spu::kernel::hal diff --git a/libspu/kernel/hal/fxp_approx.h b/libspu/kernel/hal/fxp_approx.h index c9179b05..9a7648f5 100644 --- a/libspu/kernel/hal/fxp_approx.h +++ b/libspu/kernel/hal/fxp_approx.h @@ -49,6 +49,10 @@ Value f_exp2(SPUContext* ctx, const Value& x); Value f_tanh(SPUContext* ctx, const Value& x); +Value f_sine(SPUContext* ctx, const Value& x); + +Value f_cosine(SPUContext* ctx, const Value& x); + Value f_rsqrt(SPUContext* ctx, const Value& x); Value f_sqrt(SPUContext* ctx, const Value& x); diff --git a/libspu/kernel/hal/fxp_approx_test.cc b/libspu/kernel/hal/fxp_approx_test.cc index df4269ae..94df626f 100644 --- a/libspu/kernel/hal/fxp_approx_test.cc +++ b/libspu/kernel/hal/fxp_approx_test.cc @@ -310,4 +310,66 @@ TEST(FxpTest, Sqrt) { } } +TEST(FxpTest, Sine) { + // GIVEN + SPUContext ctx = test::makeSPUContext(); + + xt::xarray x = {0., 0.52359878, 0.78539816, 1.04719755, 1.57079633}; + // public sin + { + Value a = constant(&ctx, x, DT_F32); + Value c = f_sine(&ctx, a); + EXPECT_EQ(c.dtype(), DT_F32); + + auto y = dump_public_as(&ctx, c); + EXPECT_TRUE(xt::allclose(xt::sin(x), y, 0.01, 0.001)) + << xt::sin(x) << std::endl + << y; + } + + // secret sin + { + Value a = test::makeValue(&ctx, x, VIS_SECRET); + Value c = f_sine(&ctx, a); + EXPECT_EQ(c.dtype(), DT_F32); + + auto y = dump_public_as(&ctx, reveal(&ctx, c)); + // low precision + EXPECT_TRUE(xt::allclose(xt::sin(x), y, 0.01, 0.001)) + << xt::sin(x) << std::endl + << y; + } +} + +TEST(FxpTest, Cosine) { + // GIVEN + SPUContext ctx = test::makeSPUContext(); + + xt::xarray x = {0., 0.52359878, 0.78539816, 1.04719755, 1.57079633}; + // public cos + { + Value a = constant(&ctx, x, DT_F32); + Value c = f_cosine(&ctx, a); + EXPECT_EQ(c.dtype(), DT_F32); + + auto y = dump_public_as(&ctx, c); + EXPECT_TRUE(xt::allclose(xt::cos(x), y, 0.01, 0.001)) + << xt::cos(x) << std::endl + << y; + } + + // secret cos + { + Value a = test::makeValue(&ctx, x, VIS_SECRET); + Value c = f_cosine(&ctx, a); + EXPECT_EQ(c.dtype(), DT_F32); + + auto y = dump_public_as(&ctx, reveal(&ctx, c)); + // low precision + EXPECT_TRUE(xt::allclose(xt::cos(x), y, 0.01, 0.001)) + << xt::cos(x) << std::endl + << y; + } +} + } // namespace spu::kernel::hal diff --git a/libspu/kernel/hal/fxp_base.cc b/libspu/kernel/hal/fxp_base.cc index eea3ee40..0a2ace18 100644 --- a/libspu/kernel/hal/fxp_base.cc +++ b/libspu/kernel/hal/fxp_base.cc @@ -318,7 +318,7 @@ Value f_less(SPUContext* ctx, const Value& x, const Value& y) { Value f_square(SPUContext* ctx, const Value& x) { SPU_TRACE_HAL_LEAF(ctx, x); - SPU_ENFORCE(x.isFxp()); + SPU_ENFORCE(x.isFxp(), "{}", x); return _trunc(ctx, _mul(ctx, x, x), ctx->getFxpBits(), SignType::Positive) .setDtype(x.dtype()); } diff --git a/libspu/kernel/hal/fxp_cleartext.cc b/libspu/kernel/hal/fxp_cleartext.cc index eeb33ffb..7dc5e936 100644 --- a/libspu/kernel/hal/fxp_cleartext.cc +++ b/libspu/kernel/hal/fxp_cleartext.cc @@ -101,4 +101,14 @@ Value f_div_p(SPUContext* ctx, const Value& x, const Value& y) { [](float a, float b) { return a / b; }); } +Value f_sine_p(SPUContext* ctx, const Value& in) { + SPU_TRACE_HAL_DISP(ctx, in); + return applyFloatingPointFn(ctx, in, [](float x) { return std::sin(x); }); +} + +Value f_cosine_p(SPUContext* ctx, const Value& in) { + SPU_TRACE_HAL_DISP(ctx, in); + return applyFloatingPointFn(ctx, in, [](float x) { return std::cos(x); }); +} + } // namespace spu::kernel::hal diff --git a/libspu/kernel/hal/fxp_cleartext.h b/libspu/kernel/hal/fxp_cleartext.h index 3a09aa47..e9a75430 100644 --- a/libspu/kernel/hal/fxp_cleartext.h +++ b/libspu/kernel/hal/fxp_cleartext.h @@ -31,4 +31,8 @@ Value f_exp_p(SPUContext* ctx, const Value& in); Value f_div_p(SPUContext* ctx, const Value& x, const Value& y); +Value f_sine_p(SPUContext* ctx, const Value& x); + +Value f_cosine_p(SPUContext* ctx, const Value& x); + } // namespace spu::kernel::hal diff --git a/libspu/kernel/hal/polymorphic.cc b/libspu/kernel/hal/polymorphic.cc index 679dbd5a..6d688c71 100644 --- a/libspu/kernel/hal/polymorphic.cc +++ b/libspu/kernel/hal/polymorphic.cc @@ -459,6 +459,22 @@ Value tanh(SPUContext* ctx, const Value& x) { return f_tanh(ctx, x); } +Value sine(SPUContext* ctx, const Value& x) { + SPU_TRACE_HAL_DISP(ctx, x); + + SPU_ENFORCE(x.isFxp()); + + return f_sine(ctx, x); +} + +Value cosine(SPUContext* ctx, const Value& x) { + SPU_TRACE_HAL_DISP(ctx, x); + + SPU_ENFORCE(x.isFxp()); + + return f_cosine(ctx, x); +} + Value rsqrt(SPUContext* ctx, const Value& x) { SPU_TRACE_HAL_DISP(ctx, x); diff --git a/libspu/kernel/hal/polymorphic.h b/libspu/kernel/hal/polymorphic.h index 13e7abe4..35673a51 100644 --- a/libspu/kernel/hal/polymorphic.h +++ b/libspu/kernel/hal/polymorphic.h @@ -200,6 +200,14 @@ Value exp2(SPUContext* ctx, const Value& x); // @param in, the input value Value tanh(SPUContext* ctx, const Value& x); +/// element-wise sine, i.e. x -> sin(x) +// @param in, the input value +Value sine(SPUContext* ctx, const Value& x); + +/// element-wise cosine, i.e. x -> cos(x) +// @param in, the input value +Value cosine(SPUContext* ctx, const Value& x); + /// element-wise reciprocal of square root operation, i.e. x - > 1.0 / sqrt(x) // @param in, the input value Value rsqrt(SPUContext* ctx, const Value& x); diff --git a/libspu/kernel/hlo/basic_unary.cc b/libspu/kernel/hlo/basic_unary.cc index c973883a..b14b13de 100644 --- a/libspu/kernel/hlo/basic_unary.cc +++ b/libspu/kernel/hlo/basic_unary.cc @@ -39,6 +39,8 @@ SIMPLE_UNARY_KERNEL_DEFN(Logistic, hal::logistic) SIMPLE_UNARY_KERNEL_DEFN(Tanh, hal::tanh) SIMPLE_UNARY_KERNEL_DEFN(Rsqrt, hal::rsqrt) SIMPLE_UNARY_KERNEL_DEFN(Sqrt, hal::sqrt) +SIMPLE_UNARY_KERNEL_DEFN(Sine, hal::sine) +SIMPLE_UNARY_KERNEL_DEFN(Cosine, hal::cosine) #undef SIMPLE_UNARY_KERNEL_DEFN diff --git a/libspu/kernel/hlo/basic_unary.h b/libspu/kernel/hlo/basic_unary.h index 67e3188b..8048cb95 100644 --- a/libspu/kernel/hlo/basic_unary.h +++ b/libspu/kernel/hlo/basic_unary.h @@ -32,6 +32,8 @@ SIMPLE_UNARY_KERNEL_DECL(Ceil) SIMPLE_UNARY_KERNEL_DECL(Abs) SIMPLE_UNARY_KERNEL_DECL(Logistic) SIMPLE_UNARY_KERNEL_DECL(Tanh) +SIMPLE_UNARY_KERNEL_DECL(Sine) +SIMPLE_UNARY_KERNEL_DECL(Cosine) SIMPLE_UNARY_KERNEL_DECL(Not) SIMPLE_UNARY_KERNEL_DECL(Rsqrt) SIMPLE_UNARY_KERNEL_DECL(Sqrt) diff --git a/libspu/kernel/hlo/convolution.cc b/libspu/kernel/hlo/convolution.cc index 62bcdb06..4dd7f56b 100644 --- a/libspu/kernel/hlo/convolution.cc +++ b/libspu/kernel/hlo/convolution.cc @@ -57,38 +57,6 @@ spu::Value Convolution2D(SPUContext *ctx, const spu::Value &input, SPU_ENFORCE_EQ(hh, (H - h) / sh + 1); SPU_ENFORCE_EQ(ww, (W - w) / sw + 1); - if (ctx->config().protocol() == ProtocolKind::CHEETAH && // - input.isSecret() && kernel.isSecret()) { - // NOTE(juhou): ad-hoc optimization for the current 2PC conv2d - // implementation. When N is large or small kernel size, it would - // be better to compute im2col because the current conv2d implementation - // needs `N` iterations to handle batched input. - if (N <= h * w) { - // ad-hoc optimization for strided conv2d when h=1 - auto in = input; - { - Strides strides = {1, 1, 1, 1}; - if (kernel.shape()[0] == 1) { - strides[1] = sh; - } - if (kernel.shape()[1] == 1) { - strides[2] = sw; - } - - if (std::any_of(strides.begin(), strides.end(), - [](int64_t s) { return s > 1; })) { - const Index start = {0, 0, 0, 0}; - const Index end = Index(in.shape().begin(), in.shape().end()); - in = hal::slice(ctx, in, start, end, strides); - sh = 1; - sw = 1; - } - } - - return hal::conv2d(ctx, in, kernel, {sh, sw}); - } - } - // Fallback, use im2col + dot to implement convolution { // expand the image according to the kernel size. diff --git a/libspu/kernel/hlo/reduce.cc b/libspu/kernel/hlo/reduce.cc index 1881fd8a..b3280b35 100644 --- a/libspu/kernel/hlo/reduce.cc +++ b/libspu/kernel/hlo/reduce.cc @@ -210,55 +210,73 @@ std::vector ReduceWindowImpl( SPU_ENFORCE(!last_operand_is_window_mask); const int64_t ndims = inputs[0].shape().size(); - std::vector window_index(ndims, 0); int64_t nargs = inputs.size(); - // Init... - std::vector rets(nargs); - for (int64_t idx = 0; idx < nargs; ++idx) { - rets[idx] = hal::expand(ctx, init_values[idx], ret_shape); + SPU_ENFORCE_EQ(ndims, static_cast(config.window_padding.size())); + SPU_ENFORCE_EQ(ndims, static_cast(config.window_dilations.size())); + + Sizes padding_lo(ndims); + Sizes padding_hi(ndims); + Sizes padding_in(ndims); + + for (size_t idx = 0; idx < config.window_padding.size(); idx++) { + padding_lo[idx] = config.window_padding[idx].first; + padding_hi[idx] = config.window_padding[idx].second; + padding_in[idx] = config.base_dilations[idx] - 1; } - // For each resulting dimension, calculate and assign computed value. - auto evaluate_impl = - [&](absl::Span output_index) -> std::vector { - std::vector ret; - RunOnWindowIndex( - config.window_shape, config.window_strides, config.window_dilations, - config.window_padding, inputs[0].shape(), config.base_dilations, - output_index, window_index, [&](const Index &operand_index) { - for (int64_t idx = 0; idx < nargs; ++idx) { - auto element = - hal::slice_scalar_at(ctx, inputs[idx], operand_index); - ret.emplace_back(std::move(element)); - } - }); - return ret; - }; - - // For each window index - std::vector batches(nargs); + const Strides &S = config.window_strides; + const Shape &W = config.window_shape; + + // padding + std::vector padded_inputs; for (int64_t idx = 0; idx < nargs; ++idx) { - batches[idx] = - hal::expand(ctx, hal::slice_scalar_at(ctx, inputs[idx], {}), ret_shape); + padded_inputs.emplace_back(hal::pad(ctx, inputs[idx], init_values[idx], + padding_lo, padding_hi, padding_in)); } + // iterate windows to reduce + // reduce dims + const auto in_shape = padded_inputs[0].shape(); + Axes reduce_dims(in_shape.size(), 0); + std::iota(reduce_dims.begin(), reduce_dims.end(), 0); + + Index window_index(ndims, 0); + std::vector> reduced_rets; do { - // Collect one element from each window - Index output_index(ret_shape.size(), 0); - do { - auto r = evaluate_impl(output_index); - if (!r.empty()) { - for (int64_t idx = 0; idx < nargs; ++idx) { - batches[idx].data().update_slice(r[idx].data(), output_index); - } - } - } while (bumpIndices(ret_shape, absl::MakeSpan(output_index))); + Index start(ndims); + Index end(ndims); + for (int64_t dim = 0; dim < ndims; dim++) { + start[dim] = window_index[dim] * S[dim]; + end[dim] = start[dim] + W[dim] + + (W[dim] - 1) * (config.window_dilations[dim] - 1); + } + std::vector windows(nargs); + for (int64_t idx = 0; idx < nargs; ++idx) { + windows[idx] = hal::slice(ctx, padded_inputs[idx], start, end, + (Strides)config.window_dilations); + } + reduced_rets.emplace_back( + Reduce(ctx, windows, init_values, reduce_dims, reducer)); + } while (bumpIndices(ret_shape, absl::MakeSpan(window_index))); + + SPU_ENFORCE_EQ(static_cast(reduced_rets.size()), ret_shape.numel()); + + std::vector rets; + for (int64_t input_idx = 0; input_idx < nargs; ++input_idx) { + std::vector reduced_values; - // Now run the batch - rets = reducer(rets, batches); + for (size_t widx = 0; widx < reduced_rets.size(); ++widx) { + Shape new_shape = reduced_rets[widx][input_idx].shape(); + new_shape.insert(new_shape.begin(), 1); - } while (bumpIndices(config.window_shape, absl::MakeSpan(window_index))); + reduced_values.emplace_back( + hal::reshape(ctx, reduced_rets[widx][input_idx], new_shape)); + } + + rets.emplace_back( + hal::reshape(ctx, hal::concatenate(ctx, reduced_values, 0), ret_shape)); + } return rets; } diff --git a/libspu/kernel/hlo/utils.h b/libspu/kernel/hlo/utils.h index b94436f0..0e9f67aa 100644 --- a/libspu/kernel/hlo/utils.h +++ b/libspu/kernel/hlo/utils.h @@ -67,70 +67,6 @@ void forEachIndex(absl::Span shape, /*count=*/shape, incr, visitor_function); } -// return false if it's in padding or dilation area -inline bool getBaseIndexFromWindowIndex( - absl::Span window_shape, - absl::Span window_strides, - absl::Span window_dilation, - absl::Span> window_padding, - absl::Span base_shape, - absl::Span base_dilation, - absl::Span window_count_index, - absl::Span window_index, - absl::Span base_index // out parameter -) { - const int64_t ndim = base_shape.size(); - - for (int64_t dim = 0; dim < ndim; ++dim) { - // Padding is applied to the dilated base. Say that padding is 3 and - // dilation is 2 for some dimension. After applying base dilation and - // padding, the dimension looks like: - // P P P E D D E D D ... E D D E P P P - // where E are the elements and D are the holes. So, the elements are - // located in indices: padding + k*base_dilation for k = {0, 1, 2, ...}. - // We are accessing elements in the transformed base at indices: - // window_count_index * stride + window_index * window_dilation. - // Solving for k gives us - // (win_count_i * stride + win_i * win_dilation - pad) / base_dilation - // When this is a natural number, we index an original element. - // Otherwise, we index a 0 (pad or hole), and we don't need to apply - // the callback f. - base_index[dim] = window_count_index[dim] * window_strides[dim] + - window_index[dim] * window_dilation[dim] - - window_padding[dim].first; - if (base_index[dim] % base_dilation[dim] != 0) { - // out of bound - return true; - } - base_index[dim] /= base_dilation[dim]; - if (base_index[dim] < 0 || base_index[dim] >= base_shape[dim]) { - // out of bound - return true; - } - } - return false; -} - -inline void RunOnWindowIndex( - absl::Span window_shape, - absl::Span window_strides, - absl::Span window_dilation, - absl::Span> window_padding, - absl::Span base_shape, - absl::Span base_dilation, - absl::Span window_count_index, - absl::Span window_index, - const std::function &f) { - Index base_index(base_shape.size()); - bool out_of_bound = getBaseIndexFromWindowIndex( - window_shape, window_strides, window_dilation, window_padding, base_shape, - base_dilation, window_count_index, window_index, - absl::MakeSpan(base_index)); - if (!out_of_bound) { - f(base_index); - } -} - /// Expand the base according to window // // let base = (B0, B1, ..., Bn) diff --git a/libspu/mpc/cheetah/BUILD.bazel b/libspu/mpc/cheetah/BUILD.bazel index 0c68e670..e51244c1 100644 --- a/libspu/mpc/cheetah/BUILD.bazel +++ b/libspu/mpc/cheetah/BUILD.bazel @@ -40,13 +40,15 @@ spu_cc_library( spu_cc_library( name = "boolean", - srcs = ["boolean.cc"], + srcs = [ + "boolean.cc", + "boolean_semi2k.cc", + ], hdrs = ["boolean.h"], copts = AES_COPT_FLAGS, deps = [ ":state", ":type", - "//libspu/mpc/semi2k:boolean", ], ) @@ -58,19 +60,24 @@ spu_cc_library( deps = [ ":state", ":type", - "//libspu/mpc/semi2k:conversion", + "//libspu/core:vectorize", + "//libspu/mpc:ab_api", + "//libspu/mpc:kernel", + "//libspu/mpc/common:communicator", ], ) spu_cc_library( name = "arithmetic", - srcs = ["arithmetic.cc"], + srcs = [ + "arithmetic.cc", + "arithmetic_semi2k.cc", + ], hdrs = ["arithmetic.h"], copts = AES_COPT_FLAGS, deps = [ ":state", ":type", - "//libspu/mpc/semi2k:arithmetic", ], ) @@ -123,37 +130,17 @@ spu_cc_library( hdrs = ["io.h"], deps = [ ":type", - "//libspu/mpc/semi2k:io", + "//libspu/mpc:io_interface", "//libspu/mpc/utils:ring_ops", ], ) spu_cc_library( name = "type", + srcs = ["type.cc"], + hdrs = ["type.h"], deps = [ - "//libspu/mpc/semi2k:type", - ], -) - -spu_cc_library( - name = "array_ref", - srcs = ["array_ref.cc"], - hdrs = ["array_ref.h"], - deps = [ - "//libspu/core:bit_utils", - "//libspu/core:ndarray_ref", - "//libspu/core:parallel_utils", - "//libspu/core:prelude", "//libspu/core:type", - "//libspu/core:vectorize", - "@yacl//yacl/base:buffer", - ], -) - -spu_cc_test( - name = "array_ref_test", - srcs = ["array_ref_test.cc"], - deps = [ - ":array_ref", + "//libspu/mpc/common:pv2k", ], ) diff --git a/libspu/mpc/cheetah/arith/BUILD.bazel b/libspu/mpc/cheetah/arith/BUILD.bazel index b824313c..7e6fc334 100644 --- a/libspu/mpc/cheetah/arith/BUILD.bazel +++ b/libspu/mpc/cheetah/arith/BUILD.bazel @@ -30,8 +30,9 @@ spu_cc_library( hdrs = ["cheetah_dot.h"], deps = [ ":arith_comm", - ":conv2d_prot", ":matmat_prot", + "//libspu/mpc/cheetah/rlwe:packlwes", + "@yacl//yacl/utils:elapsed_timer", ], ) @@ -51,24 +52,6 @@ spu_cc_library( ], ) -spu_cc_library( - name = "conv2d_prot", - srcs = [ - "conv2d_helper.cc", - "conv2d_prot.cc", - "tensor_encoder.cc", - ], - hdrs = [ - "conv2d_helper.h", - "conv2d_prot.h", - "tensor_encoder.h", - ], - deps = [ - ":arith_comm", - "//libspu/core:ndarray_ref", - ], -) - spu_cc_library( name = "arith_comm", srcs = [ @@ -81,7 +64,6 @@ spu_cc_library( ], deps = [ "//libspu/core:prelude", - "//libspu/mpc/cheetah:array_ref", "//libspu/mpc/cheetah/rlwe:cheetah_rlwe", "@yacl//yacl/link", ], @@ -96,16 +78,6 @@ spu_cc_test( ], ) -spu_cc_test( - name = "conv2d_prot_test", - size = "large", - srcs = ["conv2d_prot_test.cc"], - deps = [ - ":conv2d_prot", - "@com_github_xtensor_xtensor//:xtensor", - ], -) - spu_cc_test( name = "cheetah_mul_test", srcs = ["cheetah_mul_test.cc"], @@ -128,15 +100,3 @@ spu_cc_test( "@com_github_xtensor_xtensor//:xtensor", ], ) - -spu_cc_test( - name = "cheetah_conv2d_test", - size = "large", - srcs = ["cheetah_conv2d_test.cc"], - deps = [ - ":cheetah_dot", - "//libspu/mpc/utils:ring_ops", - "//libspu/mpc/utils:simulate", - "@com_github_xtensor_xtensor//:xtensor", - ], -) diff --git a/libspu/mpc/cheetah/arith/cheetah_dot.cc b/libspu/mpc/cheetah/arith/cheetah_dot.cc index f9f51d34..c7c44bdd 100644 --- a/libspu/mpc/cheetah/arith/cheetah_dot.cc +++ b/libspu/mpc/cheetah/arith/cheetah_dot.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "libspu/mpc/cheetah/arith/cheetah_dot.h" +#include #include #include @@ -22,6 +23,7 @@ #include "seal/decryptor.h" #include "seal/encryptor.h" #include "seal/evaluator.h" +#include "seal/galoiskeys.h" #include "seal/keygenerator.h" #include "seal/publickey.h" #include "seal/secretkey.h" @@ -30,11 +32,12 @@ #include "seal/util/rlwe.h" #include "spdlog/spdlog.h" #include "yacl/link/link.h" +#include "yacl/utils/elapsed_timer.h" #include "libspu/mpc/cheetah/arith/common.h" -#include "libspu/mpc/cheetah/arith/conv2d_prot.h" #include "libspu/mpc/cheetah/arith/matmat_prot.h" #include "libspu/mpc/cheetah/rlwe/modswitch_helper.h" +#include "libspu/mpc/cheetah/rlwe/packlwes.h" #include "libspu/mpc/cheetah/rlwe/utils.h" #include "libspu/mpc/utils/ring_ops.h" @@ -42,68 +45,67 @@ namespace spu::mpc::cheetah { struct CheetahDot::Impl : public EnableCPRNG { public: + enum class CipherPackingType { + rlwes, + none, + }; + const bool kUseModDownOptimization = true; - static constexpr size_t kParallelStride = 1; - static constexpr size_t kCtAsyncParallel = 8; + static constexpr size_t kCtAsyncParallel = 16; - explicit Impl(std::shared_ptr lctx) - : lctx_(std::move(lctx)) {} + explicit Impl(std::shared_ptr lctx, + bool enable_matmul_pack) + : lctx_(std::move(lctx)), disable_pack_(!enable_matmul_pack) {} ~Impl() = default; - std::unique_ptr Fork(); - - struct Conv2DMeta { - Conv2DProtocol::Meta prot_meta; - bool is_tensor; - size_t n_tensor_poly; - size_t n_kernel_poly; - size_t n_output_poly; - }; - // Compute C = A*B where |A|=dims[0]xdims[1], |B|=dims[1]xdims[2 - ArrayRef DotOLE(const ArrayRef &prv, yacl::link::Context *conn, - const Shape3D &dim3, bool is_lhs); - - ArrayRef Conv2dOLE(const ArrayRef &inp, yacl::link::Context *conn, - int64_t input_batch, const Shape3D &tensor_shape, - int64_t num_kernels, const Shape3D &kernel_shape, - const Shape2D &window_strides, bool is_tensor); - - ArrayRef doConv2dOLEForEncryptor(FieldType field, - absl::Span poly_ntt, - const Conv2DMeta &meta, - const Conv2DProtocol &conv2d, - yacl::link::Context *conn); - - template - void doConv2dOLECtPtMul(absl::Span tensors, - absl::Span kernels, const Conv2DMeta &meta, - const Conv2DProtocol &conv2d, absl::Span out); - - ArrayRef doConv2dOLEForEvaluator(FieldType field, - absl::Span poly_ntt, - const Conv2DMeta &meta, - const Conv2DProtocol &conv2d, - yacl::link::Context *conn); - - void encodeBatchInput(const ArrayRef &batch_inp, const Conv2DMeta &meta, - const Conv2DProtocol &conv2d, bool need_encrypt, - absl::Span out); - - ArrayRef parseBatchedConv2dResult(FieldType field, const Conv2DMeta &meta, - const Conv2DProtocol &prot, - absl::Span polys); - - static seal::EncryptionParameters DecideSEALParameters(uint32_t ring_bitlen) { + NdArrayRef DotOLE(const NdArrayRef &prv_mat, yacl::link::Context *conn, + const Shape3D &dim3, bool is_self_lhs); + + // dim4 = [B, M, K, L] + // LHS.shape BxMxK, RHS.shape BxKxL + // Out.shape BxMxL + NdArrayRef BatchDotOLE(const NdArrayRef &prv_mat, yacl::link::Context *conn, + const Shape4D &dim4, bool is_self_lhs); + + NdArrayRef doDotOLE(const NdArrayRef &prv_mat, yacl::link::Context *conn, + const Shape3D &dim3, bool is_self_lhs); + + NdArrayRef doBatchDotOLE(const NdArrayRef &prv_mat, yacl::link::Context *conn, + const Shape4D &dim4, bool is_self_lhs); + + void doDotOLESenderSendStep(const NdArrayRef &prv_mat, const Shape3D &dim3, + bool is_self_lhs, CipherPackingType cptype, + yacl::link::Context *conn); + + void doDotOLEReceiverRecvStep(const NdArrayRef &prv_mat, const Shape3D &dim3, + bool is_self_lhs, CipherPackingType cptype, + absl::Span result_cts, + yacl::link::Context *conn); + + NdArrayRef doDotOLESenderRecvStep(FieldType field, size_t batch_size, + MatMatProtocol::Meta meta, + size_t num_ct_to_recv, + CipherPackingType cptype, + yacl::link::Context *conn); + + NdArrayRef doDotOLEReceiverSendStep(FieldType field, size_t batch_size, + MatMatProtocol::Meta meta, + absl::Span ct_array_to_pack, + CipherPackingType cptype, + yacl::link::Context *conn, + size_t bytes_recv); + + seal::EncryptionParameters DecideSEALParameters(uint32_t ring_bitlen) const { size_t poly_deg; std::vector modulus_bits; - // NOTE(juhou): we need Q=sum(modulus_bits) > 2*k for multiplying two k-bit - // elements. - // 1. We need the (N, Q) pair satisfies the security. + // NOTE(lwj): we need Q=sum(modulus_bits) > 2*k for multiplying two + // k-bit elements. + // 1. We need the (N, Q) pair satisifies the security. // Check `seal/util/globals.cpp` for the recommendation HE parameters. - // 2. We prefer modulus_bits[i] to be around 49-bit aiming to use AVX512 for - // acceleration if available. + // 2. We prefer modulus_bits[i] to be around 49-bit aiming to use AVX512 + // for acceleration if avaiable. // 3. We need some bits for margin. That is Q > 2*k + margin for errorless // w.h.p. We set margin=32bit if (ring_bitlen <= 32) { @@ -112,8 +114,8 @@ struct CheetahDot::Impl : public EnableCPRNG { modulus_bits = {59, 37}; } else if (ring_bitlen <= 64) { poly_deg = 8192; - // ~ 128 + 32 bit - modulus_bits = {57, 57, 45}; + // ~ 128 + 32 bit + modulus_bits = {59, 55, 49}; } else { poly_deg = 16384; // ~ 256 + 30 bit @@ -129,73 +131,9 @@ struct CheetahDot::Impl : public EnableCPRNG { return parms; } - static inline uint32_t FieldBitLen(FieldType f) { return 8 * SizeOf(f); } - - void LazyInit(size_t field_bitlen); - - static void SubPlainInplace(RLWECt &ct, const RLWEPt &pt, - const seal::SEALContext &context) { - SPU_ENFORCE(ct.parms_id() == pt.parms_id()); - auto cntxt_dat = context.get_context_data(ct.parms_id()); - SPU_ENFORCE(cntxt_dat != nullptr); - const auto &parms = cntxt_dat->parms(); - const auto &modulus = parms.coeff_modulus(); - size_t num_coeff = ct.poly_modulus_degree(); - size_t num_modulus = ct.coeff_modulus_size(); - - for (size_t l = 0; l < num_modulus; ++l) { - auto *op0 = ct.data(0) + l * num_coeff; - const auto *op1 = pt.data() + l * num_coeff; - seal::util::sub_poly_coeffmod(op0, op1, num_coeff, modulus[l], op0); - } - } - - void KeepModulus(RLWECt &ct, size_t num_to_keep, - const seal::SEALContext &context) const { - SPU_ENFORCE(num_to_keep >= 1 && num_to_keep <= ct.coeff_modulus_size()); - if (num_to_keep == ct.coeff_modulus_size()) { - // nothing to do - return; - } - - auto cntxt = context.get_context_data(ct.parms_id()); - YACL_ENFORCE(cntxt != nullptr); - size_t index = cntxt->chain_index(); - YACL_ENFORCE((index + 1) >= num_to_keep); - - auto target_context = cntxt; - auto pool = - seal::MemoryManager::GetPool(seal::mm_prof_opt::mm_force_thread_local); - - while (target_context->chain_index() >= num_to_keep) { - using namespace seal::util; - auto rns_tool = target_context->rns_tool(); - auto ntt_tables = target_context->small_ntt_tables(); - if (ct.is_ntt_form()) { - SEAL_ITERATE(iter(ct), ct.size(), [&](auto I) { - rns_tool->divide_and_round_q_last_ntt_inplace(I, ntt_tables, pool); - }); - } else { - SEAL_ITERATE(iter(ct), ct.size(), [&](auto I) { - rns_tool->divide_and_round_q_last_inplace(I, pool); - }); - } - - auto next_context = target_context->next_context_data(); - SPU_ENFORCE(next_context != nullptr); + void LazyInit(size_t field_bitlen, bool need_galois_key = false); - RLWECt next_ct(pool); - next_ct.resize(context, next_context->parms_id(), ct.size()); - SEAL_ITERATE(iter(ct, next_ct), ct.size(), [&](auto I) { - set_poly(get<0>(I), ct.poly_modulus_degree(), - ct.coeff_modulus_size() - 1, get<1>(I)); - }); - next_ct.is_ntt_form() = ct.is_ntt_form(); - target_context = next_context; - std::swap(next_ct, ct); - } - SPU_ENFORCE_EQ(num_to_keep, ct.coeff_modulus_size()); - } + void LazyInitGaloisKey(size_t field_bitlen); // Enc(Delta*m) -> Enc(Delta*m - r), r where r is sampled from Rq // One share of `m` is round(r/Delta) mod t @@ -208,20 +146,26 @@ struct CheetahDot::Impl : public EnableCPRNG { SPU_ENFORCE(num_poly > 0); SPU_ENFORCE_EQ(rnd_mask.size(), num_poly); + constexpr int64_t heuristic_group = 4; yacl::parallel_for( - 0, num_poly, kParallelStride, [&](size_t bgn, size_t end) { + 0, num_poly, heuristic_group, [&](size_t bgn, size_t end) { + RLWECt zero_ct; for (size_t idx = bgn; idx < end; ++idx) { - // decrease q <- q' for efficiency - if (ct[idx].is_ntt_form()) { - evaluator.transform_from_ntt_inplace(ct[idx]); - } - KeepModulus(ct[idx], target_modulus_size, context); + // NOTE(lwj): we hope the final ct is in the non-ntt form + // We perform the intt before the modulus down which is faster + // than modulus down then intt. + InvNttInplace(ct[idx], context); - // TODO(juhou): improve the performance of pk encryption of zero. + ModulusSwtichInplace(ct[idx], target_modulus_size, context); + + // TODO(lwj): improve the performance of pk encryption of zero. // ct <- ct + enc(0) - RLWECt zero_ct; - seal::util::encrypt_zero_asymmetric(pk, context, ct[idx].parms_id(), - ct[idx].is_ntt_form(), zero_ct); + if (0 == zero_ct.size()) { + seal::util::encrypt_zero_asymmetric( + pk, context, ct[idx].parms_id(), ct[idx].is_ntt_form(), + zero_ct); + } + evaluator.add_inplace(ct[idx], zero_ct); SPU_ENFORCE(!ct[idx].is_ntt_form()); @@ -235,45 +179,65 @@ struct CheetahDot::Impl : public EnableCPRNG { private: std::shared_ptr lctx_; + bool disable_pack_ = false; mutable std::shared_mutex context_lock_; // field_bitlen -> functor mapping std::unordered_map> seal_cntxts_; + std::unordered_map galoi_cntxts_; std::unordered_map> secret_keys_; std::unordered_map> peer_pub_keys_; + std::unordered_map> + peer_galois_keys_; // ModulusSwitchHelper for encoding std::unordered_map> ecd_mswh_; // ModulusSwitchHelper for decoding std::unordered_map> dcd_mswh_; - std::unordered_map> sym_encryptors_; std::unordered_map> decryptors_; }; -std::unique_ptr CheetahDot::Impl::Fork() { - auto f = std::make_unique(lctx_->Spawn()); - if (seal_cntxts_.size() == 0) return f; - std::unique_lock guard(context_lock_); - - f->seal_cntxts_ = seal_cntxts_; - f->secret_keys_ = secret_keys_; - f->peer_pub_keys_ = peer_pub_keys_; - f->ecd_mswh_ = ecd_mswh_; - f->dcd_mswh_ = dcd_mswh_; - f->sym_encryptors_ = sym_encryptors_; - f->decryptors_ = decryptors_; - return f; -} +void CheetahDot::Impl::LazyInitGaloisKey(size_t field_bitlen) { + // NOTE: make sure context_lock_ is obtained. + if (galoi_cntxts_.find(field_bitlen) != galoi_cntxts_.end()) { + return; + } + auto kv = seal_cntxts_.find(field_bitlen); + SPU_ENFORCE(kv != seal_cntxts_.end()); + const auto &this_context = *kv->second; + const auto &this_rlwe_sk = *secret_keys_.find(field_bitlen)->second; -void CheetahDot::Impl::LazyInit(size_t field_bitlen) { - { - std::shared_lock guard(context_lock_); - if (seal_cntxts_.find(field_bitlen) != seal_cntxts_.end()) { - return; - } + seal::GaloisKeys gk; + auto gk_parms = this_context.key_context_data()->parms(); + gk_parms.set_use_special_prime(true); + + seal::SEALContext gk_context(gk_parms, true, seal::sec_level_type::none); + GenerateGaloisKeyForPacking(gk_context, this_rlwe_sk, + /*seed*/ true, &gk); + galoi_cntxts_.emplace(field_bitlen, gk_context); + + auto gk_buf = EncodeSEALObject(gk); + int nxt_rank = lctx_->NextRank(); + std::shared_ptr peer_galois_key; + if (nxt_rank == 0) { + lctx_->Send(nxt_rank, gk_buf, "Rank0 send galois key"); + auto recv_gk = lctx_->Recv(nxt_rank, "Rank0 recv galois key"); + peer_galois_key = std::make_shared(); + DecodeSEALObject(recv_gk, this_context, peer_galois_key.get()); + } else { + auto recv_gk = lctx_->Recv(nxt_rank, "Rank1 recv galois key"); + lctx_->Send(nxt_rank, gk_buf, "Rank0 send galois key"); + peer_galois_key = std::make_shared(); + DecodeSEALObject(recv_gk, this_context, peer_galois_key.get()); } - // double-checking - std::unique_lock guard(context_lock_); + peer_galois_keys_.emplace(field_bitlen, peer_galois_key); +} + +void CheetahDot::Impl::LazyInit(size_t field_bitlen, bool need_galois_keys) { + std::unique_lock guard(context_lock_); if (seal_cntxts_.find(field_bitlen) != seal_cntxts_.end()) { + if (need_galois_keys) { + LazyInitGaloisKey(field_bitlen); + } return; } @@ -283,30 +247,34 @@ void CheetahDot::Impl::LazyInit(size_t field_bitlen) { seal::KeyGenerator keygen(*this_context); auto *rlwe_sk = new seal::SecretKey(keygen.secret_key()); + // exchange the public keys + int nxt_rank = lctx_->NextRank(); auto pk = keygen.create_public_key(); - // NOTE(juhou): we patched seal/util/serializable.h auto pk_buf = EncodeSEALObject(pk.obj()); - // exchange the public key - int nxt_rank = lctx_->NextRank(); - lctx_->SendAsync(nxt_rank, pk_buf, "send Pk"); - pk_buf = lctx_->Recv(nxt_rank, "recv pk"); auto peer_public_key = std::make_shared(); - DecodeSEALObject(pk_buf, *this_context, peer_public_key.get()); + if (nxt_rank == 0) { + lctx_->Send(nxt_rank, pk_buf, "Rank1 send public key"); + auto recv_pk = lctx_->Recv(nxt_rank, "Rank1 recv public key"); + DecodeSEALObject(recv_pk, *this_context, peer_public_key.get()); + } else { + auto recv_pk = lctx_->Recv(nxt_rank, "Rank0 recv public key"); + lctx_->Send(nxt_rank, pk_buf, "Rank1 send public key"); + DecodeSEALObject(recv_pk, *this_context, peer_public_key.get()); + } auto modulus = parms.coeff_modulus(); - if (parms.use_special_prime()) { - modulus.pop_back(); - } size_t ecd_modulus_sze = modulus.size(); parms.set_coeff_modulus(modulus); seal::SEALContext ecd_ms_context(parms, false, seal::sec_level_type::none); + if (kUseModDownOptimization) { - // NOTE: we can drop some modulus before H2A modulus.pop_back(); if (field_bitlen > 64) { + // Keep 3 primes are enough for FM128. modulus.pop_back(); } } + size_t dcd_modulus_sze = modulus.size(); parms.set_coeff_modulus(modulus); seal::SEALContext dcd_ms_context(parms, false, seal::sec_level_type::none); @@ -314,459 +282,506 @@ void CheetahDot::Impl::LazyInit(size_t field_bitlen) { seal_cntxts_.emplace(field_bitlen, this_context); secret_keys_.emplace(field_bitlen, rlwe_sk); peer_pub_keys_.emplace(field_bitlen, peer_public_key); + ecd_mswh_.emplace(field_bitlen, new ModulusSwitchHelper(ecd_ms_context, field_bitlen)); dcd_mswh_.emplace(field_bitlen, new ModulusSwitchHelper(dcd_ms_context, field_bitlen)); - sym_encryptors_.emplace(field_bitlen, - new seal::Encryptor(*this_context, *rlwe_sk)); decryptors_.emplace(field_bitlen, new seal::Decryptor(*this_context, *rlwe_sk)); - - SPDLOG_INFO("CheetahDot uses {}@{} modulus {} degree for {} bit ring", - ecd_modulus_sze, dcd_modulus_sze, parms.poly_modulus_degree(), - field_bitlen); -} - -ArrayRef CheetahDot::Impl::DotOLE(const ArrayRef &prv_mat, - yacl::link::Context *conn, - const Shape3D &dim3, bool is_lhs) { - if (conn == nullptr) { - conn = lctx_.get(); + if (need_galois_keys) { + LazyInitGaloisKey(field_bitlen); } - int nxt_rank = conn->NextRank(); - auto eltype = prv_mat.eltype(); - SPU_ENFORCE(eltype.isa(), "must be ring_type, got={}", eltype); - SPU_ENFORCE(prv_mat.numel() > 0); - if (is_lhs) { - SPU_ENFORCE_EQ(prv_mat.numel(), dim3[0] * dim3[1]); - } else { - SPU_ENFORCE_EQ(prv_mat.numel(), dim3[1] * dim3[2]); + if (lctx_->Rank() == 0) { + SPDLOG_INFO("CheetahDot uses {}@{} modulus {} degree for {} bit ring", + ecd_modulus_sze, dcd_modulus_sze, parms.poly_modulus_degree(), + field_bitlen); } +} - auto field = eltype.as()->field(); - const size_t field_bitlen = FieldBitLen(field); - LazyInit(field_bitlen); +void CheetahDot::Impl::doDotOLEReceiverRecvStep(const NdArrayRef &prv_mat, + const Shape3D &dim3, + bool is_self_lhs, + CipherPackingType cptype, + absl::Span result_cts, + yacl::link::Context *conn) { + int next_rank = conn->NextRank(); + auto eltype = prv_mat.eltype(); + auto field = eltype.template as()->field(); + const size_t field_bitlen = SizeOf(field) * 8; const auto &this_context = *seal_cntxts_.find(field_bitlen)->second; - auto &this_encryptor = sym_encryptors_.find(field_bitlen)->second; - auto &this_decryptor = decryptors_.find(field_bitlen)->second; - auto &this_ecd_ms = ecd_mswh_.find(field_bitlen)->second; - auto &this_dcd_ms = dcd_mswh_.find(field_bitlen)->second; - seal::Evaluator evaluator(this_context); + const auto &this_ecd_msh = *ecd_mswh_.find(field_bitlen)->second; + bool disable_pack = cptype == CipherPackingType::none; + + MatMatProtocol matmat_prot(this_context, this_ecd_msh, disable_pack); - MatMatProtocol matmat(this_context, *this_ecd_ms, /*mont*/ true); MatMatProtocol::Meta meta; meta.dims = dim3; - auto subshape = matmat.GetSubMatShape(meta); - size_t lhs_n = matmat.GetLeftSize(meta, subshape); - size_t rhs_n = matmat.GetRightSize(meta, subshape); - size_t out_n = matmat.GetOutSize(meta, subshape); - bool to_encrypt_lhs = lhs_n < rhs_n; - bool need_encrypt = (is_lhs ^ to_encrypt_lhs) == 0; - - std::vector encoded_mat(is_lhs ? lhs_n : rhs_n); - if (is_lhs) { - matmat.EncodeLHS(prv_mat, meta, need_encrypt, absl::MakeSpan(encoded_mat)); + auto subshape = matmat_prot.GetSubMatShape(meta); + const size_t lhs_n = matmat_prot.GetLeftSize(meta, subshape); + const size_t rhs_n = matmat_prot.GetRightSize(meta, subshape); + const size_t out_n = matmat_prot.GetOutSize(meta, subshape); + SPU_ENFORCE_EQ(out_n, result_cts.size()); + + // 1. launch IO task to recv ct from peer + std::vector enc_mat(is_self_lhs ? rhs_n : lhs_n); + auto io_task = std::async(std::launch::async, [&]() { + for (auto &ct : enc_mat) { + auto ct_s = conn->Recv(next_rank, "recv encrypted mat"); + DecodeSEALObject(ct_s, this_context, &ct); + } + }); + + // 2. encode the matrix for multiplication + std::vector plain_mat(is_self_lhs ? lhs_n : rhs_n); + if (is_self_lhs) { + matmat_prot.EncodeLHS(prv_mat, meta, false, absl::MakeSpan(plain_mat)); } else { - matmat.EncodeRHS(prv_mat, meta, need_encrypt, absl::MakeSpan(encoded_mat)); + matmat_prot.EncodeRHS(prv_mat, meta, false, absl::MakeSpan(plain_mat)); } - // convert local poly to NTT form to perform encryption / multiplication. - yacl::parallel_for(0, encoded_mat.size(), kParallelStride, + yacl::parallel_for(0, plain_mat.size(), CalculateWorkLoad(plain_mat.size()), [&](size_t bgn, size_t end) { for (size_t i = bgn; i < end; ++i) { - NttInplace(encoded_mat[i], this_context); - if (not need_encrypt) { - matmat.Montgomerize({&encoded_mat[i], 1}); - } + NttInplace(plain_mat[i], this_context); } }); + io_task.get(); - if (need_encrypt) { - // send ct - for (size_t i = 0; i < encoded_mat.size(); ++i) { - auto ct = this_encryptor->encrypt_symmetric(encoded_mat[i]).obj(); - auto ct_s = EncodeSEALObject(ct); - conn->SendAsync(nxt_rank, ct_s, "send encrypted mat"); - } + // 3. HE multiplications + if (is_self_lhs) { + matmat_prot.Compute(plain_mat, enc_mat, meta, result_cts); + } else { + matmat_prot.Compute(enc_mat, plain_mat, meta, result_cts); + } +} - // wait for result - std::vector recv_ct(kCtAsyncParallel); - std::vector result_poly(out_n); - for (size_t i = 0; i < out_n; i += kCtAsyncParallel) { - size_t this_batch = std::min(out_n - i, kCtAsyncParallel); - for (size_t j = 0; j < this_batch; ++j) { - auto ct_s = conn->Recv(nxt_rank, "recv result mat"); - DecodeSEALObject(ct_s, this_context, &recv_ct[j]); - } - - yacl::parallel_for( - 0, this_batch, kParallelStride, [&](size_t bgn, size_t end) { - for (size_t j = bgn; j < end; ++j) { - if (not recv_ct[j].is_ntt_form()) { - evaluator.transform_to_ntt_inplace(recv_ct[j]); - } - this_decryptor->decrypt(recv_ct[j], result_poly[i + j]); - } - }); - } +NdArrayRef CheetahDot::Impl::doDotOLEReceiverSendStep( + FieldType field, size_t batch_size, MatMatProtocol::Meta meta, + absl::Span ct_array_to_pack, CipherPackingType cptype, + yacl::link::Context *conn, size_t bytes_recv) { + int next_rank = conn->NextRank(); + const size_t field_bitlen = SizeOf(field) * 8; + const auto &this_public_key = *(peer_pub_keys_.find(field_bitlen)->second); + const auto &this_dcd_msh = *dcd_mswh_.find(field_bitlen)->second; + const auto &this_context = *seal_cntxts_.find(field_bitlen)->second; + const auto &this_ecd_msh = *ecd_mswh_.find(field_bitlen)->second; + bool disable_pack = cptype == CipherPackingType::none; - // non-ntt form for ParseResult - yacl::parallel_for(0, out_n, kParallelStride, [&](size_t bgn, size_t end) { - for (size_t i = bgn; i < end; ++i) { - InvNttInplace(result_poly[i], this_context); - } - }); + MatMatProtocol matmat_prot(this_context, this_ecd_msh, disable_pack); - auto ret = matmat.ParseResult(field, meta, absl::MakeSpan(result_poly), - *this_dcd_ms); - return ret; - } + auto subshape = matmat_prot.GetSubMatShape(meta); - // recv ct from peer - std::vector encrypted_mat(is_lhs ? rhs_n : lhs_n); - for (size_t i = 0; i < encrypted_mat.size(); ++i) { - auto ct_s = conn->Recv(nxt_rank, "recv encrypted mat"); - DecodeSEALObject(ct_s, this_context, &encrypted_mat[i]); - } + size_t out_n = ct_array_to_pack.size(); + size_t num_ct_response = 0; - std::vector result_ct(out_n); - if (is_lhs) { - matmat.Compute(encoded_mat, encrypted_mat, meta, absl::MakeSpan(result_ct)); + double pack_time = 0.0; + if (cptype == CipherPackingType::none) { + SPU_ENFORCE(batch_size == 1, "impossible dispatch here"); + num_ct_response = out_n; } else { - matmat.Compute(encrypted_mat, encoded_mat, meta, absl::MakeSpan(result_ct)); - } - - const auto &this_pk = peer_pub_keys_.find(field_bitlen)->second; + yacl::ElapsedTimer _timer; + const auto &this_galois_context = galoi_cntxts_.find(field_bitlen)->second; + const auto &this_galois_key = + *(peer_galois_keys_.find(field_bitlen)->second); + + const size_t gap = subshape[1]; + const size_t pack_stride = gap; + + PackingHelper pack_helper(gap, this_galois_key, this_galois_context, + this_context); + + for (size_t i = 0; i < out_n; i += pack_stride) { + size_t this_batch = std::min(out_n - i, pack_stride); + size_t packed_idx = i / pack_stride; + pack_helper.PackingWithModulusDrop( + ct_array_to_pack.subspan(i, this_batch), + ct_array_to_pack[packed_idx]); + } + pack_time = _timer.CountMs(); - std::vector mask_mat(out_n); - H2A(absl::MakeSpan(result_ct), absl::MakeSpan(mask_mat), - this_dcd_ms->coeff_modulus_size(), *this_pk, this_context); - matmat.ExtractLWEsInplace(meta, absl::MakeSpan(result_ct)); + num_ct_response = CeilDiv(out_n, pack_stride); + } - for (size_t i = 0; i < out_n; i += kCtAsyncParallel) { - // NOTE(juhou): we do not send too much ct with Async - size_t this_batch = std::min(out_n - i, kCtAsyncParallel); - auto ct_s = EncodeSEALObject(result_ct[i]); - conn->Send(nxt_rank, ct_s, "send result mat"); + // 4. Random masking to conver HE to AShr + std::vector rnd_polys(num_ct_response); + H2A({ct_array_to_pack.data(), num_ct_response}, absl::MakeSpan(rnd_polys), + this_dcd_msh.coeff_modulus_size(), this_public_key, this_context); - for (size_t j = 1; j < this_batch; ++j) { - auto ct_s = EncodeSEALObject(result_ct[i + j]); - conn->SendAsync(nxt_rank, ct_s, "send result mat"); - } + if (cptype == CipherPackingType::none) { + // If no packing, then clean up un-used coefficients + // NOTE(lwj): we place Extract **after** H2A for a smaller communication + matmat_prot.ExtractLWEsInplace(meta, ct_array_to_pack); } - return matmat.ParseResult(field, meta, absl::MakeSpan(mask_mat), - *this_dcd_ms); -} + size_t bytes_sent = conn->GetStats()->sent_bytes; + for (size_t i = 0; i < num_ct_response; ++i) { + conn->SendAsync(next_rank, EncodeSEALObject(ct_array_to_pack[i]), ""); + } + bytes_sent = conn->GetStats()->sent_bytes - bytes_sent; + + if (conn->Rank() == 0) { + SPDLOG_INFO( + "{}@{}x{}x{} => {}x{}x{} Recv {} MiB, Response {} MiB Pack {} ms", + batch_size, meta.dims[0], meta.dims[1], meta.dims[2], subshape[0], + subshape[1], subshape[2], + std::roundf(bytes_recv / 1024. / 1024. * 1000) / 1000., + std::roundf(bytes_sent / 1024. / 1024. * 1000) / 1000., + std::roundf(pack_time * 1000) / 1000.); + } -void CheetahDot::Impl::encodeBatchInput(const ArrayRef &batch_inp, - const Conv2DMeta &meta, - const Conv2DProtocol &conv2d, - bool need_encrypt, - absl::Span out) { - size_t input_batch = meta.prot_meta.input_batch; - size_t num_poly_per_input = meta.n_tensor_poly / input_batch; - SPU_ENFORCE_EQ(out.size(), meta.n_tensor_poly); - - const size_t numel = calcNumel(meta.prot_meta.input_shape); - yacl::parallel_for( - 0, input_batch, kParallelStride, [&](size_t bgn, size_t end) { - for (size_t ib = bgn; ib < end; ++ib) { - absl::Span one_input_polys{ - out.data() + ib * num_poly_per_input, num_poly_per_input}; - conv2d.EncodeInput(batch_inp.slice(ib * numel, (ib + 1) * numel), - meta.prot_meta, need_encrypt, one_input_polys); - } - }); + switch (cptype) { + case CipherPackingType::none: + return matmat_prot.ParseResult( + field, meta, absl::MakeConstSpan(rnd_polys), this_dcd_msh); + case CipherPackingType::rlwes: + default: + return matmat_prot.ParseBatchPackedResult(field, batch_size, meta, + absl::MakeConstSpan(rnd_polys), + this_dcd_msh); + } } -template -void CheetahDot::Impl::doConv2dOLECtPtMul(absl::Span tensors, - absl::Span kernels, - const Conv2DMeta &meta, - const Conv2DProtocol &conv2d, - absl::Span out) { - const size_t input_batch = meta.prot_meta.input_batch; - const size_t n_poly_per_tensor = meta.n_tensor_poly / input_batch; - const size_t n_poly_per_out = meta.n_output_poly / input_batch; - yacl::parallel_for( - 0, input_batch, kParallelStride, [&](size_t bgn, size_t end) { - for (size_t ib = bgn; ib < end; ++ib) { - absl::Span one_input_polys{ - tensors.data() + ib * n_poly_per_tensor, n_poly_per_tensor}; - - absl::Span one_output_polys{out.data() + ib * n_poly_per_out, - n_poly_per_out}; - - conv2d.Compute(one_input_polys, kernels, meta.prot_meta, - one_output_polys); - } - }); -} +void CheetahDot::Impl::doDotOLESenderSendStep(const NdArrayRef &prv_mat, + const Shape3D &dim3, + bool is_self_lhs, + CipherPackingType cptype, + yacl::link::Context *conn) { + int next_rank = conn->NextRank(); + auto eltype = prv_mat.eltype(); + auto field = eltype.template as()->field(); + const size_t field_bitlen = SizeOf(field) * 8; -ArrayRef CheetahDot::Impl::doConv2dOLEForEvaluator( - FieldType field, absl::Span poly_ntt, const Conv2DMeta &meta, - const Conv2DProtocol &conv2d, yacl::link::Context *conn) { - const size_t field_bitlen = FieldBitLen(field); const auto &this_context = *seal_cntxts_.find(field_bitlen)->second; - const auto &this_dcd_ms = dcd_mswh_.find(field_bitlen)->second; - const int nxt_rank = conn->NextRank(); - - SPU_ENFORCE_EQ(poly_ntt.size(), - meta.is_tensor ? meta.n_tensor_poly : meta.n_kernel_poly); + auto &this_secret_key = *secret_keys_.find(field_bitlen)->second; + auto &this_ecd_msh = *ecd_mswh_.find(field_bitlen)->second; + bool disable_pack = cptype == CipherPackingType::none; + MatMatProtocol matmat_prot(this_context, this_ecd_msh, disable_pack); - std::vector recv_ct(meta.is_tensor ? meta.n_kernel_poly - : meta.n_tensor_poly); + MatMatProtocol::Meta meta; + meta.dims = dim3; + auto subshape = matmat_prot.GetSubMatShape(meta); - for (size_t i = 0; i < recv_ct.size(); ++i) { - auto ct_s = conn->Recv(nxt_rank, "recv encrypted mat"); - DecodeSEALObject(ct_s, this_context, &recv_ct[i]); - } + const size_t lhs_n = matmat_prot.GetLeftSize(meta, subshape); + const size_t rhs_n = matmat_prot.GetRightSize(meta, subshape); - std::vector result_ct(meta.n_output_poly); - if (meta.is_tensor) { - doConv2dOLECtPtMul(poly_ntt, absl::MakeSpan(recv_ct), meta, - conv2d, absl::MakeSpan(result_ct)); + std::vector _encoded_mat(is_self_lhs ? lhs_n : rhs_n); + auto encoded_mat = absl::MakeSpan(_encoded_mat); + if (is_self_lhs) { + matmat_prot.EncodeLHS(prv_mat, meta, true, encoded_mat); } else { - doConv2dOLECtPtMul(absl::MakeSpan(recv_ct), poly_ntt, meta, - conv2d, absl::MakeSpan(result_ct)); + matmat_prot.EncodeRHS(prv_mat, meta, true, encoded_mat); } - const auto &this_pk = peer_pub_keys_.find(field_bitlen)->second; - - std::vector mask_tensor(meta.n_output_poly); - H2A(absl::MakeSpan(result_ct), absl::MakeSpan(mask_tensor), - this_dcd_ms->coeff_modulus_size(), *this_pk, this_context); - - size_t n_poly_per_out = meta.n_output_poly / meta.prot_meta.input_batch; - for (int64_t ib = 0; ib < meta.prot_meta.input_batch; ++ib) { - absl::Span one_output_polys{result_ct.data() + ib * n_poly_per_out, - n_poly_per_out}; - conv2d.ExtractLWEsInplace(meta.prot_meta, one_output_polys); - } + size_t num_ct_to_send = encoded_mat.size(); + for (size_t i = 0; i < num_ct_to_send; i += kCtAsyncParallel) { + size_t this_batch = std::min(num_ct_to_send - i, kCtAsyncParallel); + std::vector enc_mat(this_batch); + std::vector ct_s(this_batch); + + SymmetricRLWEEncrypt(this_secret_key, this_context, + encoded_mat.subspan(i, this_batch), + /*ntt*/ true, + /*seed*/ true, absl::MakeSpan(enc_mat)); + + yacl::parallel_for(0, this_batch, CalculateWorkLoad(this_batch), + [&](size_t bgn, size_t end) { + for (size_t j = bgn; j < end; ++j) { + ct_s[j] = EncodeSEALObject(enc_mat[j]); + } + }); - for (size_t i = 0; i < meta.n_output_poly; i += kCtAsyncParallel) { - size_t this_batch = std::min(meta.n_output_poly - i, kCtAsyncParallel); - auto ct_s = EncodeSEALObject(result_ct[i]); - conn->Send(nxt_rank, ct_s, "send result tensor"); for (size_t j = 1; j < this_batch; ++j) { - auto ct_s = EncodeSEALObject(result_ct[i + j]); - conn->SendAsync(nxt_rank, ct_s, "send result tensor"); + conn->SendAsync(next_rank, ct_s[j - 1], "send encrypted mat"); } + conn->Send(next_rank, ct_s[this_batch - 1], "send encrypted mat"); } - - return parseBatchedConv2dResult(field, meta, conv2d, - absl::MakeSpan(mask_tensor)); } -ArrayRef CheetahDot::Impl::doConv2dOLEForEncryptor( - FieldType field, absl::Span poly_ntt, const Conv2DMeta &meta, - const Conv2DProtocol &conv2d, yacl::link::Context *conn) { - const size_t field_bitlen = FieldBitLen(field); +NdArrayRef CheetahDot::Impl::doDotOLESenderRecvStep(FieldType field, + size_t batch_size, + MatMatProtocol::Meta meta, + size_t num_ct_to_recv, + CipherPackingType cptype, + yacl::link::Context *conn) { + SPU_ENFORCE(batch_size > 0); + int next_rank = conn->NextRank(); + const size_t field_bitlen = SizeOf(field) * 8; + const auto &this_context = *seal_cntxts_.find(field_bitlen)->second; - auto &this_encryptor = sym_encryptors_.find(field_bitlen)->second; auto &this_decryptor = decryptors_.find(field_bitlen)->second; - seal::Evaluator evaluator(this_context); - - const size_t num_ct_to_send = poly_ntt.size(); - const int nxt_rank = conn->NextRank(); - - // send ct - { - std::vector ct_to_send(kCtAsyncParallel); - for (size_t i = 0; i < num_ct_to_send; i += kCtAsyncParallel) { - size_t this_batch = std::min(num_ct_to_send - i, kCtAsyncParallel); - yacl::parallel_for( - 0, this_batch, kParallelStride, [&](size_t bgn, size_t end) { - for (size_t k = bgn; k < end; ++k) { - auto ct = - this_encryptor->encrypt_symmetric(poly_ntt[i + k]).obj(); - ct_to_send[k] = EncodeSEALObject(ct); - } - }); + auto &this_ecd_msh = *ecd_mswh_.find(field_bitlen)->second; + auto &this_dcd_msh = *dcd_mswh_.find(field_bitlen)->second; - conn->Send(nxt_rank, ct_to_send[0], "Conv2dOLE::Send ct"); - for (size_t k = 1; k < this_batch; ++k) { - conn->SendAsync(nxt_rank, ct_to_send[k], "Conv2dOLE::Send ct"); - } - } - } + MatMatProtocol matmat_prot(this_context, this_ecd_msh, + cptype == CipherPackingType::none); - // wait for result std::vector recv_ct(kCtAsyncParallel); - std::vector result_poly(meta.n_output_poly); - for (size_t i = 0; i < meta.n_output_poly; i += kCtAsyncParallel) { - size_t this_batch = std::min(meta.n_output_poly - i, kCtAsyncParallel); + std::vector result_poly(num_ct_to_recv); + + for (size_t i = 0; i < num_ct_to_recv; i += kCtAsyncParallel) { + size_t this_batch = std::min(num_ct_to_recv - i, kCtAsyncParallel); for (size_t j = 0; j < this_batch; ++j) { - auto ct_s = conn->Recv(nxt_rank, "Conv2dOLE::Recv"); + auto ct_s = conn->Recv(next_rank, "recv result mat"); DecodeSEALObject(ct_s, this_context, &recv_ct[j]); } - yacl::parallel_for( - 0, this_batch, kParallelStride, [&](size_t bgn, size_t end) { - for (size_t j = bgn; j < end; ++j) { - evaluator.transform_to_ntt_inplace(recv_ct[j]); - this_decryptor->decrypt(recv_ct[j], result_poly[i + j]); - } - }); + yacl::parallel_for(0, this_batch, CalculateWorkLoad(this_batch), + [&](size_t bgn, size_t end) { + for (size_t j = bgn; j < end; ++j) { + NttInplace(recv_ct[j], this_context); + this_decryptor->decrypt(recv_ct[j], + result_poly[i + j]); + // non-ntt form for ParseResult + InvNttInplace(result_poly[i + j], this_context); + } + }); } - // non-ntt form for parseBatchedConv2dResult - yacl::parallel_for(0, meta.n_output_poly, kParallelStride, - [&](size_t bgn, size_t end) { - for (size_t i = bgn; i < end; ++i) { - InvNttInplace(result_poly[i], this_context); - } - }); - - return parseBatchedConv2dResult(field, meta, conv2d, - absl::MakeSpan(result_poly)); + switch (cptype) { + case CipherPackingType::none: + return matmat_prot.ParseResult( + field, meta, absl::MakeConstSpan(result_poly), this_dcd_msh); + case CipherPackingType::rlwes: + default: + return matmat_prot.ParseBatchPackedResult( + field, batch_size, meta, absl::MakeConstSpan(result_poly), + this_dcd_msh); + } } -ArrayRef CheetahDot::Impl::Conv2dOLE( - const ArrayRef &inp, yacl::link::Context *conn, int64_t input_batch, - const Shape3D &tensor_shape, int64_t num_kernels, - const Shape3D &kernel_shape, const Shape2D &window_strides, - bool is_tensor) { +NdArrayRef CheetahDot::Impl::DotOLE(const NdArrayRef &prv_mat, + yacl::link::Context *conn, + const Shape3D &dim3, bool is_self_lhs) { if (conn == nullptr) { conn = lctx_.get(); } - - auto eltype = inp.eltype(); + auto eltype = prv_mat.eltype(); SPU_ENFORCE(eltype.isa(), "must be ring_type, got={}", eltype); - SPU_ENFORCE(input_batch > 0 && num_kernels > 0); - if (is_tensor) { - SPU_ENFORCE_EQ(inp.numel(), calcNumel(tensor_shape) * input_batch); + SPU_ENFORCE(prv_mat.numel() > 0 && prv_mat.ndim() == 2); + + if (is_self_lhs) { + SPU_ENFORCE_EQ(prv_mat.numel(), dim3[0] * dim3[1]); } else { - SPU_ENFORCE_EQ(inp.numel(), calcNumel(kernel_shape) * num_kernels); + SPU_ENFORCE_EQ(prv_mat.numel(), dim3[1] * dim3[2]); } - auto field = eltype.as()->field(); - const size_t field_bitlen = FieldBitLen(field); - LazyInit(field_bitlen); + return doDotOLE(prv_mat, conn, dim3, is_self_lhs); +} - const auto &this_context = *seal_cntxts_.find(field_bitlen)->second; - auto &this_ecd_ms = ecd_mswh_.find(field_bitlen)->second; - Conv2DProtocol conv2d(this_context, *this_ecd_ms); - - Conv2DMeta meta; - meta.prot_meta.input_batch = input_batch; - meta.prot_meta.num_kernels = num_kernels; - meta.prot_meta.input_shape = tensor_shape; - meta.prot_meta.kernel_shape = kernel_shape; - meta.prot_meta.window_strides = window_strides; - - auto subshape = conv2d.GetSubTensorShape(meta.prot_meta); - meta.n_tensor_poly = - input_batch * conv2d.GetInputSize(meta.prot_meta, subshape); - meta.n_output_poly = - input_batch * conv2d.GetOutSize(meta.prot_meta, subshape); - meta.n_kernel_poly = conv2d.GetKernelSize(meta.prot_meta, subshape); - meta.is_tensor = is_tensor; - - bool to_encrypt_tensor = meta.n_tensor_poly < meta.n_kernel_poly; - bool need_encrypt = !(is_tensor ^ to_encrypt_tensor); - std::vector encoded_poly; - if (is_tensor) { - encoded_poly.resize(meta.n_tensor_poly); - encodeBatchInput(inp, meta, conv2d, need_encrypt, - absl::MakeSpan(encoded_poly)); +NdArrayRef CheetahDot::Impl::BatchDotOLE(const NdArrayRef &prv_mat, + yacl::link::Context *conn, + const Shape4D &dim4, + bool is_self_lhs) { + if (conn == nullptr) { + conn = lctx_.get(); + } + auto eltype = prv_mat.eltype(); + SPU_ENFORCE(eltype.isa(), "must be ring_type, got={}", eltype); + SPU_ENFORCE(prv_mat.numel() > 0 && prv_mat.ndim() == 3); + + if (is_self_lhs) { + SPU_ENFORCE_EQ(prv_mat.numel(), dim4[0] * dim4[1] * dim4[2]); } else { - encoded_poly.resize(meta.n_kernel_poly); - conv2d.EncodeKernels(inp, meta.prot_meta, need_encrypt, - absl::MakeSpan(encoded_poly)); + SPU_ENFORCE_EQ(prv_mat.numel(), dim4[0] * dim4[2] * dim4[3]); } - // convert local poly to NTT form to perform multiplication. - yacl::parallel_for(0, encoded_poly.size(), kParallelStride, - [&](size_t bgn, size_t end) { - for (size_t i = bgn; i < end; ++i) { - NttInplace(encoded_poly[i], this_context); - } - }); - if (need_encrypt) { - return doConv2dOLEForEncryptor(field, encoded_poly, meta, conv2d, conn); + if (eltype.template as()->field() == FM32) { + // FM32 not supportting Packing. + // Thus just call multiple DotOLEs + Shape3D dim3 = {dim4[1], dim4[2], dim4[3]}; + int64_t out_numel = dim4[1] * dim4[3]; + NdArrayRef out(eltype, {dim4[0] * out_numel}); + const auto &mat_shape = prv_mat.shape(); + + for (int64_t b = 0; b < dim4[0]; ++b) { + auto one_mat = + prv_mat + .slice({b, 0, 0}, {b + 1, mat_shape[1], mat_shape[2]}, {1, 1, 1}) + .reshape({mat_shape[1], mat_shape[2]}); + auto one_out = DotOLE(one_mat, conn, dim3, is_self_lhs); + auto out_slice = + out.slice({b * out_numel}, {b * out_numel + out_numel}, {1}); + pforeach(0, one_out.numel(), [&](int64_t i) { + std::memcpy(&out_slice.at(i), &one_out.at(i), one_out.elsize()); + }); + } + + return out.reshape({dim4[0], dim4[1], dim4[3]}); } - return doConv2dOLEForEvaluator(field, encoded_poly, meta, conv2d, conn); + + return doBatchDotOLE(prv_mat, conn, dim4, is_self_lhs); } -ArrayRef CheetahDot::Impl::parseBatchedConv2dResult( - FieldType field, const Conv2DMeta &meta, const Conv2DProtocol &prot, - absl::Span polys) { - SPU_ENFORCE_EQ(polys.size(), meta.n_output_poly); - const auto &this_dcd_ms = dcd_mswh_.find(FieldBitLen(field))->second; - - std::array oshape; - oshape[0] = meta.prot_meta.input_batch; - for (int d : {0, 1}) { - oshape[d + 1] = - (meta.prot_meta.input_shape[d] - meta.prot_meta.kernel_shape[d] + - meta.prot_meta.window_strides[d]) / - meta.prot_meta.window_strides[d]; +NdArrayRef CheetahDot::Impl::doBatchDotOLE(const NdArrayRef &prv_mat, + yacl::link::Context *conn, + const Shape4D &dim4, + bool is_self_lhs) { + auto eltype = prv_mat.eltype(); + auto field = eltype.template as()->field(); + SPU_ENFORCE(field != FM32, "Not support BatchDotOLE for FM32"); + + const size_t field_bitlen = SizeOf(field) * 8; + size_t poly_deg = DecideSEALParameters(field_bitlen).poly_modulus_degree(); + + const Shape3D dim3 = {dim4[1], dim4[2], dim4[3]}; + const size_t batch_size = dim4[0]; + MatMatProtocol::Meta meta = {.dims = dim3}; + + CipherPackingType cptype = CipherPackingType::rlwes; + auto subshape = + MatMatProtocol::GetSubMatShape(meta, poly_deg, /*disable_pack*/ false); + + size_t blk0 = CeilDiv(meta.dims[0], subshape[0]); + size_t blk1 = CeilDiv(meta.dims[1], subshape[1]); + size_t blk2 = CeilDiv(meta.dims[2], subshape[2]); + size_t lhs_poly_n = blk0 * blk1; + size_t rhs_poly_n = blk1 * blk2; + bool to_encrypt_lhs = lhs_poly_n <= rhs_poly_n; + bool act_as_encryptor = (is_self_lhs ^ to_encrypt_lhs) == 0; + + LazyInit(field_bitlen, cptype != CipherPackingType::none); + auto mat_shape = prv_mat.shape(); + + if (act_as_encryptor) { + for (int64_t b = 0; b < (int64_t)batch_size; ++b) { + auto one_mat = + prv_mat + .slice({b, 0, 0}, {b + 1, mat_shape[1], mat_shape[2]}, {1, 1, 1}) + .reshape({mat_shape[1], mat_shape[2]}); + doDotOLESenderSendStep(one_mat, dim3, is_self_lhs, cptype, conn); + } + + size_t num_ct_to_recv = 0; + switch (cptype) { + case CipherPackingType::rlwes: + num_ct_to_recv = CeilDiv(batch_size * blk0 * blk2, subshape[1]); + break; + default: + num_ct_to_recv = blk0 * blk2; + } + return doDotOLESenderRecvStep(field, batch_size, meta, num_ct_to_recv, + cptype, conn); } - oshape[3] = meta.prot_meta.num_kernels; - - const size_t n_poly_per_out = meta.n_output_poly / meta.prot_meta.input_batch; - ArrayRef ret = flatten(ring_zeros(field, {calcNumel(oshape)})); - int64_t offset = 0; - for (int64_t ib = 0; ib < meta.prot_meta.input_batch; ++ib) { - // NxHxWxC layout for tensor - // One slice is 1xHxWxC - absl::Span subpoly = {polys.data() + ib * n_poly_per_out, - n_poly_per_out}; - auto slice = prot.ParseResult(field, meta.prot_meta, subpoly, *this_dcd_ms); - int64_t n = slice.numel(); - DISPATCH_ALL_FIELDS(field, "", [&]() { - ArrayView xret(ret); - ArrayView xslice(slice); - pforeach(0, n, [&](int64_t i) { xret[i + offset] = xslice[i]; }); - }); - - offset += n; + + size_t num_ct_per_batch = blk0 * blk2; + std::vector _ct_array_to_pack(batch_size * num_ct_per_batch); + auto ct_array_to_pack = absl::MakeSpan(_ct_array_to_pack); + + size_t bytes_recv = conn->GetStats()->recv_bytes; + for (int64_t b = 0; b < (int64_t)batch_size; ++b) { + auto one_mat = + prv_mat.slice({b, 0, 0}, {b + 1, mat_shape[1], mat_shape[2]}, {1, 1, 1}) + .reshape({mat_shape[1], mat_shape[2]}); + doDotOLEReceiverRecvStep( + one_mat, dim3, is_self_lhs, cptype, + ct_array_to_pack.subspan(b * num_ct_per_batch, num_ct_per_batch), conn); } - return ret; + bytes_recv = conn->GetStats()->recv_bytes - bytes_recv; + + return doDotOLEReceiverSendStep(field, batch_size, meta, ct_array_to_pack, + cptype, conn, bytes_recv); } -CheetahDot::CheetahDot(std::shared_ptr lctx) { - impl_ = std::make_unique(lctx); +NdArrayRef CheetahDot::Impl::doDotOLE(const NdArrayRef &prv_mat, + yacl::link::Context *conn, + const Shape3D &dim3, bool is_self_lhs) { + auto eltype = prv_mat.eltype(); + auto field = eltype.template as()->field(); + const size_t field_bitlen = SizeOf(field) * 8; + size_t poly_deg = DecideSEALParameters(field_bitlen).poly_modulus_degree(); + + MatMatProtocol::Meta meta = {.dims = dim3}; + + // No cipher packing for small HE + CipherPackingType cptype = (field == FM32 || disable_pack_) + ? CipherPackingType::none + : CipherPackingType::rlwes; + Shape3D subshape; + size_t blk[3]; + if (cptype != CipherPackingType::none) { + // attempt to calculate the cost with packing + subshape = MatMatProtocol::GetSubMatShape(meta, poly_deg, false); + for (int i : {0, 1, 2}) { + blk[i] = CeilDiv(meta.dims[i], subshape[i]); + } + // If there is only 1 resultant RLWE; then we just skip any packing + cptype = blk[0] * blk[2] <= 1 ? CipherPackingType::none + : CipherPackingType::rlwes; + } + + LazyInit(field_bitlen, cptype != CipherPackingType::none); + + // Update subshape + subshape = MatMatProtocol::GetSubMatShape(meta, poly_deg, + cptype == CipherPackingType::none); + + for (int i : {0, 1, 2}) { + blk[i] = CeilDiv(meta.dims[i], subshape[i]); + } + + size_t lhs_poly_n = blk[0] * blk[1]; + size_t rhs_poly_n = blk[1] * blk[2]; + bool to_encrypt_lhs = lhs_poly_n <= rhs_poly_n; + bool act_as_encryptor = (is_self_lhs ^ to_encrypt_lhs) == 0; + + if (act_as_encryptor) { + doDotOLESenderSendStep(prv_mat, dim3, is_self_lhs, cptype, conn); + + size_t num_ct_to_recv = 0; + if (cptype == CipherPackingType::rlwes) { + num_ct_to_recv = CeilDiv(blk[0] * blk[2], subshape[1]); + } else { + num_ct_to_recv = blk[0] * blk[2]; + } + + return doDotOLESenderRecvStep(field, /*batch*/ 1, meta, num_ct_to_recv, + cptype, conn) + .reshape({dim3[0], dim3[2]}); + } + + std::vector _ct_array_to_pack(blk[0] * blk[2]); + auto ct_array_to_pack = absl::MakeSpan(_ct_array_to_pack); + size_t bytes_recv = conn->GetStats()->recv_bytes; + doDotOLEReceiverRecvStep(prv_mat, dim3, is_self_lhs, cptype, ct_array_to_pack, + conn); + bytes_recv = conn->GetStats()->recv_bytes - bytes_recv; + + return doDotOLEReceiverSendStep(field, /*batch*/ 1, meta, ct_array_to_pack, + cptype, conn, bytes_recv) + .reshape({dim3[0], dim3[2]}); } -CheetahDot::~CheetahDot() = default; +////////////////////////////////////////////// -ArrayRef CheetahDot::DotOLE(const ArrayRef &inp, yacl::link::Context *conn, - const Shape3D &dim3, bool is_lhs) { - SPU_ENFORCE(impl_ != nullptr); - SPU_ENFORCE(conn != nullptr); - return impl_->DotOLE(inp, conn, dim3, is_lhs); +CheetahDot::CheetahDot(const std::shared_ptr &lctx, + bool enable_matmul_pack) { + impl_ = std::make_unique(lctx, enable_matmul_pack); } -ArrayRef CheetahDot::DotOLE(const ArrayRef &inp, const Shape3D &dim3, - bool is_lhs) { +CheetahDot::~CheetahDot() = default; + +NdArrayRef CheetahDot::DotOLE(const NdArrayRef &inp, yacl::link::Context *conn, + const Shape3D &dim3, bool is_self_lhs) { SPU_ENFORCE(impl_ != nullptr); - return impl_->DotOLE(inp, nullptr, dim3, is_lhs); + SPU_ENFORCE(conn != nullptr); + return impl_->DotOLE(inp, conn, dim3, is_self_lhs); } -ArrayRef CheetahDot::Conv2dOLE(const ArrayRef &inp, yacl::link::Context *conn, - int64_t num_input, const Shape3D &tensor_shape, - int64_t num_kernels, const Shape3D &kernel_shape, - const Shape2D &window_strides, bool is_tensor) { +NdArrayRef CheetahDot::DotOLE(const NdArrayRef &inp, const Shape3D &dim3, + bool is_self_lhs) { SPU_ENFORCE(impl_ != nullptr); - SPU_ENFORCE(conn != nullptr); - return impl_->Conv2dOLE(inp, conn, num_input, tensor_shape, num_kernels, - kernel_shape, window_strides, is_tensor); + return impl_->DotOLE(inp, nullptr, dim3, is_self_lhs); } -ArrayRef CheetahDot::Conv2dOLE(const ArrayRef &inp, int64_t num_input, - const Shape3D &tensor_shape, int64_t num_kernels, - const Shape3D &kernel_shape, - const Shape2D &window_strides, bool is_tensor) { +NdArrayRef CheetahDot::BatchDotOLE(const NdArrayRef &inp, + yacl::link::Context *conn, + const Shape4D &dim4, bool is_self_lhs) { SPU_ENFORCE(impl_ != nullptr); - return impl_->Conv2dOLE(inp, nullptr, num_input, tensor_shape, num_kernels, - kernel_shape, window_strides, is_tensor); + return impl_->BatchDotOLE(inp, conn, dim4, is_self_lhs); } -} // namespace spu::mpc::cheetah +} // namespace spu::mpc::cheetah \ No newline at end of file diff --git a/libspu/mpc/cheetah/arith/cheetah_dot.h b/libspu/mpc/cheetah/arith/cheetah_dot.h index 8808c967..cf9bf36d 100644 --- a/libspu/mpc/cheetah/arith/cheetah_dot.h +++ b/libspu/mpc/cheetah/arith/cheetah_dot.h @@ -18,8 +18,8 @@ #include "yacl/link/context.h" +#include "libspu/core/ndarray_ref.h" #include "libspu/mpc/cheetah/arith/common.h" -#include "libspu/mpc/cheetah/array_ref.h" namespace spu::mpc::cheetah { @@ -29,7 +29,8 @@ namespace spu::mpc::cheetah { // https://eprint.iacr.org/2022/207.pdf class CheetahDot { public: - explicit CheetahDot(std::shared_ptr lctx); + explicit CheetahDot(const std::shared_ptr& lctx, + bool enable_matmul_pack = true); ~CheetahDot(); @@ -39,30 +40,21 @@ class CheetahDot { CheetahDot(CheetahDot&&) = delete; - ArrayRef DotOLE(const ArrayRef& inp, yacl::link::Context* conn, - const Shape3D& dim3, bool is_left_hand_side); + NdArrayRef DotOLE(const NdArrayRef& inp, const Shape3D& dim3, + bool is_self_lhs); - ArrayRef DotOLE(const ArrayRef& inp, const Shape3D& dim3, - bool is_left_hand_side); + // LHS.shape MxK, RHS.shape KxL => MxL + NdArrayRef DotOLE(const NdArrayRef& inp, yacl::link::Context* conn, + const Shape3D& dim3, bool is_self_lhs); - ArrayRef Conv2dOLE(const ArrayRef& inp, yacl::link::Context* conn, - int64_t num_input, const Shape3D& tensor_shape, - int64_t num_kernels, const Shape3D& kernel_shape, - const Shape2D& window_strides, bool is_tensor); - - ArrayRef Conv2dOLE(const ArrayRef& inp, int64_t num_input, - const Shape3D& tensor_shape, int64_t num_kernels, - const Shape3D& kernel_shape, const Shape2D& window_strides, - bool is_tensor); - - std::shared_ptr GetLink() const { return lctx_; } + // LHS.shape BxMxK, RHS.shape BxKxL => BxMxL + NdArrayRef BatchDotOLE(const NdArrayRef& inp, yacl::link::Context* conn, + const Shape4D& dim4, bool is_self_lhs); private: struct Impl; std::unique_ptr impl_{nullptr}; - - std::shared_ptr lctx_{nullptr}; }; -} // namespace spu::mpc::cheetah +} // namespace spu::mpc::cheetah \ No newline at end of file diff --git a/libspu/mpc/cheetah/arith/cheetah_dot_test.cc b/libspu/mpc/cheetah/arith/cheetah_dot_test.cc index 6254587e..7d083411 100644 --- a/libspu/mpc/cheetah/arith/cheetah_dot_test.cc +++ b/libspu/mpc/cheetah/arith/cheetah_dot_test.cc @@ -27,11 +27,11 @@ class CheetahDotTest INSTANTIATE_TEST_SUITE_P( Cheetah, CheetahDotTest, - testing::Combine(testing::Values(FieldType::FM32, FieldType::FM64, - FieldType::FM128), + testing::Combine(testing::Values(FieldType::FM64, FieldType::FM128), testing::Values(Shape3D{8, 7, 5}, Shape3D{57, 30, 1}, Shape3D{30, 57, 1}, Shape3D{18, 8, 41}, - Shape3D{25088, 5, 25})), + Shape3D{500, 13, 25}, + Shape3D{18, 768, 78})), [](const testing::TestParamInfo& p) { return fmt::format("{}x{}x{}x{}", std::get<0>(std::get<1>(p.param)), std::get<1>(std::get<1>(p.param)), @@ -44,32 +44,32 @@ TEST_P(CheetahDotTest, Basic) { auto field = std::get<0>(GetParam()); auto dim3 = std::get<1>(GetParam()); - std::vector mat(kWorldSize); + std::vector mat(kWorldSize); // NOTE(juhou): now Cheetah supports strided ArrayRef for (int64_t stride : {1, 3}) { - mat[0] = flatten(ring_rand(field, {stride * dim3[0] * dim3[1]})); - mat[0] = mat[0].slice(0, mat[0].numel(), stride); + mat[0] = ring_rand(field, {stride * dim3[0] * dim3[1]}); + mat[0] = mat[0].slice({0}, {mat[0].numel()}, {stride}); + mat[0] = mat[0].reshape({dim3[0], dim3[1]}); - mat[1] = flatten(ring_rand(field, {1 + stride * dim3[1] * dim3[2]})); - mat[1] = mat[1].slice(1, mat[1].numel(), stride); + mat[1] = ring_rand(field, {1 + stride * dim3[1] * dim3[2]}); + mat[1] = mat[1].slice({1}, {mat[1].numel()}, {stride}); + mat[1] = mat[1].reshape({dim3[1], dim3[2]}); - std::vector result(kWorldSize); + std::vector result(kWorldSize); utils::simulate(kWorldSize, [&](std::shared_ptr lctx) { int rank = lctx->Rank(); auto dot = std::make_shared(lctx); result[rank] = dot->DotOLE(mat[rank], dim3, rank == 0); }); - auto expected = flatten(ring_mmul(unflatten(mat[0], {dim3[0], dim3[1]}), - unflatten(mat[1], {dim3[1], dim3[2]}))); - auto computed = - flatten(ring_add(toNdArray(result[0]), toNdArray(result[1]))); + auto expected = ring_mmul(mat[0], mat[1]); + auto computed = ring_add(result[0], result[1]); EXPECT_EQ(expected.numel(), computed.numel()); const int64_t kMaxDiff = 1; DISPATCH_ALL_FIELDS(field, "_", [&]() { - auto e = ArrayView(expected); - auto c = ArrayView(computed); + auto e = NdArrayView(expected); + auto c = NdArrayView(computed); for (auto idx = 0; idx < expected.numel(); idx++) { EXPECT_NEAR(e[idx], c[idx], kMaxDiff); @@ -78,4 +78,60 @@ TEST_P(CheetahDotTest, Basic) { } } +TEST_P(CheetahDotTest, BatchDot) { + size_t kWorldSize = 2; + auto field = std::get<0>(GetParam()); + auto dim3 = std::get<1>(GetParam()); + int64_t B = 12; + Shape4D dim4 = {B, dim3[0], dim3[1], dim3[2]}; + + std::vector mat(kWorldSize); + mat[0] = ring_rand(field, {B, dim3[0], dim3[1]}); + mat[1] = ring_rand(field, {B, dim3[1], dim3[2]}); + + std::vector result(kWorldSize); + utils::simulate(kWorldSize, [&](std::shared_ptr lctx) { + int rank = lctx->Rank(); + auto dot = std::make_shared(lctx); + result[rank] = dot->BatchDotOLE(mat[rank], lctx.get(), dim4, rank == 0); + }); + + NdArrayRef expected = ring_zeros(field, {B * dim3[0] * dim3[2]}); + + mat[0] = mat[0].reshape({B * dim3[0] * dim3[1]}); + mat[1] = mat[1].reshape({B * dim3[1] * dim3[2]}); + for (int64_t b = 0; b < B; ++b) { + auto lhs = mat[0].slice({b * dim3[0] * dim3[1]}, + {(1 + b) * dim3[0] * dim3[1]}, {1}); + auto rhs = mat[1].slice({b * dim3[1] * dim3[2]}, + {(1 + b) * dim3[1] * dim3[2]}, {1}); + auto slice = + expected + .slice({b * dim3[0] * dim3[2]}, {(1 + b) * dim3[0] * dim3[2]}, {1}) + .reshape({dim3[0], dim3[2]}); + + ring_mmul_(slice, lhs.reshape({dim3[0], dim3[1]}), + rhs.reshape({dim3[1], dim3[2]})); + } + auto computed = ring_add(result[0], result[1]); + EXPECT_EQ(expected.numel(), computed.numel()); + + [[maybe_unused]] constexpr int64_t kMaxDiff = 1; + int64_t max_diff = 0; + DISPATCH_ALL_FIELDS(field, "_", [&]() { + auto e = NdArrayView(expected); + auto c = NdArrayView(computed); + + for (auto idx = 0; idx < expected.numel(); idx++) { + if (e[idx] != c[idx]) { + std::cout << fmt::format("expected {} got {} at {}\n", e[idx], c[idx], + idx); + } + int64_t diff = e[idx] < c[idx] ? c[idx] - e[idx] : e[idx] - c[idx]; + max_diff = std::max(max_diff, diff); + ASSERT_LE(diff, kMaxDiff); + } + }); +} + } // namespace spu::mpc::cheetah diff --git a/libspu/mpc/cheetah/arith/cheetah_mul.cc b/libspu/mpc/cheetah/arith/cheetah_mul.cc index 97ed796a..b0c46494 100644 --- a/libspu/mpc/cheetah/arith/cheetah_mul.cc +++ b/libspu/mpc/cheetah/arith/cheetah_mul.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "libspu/mpc/cheetah/arith/cheetah_mul.h" +#include #include #include #include @@ -31,35 +32,51 @@ #include "seal/util/polyarithsmallmod.h" #include "seal/valcheck.h" #include "spdlog/spdlog.h" -#include "xtensor/xview.hpp" #include "yacl/link/link.h" #include "yacl/utils/parallel.h" -#include "libspu/core/xt_helper.h" #include "libspu/mpc/cheetah/arith/common.h" #include "libspu/mpc/cheetah/rlwe/modswitch_helper.h" #include "libspu/mpc/cheetah/rlwe/utils.h" #include "libspu/mpc/utils/ring_ops.h" +struct Options { + size_t ring_bitlen; + size_t msg_bitlen; // msg_bitlen <= ring_bitlen +}; + +bool operator==(const Options &lhs, const Options &rhs) { + return lhs.ring_bitlen == rhs.ring_bitlen && lhs.msg_bitlen == rhs.msg_bitlen; +} + +template <> +struct std::hash { + size_t operator()(Options const &s) const noexcept { + return std::hash{}( + fmt::format("{}_{}", s.ring_bitlen, s.msg_bitlen)); + } +}; + namespace spu::mpc::cheetah { struct CheetahMul::Impl : public EnableCPRNG { public: - // RLWE parameters: N = 8192, Q \approx 2^{109}, t \approx 2^{40} - // NOTE(juhou): Under this parameters, the Mul() might introduce 1-bit error + // RLWE parameters: N = 4096, Q \approx 2^{109}, t \approx 2^{40} + // NOTE(lwj): Under this parameters, the Mul() might introduce 1-bit error // within a small chance Pr = 2^{-32}. - static constexpr size_t kPolyDegree = 8192; + static constexpr size_t kPolyDegree = 4096; static constexpr size_t kCipherModulusBits = 109; - static constexpr uint32_t kSmallPrimeBitLen = 40; - static constexpr int kNoiseFloodRandomBits = 50; + static constexpr int64_t kCtAsyncParallel = 16; - static constexpr size_t kParallelGrain = 1; - static constexpr size_t kCtAsyncParallel = 8; + const uint32_t small_crt_prime_len_; - explicit Impl(std::shared_ptr lctx) - : lctx_(std::move(lctx)) { - parms_ = DecideSEALParameters(kSmallPrimeBitLen); + explicit Impl(std::shared_ptr lctx, + bool allow_high_prob_one_bit_error) + : small_crt_prime_len_(allow_high_prob_one_bit_error ? 45 : 42), + lctx_(std::move(lctx)), + allow_high_prob_one_bit_error_(allow_high_prob_one_bit_error) { + parms_ = DecideSEALParameters(); } ~Impl() = default; @@ -67,130 +84,112 @@ struct CheetahMul::Impl : public EnableCPRNG { constexpr size_t OLEBatchSize() const { return kPolyDegree; } int Rank() const { return lctx_->Rank(); } - static seal::EncryptionParameters DecideSEALParameters(uint32_t ring_bitlen) { + + static seal::EncryptionParameters DecideSEALParameters() { size_t poly_deg = kPolyDegree; auto scheme_type = seal::scheme_type::bfv; auto parms = seal::EncryptionParameters(scheme_type); std::vector modulus_bits; - // NOTE(juhou): We set the 2nd modulus a bit larger than + // NOTE(lwj): We set the 2nd modulus a bit larger than // `kSmallPrimeBitLen`. We will drop the 2nd modulus during the H2A step. // Also, it helps reducing the noise in the BFV ciphertext. - modulus_bits = {60, 49}; + // Slightly larger than the recommand 109bit modulus in SEAL. + modulus_bits = {60, 52}; parms.set_use_special_prime(false); parms.set_poly_modulus_degree(poly_deg); parms.set_coeff_modulus(seal::CoeffModulus::Create(poly_deg, modulus_bits)); return parms; } - size_t num_slots() const { return parms_.poly_modulus_degree(); } + int64_t num_slots() const { return parms_.poly_modulus_degree(); } - void LazyExpandSEALContexts(uint32_t field_bitlen, + void LazyExpandSEALContexts(const Options &options, yacl::link::Context *conn = nullptr); - ArrayRef MulOLE(const ArrayRef &shr, yacl::link::Context *conn, - bool evaluator); + NdArrayRef MulOLE(const NdArrayRef &shr, yacl::link::Context *conn, + bool evaluator, uint32_t msg_width_hint); protected: void LocalExpandSEALContexts(size_t target); - static inline uint32_t FieldBitLen(FieldType f) { return 8 * SizeOf(f); } - - static inline uint32_t TotalCRTBitLen(uint32_t field_bitlen) { - if (field_bitlen < 26) { - // 1 <= k < 26 uses P = 80bit - return 2 * kSmallPrimeBitLen; - } else if (field_bitlen < 64) { - // 26 <= k < 64 uses P = 120bit - return 3 * kSmallPrimeBitLen; - } else if (field_bitlen < 128) { - // 64 <= k < 128 uses P = 160bit - // The Pr(1bit error) < 2^{-32} - return 4 * kSmallPrimeBitLen; - } - // k == 128 uses P = 280bit - SPU_ENFORCE_EQ(field_bitlen, 128U); - return 7 * kSmallPrimeBitLen; + inline uint32_t TotalCRTBitLen(const Options &options) const { + // Let P be the product of small CRT primes. + // When P > 2^{2k - 1}, the Pr(1-bit error) is about 2^{2k-4}/P. + // If allowing high prob of 1-bit error, we just set P ~ 2^{2k} + // Otherwise we set P ~ 2^{2k + 37} so that Pr(1-bit error) is about + // 2^{-40}. + auto bits = options.msg_bitlen + options.ring_bitlen + + (allow_high_prob_one_bit_error_ ? 4UL : 37UL); + auto nprimes = CeilDiv(bits, small_crt_prime_len_); + nprimes = std::min(7UL, nprimes); // Slightly reduce the margin for FM128 + return nprimes * small_crt_prime_len_; } - void LazyInitModSwitchHelper(uint32_t field_bitlen); + void LazyInitModSwitchHelper(const Options &options); - inline uint32_t WorkingContextSize(uint32_t field_bitlen) const { - uint32_t target_bitlen = TotalCRTBitLen(field_bitlen); + inline uint32_t WorkingContextSize(const Options &options) const { + uint32_t target_bitlen = TotalCRTBitLen(options); SPU_ENFORCE(target_bitlen <= current_crt_plain_bitlen_, - "Call ExpandSEALContexts first"); - return CeilDiv(target_bitlen, kSmallPrimeBitLen); + "Call LazyExpandSEALContexts first"); + return CeilDiv(target_bitlen, small_crt_prime_len_); } - struct Options { - size_t max_pack = 0; - bool scale_delta = false; - }; + void EncodeArray(const NdArrayRef &array, bool need_encrypt, + const Options &options, absl::Span out); - // The array will be partitioned into sub-array of `options.max_pack` length. - // If `options.max_pack = 0`, set it to `num_slots`. - // If `options.scale_delta = true`, scale up it by Delta. - void EncodeArray(const ArrayRef &array, const Options &options, - absl::Span out); - - void EncodeArray(const ArrayRef &array, const Options &options, - std::vector *out) { - size_t num_elts = array.numel(); + void EncodeArray(const NdArrayRef &array, bool need_encrypt, + const Options &options, std::vector *out) { + int64_t num_elts = array.numel(); auto eltype = array.eltype(); SPU_ENFORCE(num_elts > 0, "empty array"); SPU_ENFORCE(eltype.isa(), "array must be ring_type, got={}", eltype); - auto field = eltype.as()->field(); - size_t max_pack = options.max_pack > 0 ? options.max_pack : num_slots(); - size_t num_splits = CeilDiv(num_elts, max_pack); - size_t num_seal_ctx = WorkingContextSize(FieldBitLen(field)); - size_t num_polys = num_seal_ctx * num_splits; + int64_t num_splits = CeilDiv(num_elts, num_slots()); + int64_t num_seal_ctx = WorkingContextSize(options); + int64_t num_polys = num_seal_ctx * num_splits; out->resize(num_polys); absl::Span wrap(out->data(), out->size()); - EncodeArray(array, options, wrap); + EncodeArray(array, need_encrypt, options, wrap); } // out = EncodeArray(array) if out is not null // return the payload size (absl::Buffer) - size_t EncryptArrayThenSend(const ArrayRef &array, - std::vector *out = nullptr, + size_t EncryptArrayThenSend(const NdArrayRef &array, const Options &options, yacl::link::Context *conn = nullptr); // Sample random array `r` of `size` elements in the field. // Then compute ciphers*plains + r and response the result to the peer. // Return teh sampled array `r`. - ArrayRef MuThenResponse(FieldType field, size_t num_elts, - absl::Span ciphers, - absl::Span plains, - yacl::link::Context *conn = nullptr); - - ArrayRef PrepareRandomMask(FieldType field, size_t size, - const Options &options, - std::vector *encoded_mask); - - ArrayRef PrepareRandomMask(FieldType field, size_t size, - std::vector *encoded_mask) { - Options options; - options.max_pack = num_slots(); - return PrepareRandomMask(field, size, options, encoded_mask); - } + void MulThenResponse(FieldType field, int64_t num_elts, + const Options &options, + absl::Span ciphers, + absl::Span plains, + absl::Span rnd_mask, + yacl::link::Context *conn = nullptr); + + NdArrayRef PrepareRandomMask(FieldType field, int64_t size, + const Options &options, + std::vector *encoded_mask); - ArrayRef DecryptArray(FieldType field, size_t size, - const std::vector &ct_array); + NdArrayRef DecryptArray(FieldType field, int64_t size, const Options &options, + const std::vector &ct_array); void NoiseFloodCiphertext(RLWECt &ct, const seal::SEALContext &context); - void RandomizeCipherForDecryption(RLWECt &ct, size_t cidx); + void RandomizeCipherForDecryption(RLWECt &ct, size_t context_id); private: std::shared_ptr lctx_; + bool allow_high_prob_one_bit_error_ = false; + seal::EncryptionParameters parms_; uint32_t current_crt_plain_bitlen_{0}; // SEAL's contexts for ZZ_{2^k} - mutable std::mutex context_lock_; + mutable std::shared_mutex context_lock_; std::vector seal_cntxts_; // own secret key @@ -198,7 +197,7 @@ struct CheetahMul::Impl : public EnableCPRNG { // the public key received from the opposite party std::shared_ptr pair_public_key_; - std::unordered_map ms_helpers_; + std::unordered_map ms_helpers_; std::vector> sym_encryptors_; std::vector> decryptors_; @@ -206,17 +205,17 @@ struct CheetahMul::Impl : public EnableCPRNG { std::vector> bfv_encoders_; }; -void CheetahMul::Impl::LazyInitModSwitchHelper(uint32_t field_bitlen) { - std::unique_lock guard(context_lock_); - if (ms_helpers_.count(field_bitlen) > 0) { +void CheetahMul::Impl::LazyInitModSwitchHelper(const Options &options) { + std::unique_lock guard(context_lock_); + if (ms_helpers_.count(options) > 0) { return; } - uint32_t target_plain_bitlen = TotalCRTBitLen(field_bitlen); + uint32_t target_plain_bitlen = TotalCRTBitLen(options); SPU_ENFORCE(current_crt_plain_bitlen_ >= target_plain_bitlen); std::vector crt_modulus; - uint32_t accum_plain_bitlen = 0; + uint32_t accum_plain_bitlen = 0; for (size_t idx = 0; accum_plain_bitlen < target_plain_bitlen; ++idx) { auto crt_moduli = seal_cntxts_[idx].key_context_data()->parms().plain_modulus(); @@ -224,15 +223,15 @@ void CheetahMul::Impl::LazyInitModSwitchHelper(uint32_t field_bitlen) { crt_modulus.push_back(crt_moduli); } - // NOTE(juhou): we use ckks for this crt_context + // NOTE(lwj): we use ckks for this crt_context auto parms = seal::EncryptionParameters(seal::scheme_type::ckks); parms.set_poly_modulus_degree(parms_.poly_modulus_degree()); parms.set_coeff_modulus(crt_modulus); seal::SEALContext crt_context(parms, false, seal::sec_level_type::none); - SPU_ENFORCE(crt_context.parameters_set()); - ms_helpers_.emplace(field_bitlen, - ModulusSwitchHelper(crt_context, field_bitlen)); + + ms_helpers_.emplace(options, + ModulusSwitchHelper(crt_context, options.ring_bitlen)); } void CheetahMul::Impl::LocalExpandSEALContexts(size_t target) { @@ -267,28 +266,27 @@ void CheetahMul::Impl::LocalExpandSEALContexts(size_t target) { std::make_shared(seal_cntxts_[target], pk)); } -void CheetahMul::Impl::LazyExpandSEALContexts(uint32_t field_bitlen, +void CheetahMul::Impl::LazyExpandSEALContexts(const Options &options, yacl::link::Context *conn) { - uint32_t target_plain_bitlen = TotalCRTBitLen(field_bitlen); - std::unique_lock guard(context_lock_); + uint32_t target_plain_bitlen = TotalCRTBitLen(options); + std::unique_lock guard(context_lock_); if (current_crt_plain_bitlen_ >= target_plain_bitlen) { return; } - uint32_t num_seal_ctx = CeilDiv(target_plain_bitlen, kSmallPrimeBitLen); - std::vector crt_moduli_bits(num_seal_ctx, kSmallPrimeBitLen); - - auto crt_modulus = + uint32_t num_seal_ctx = CeilDiv(target_plain_bitlen, small_crt_prime_len_); + std::vector crt_moduli_bits(num_seal_ctx, small_crt_prime_len_); + std::vector crt_modulus = seal::CoeffModulus::Create(parms_.poly_modulus_degree(), crt_moduli_bits); - // NOTE(juhou): sort the primes to make sure new primes are placed in the back + // Sort the primes to make sure new primes are placed in the back std::sort(crt_modulus.begin(), crt_modulus.end(), [](const seal::Modulus &p, const seal::Modulus &q) { return p.value() > q.value(); }); - uint32_t current_num_ctx = seal_cntxts_.size(); + uint32_t current_num_ctx = seal_cntxts_.size(); for (uint32_t i = current_num_ctx; i < num_seal_ctx; ++i) { - uint64_t new_plain_modulus = crt_modulus[current_num_ctx].value(); + uint64_t new_plain_modulus = crt_modulus.at(i).value(); // new plain modulus should be co-prime to all the previous plain modulus for (uint32_t j = 0; j < current_num_ctx; ++j) { uint32_t prev_plain_modulus = @@ -303,14 +301,14 @@ void CheetahMul::Impl::LazyExpandSEALContexts(uint32_t field_bitlen, for (uint32_t idx = current_num_ctx; idx < num_seal_ctx; ++idx) { parms_.set_plain_modulus(crt_modulus[idx]); - seal_cntxts_.emplace_back(parms_, true, seal::sec_level_type::tc128); + seal_cntxts_.emplace_back(parms_, true, seal::sec_level_type::none); if (idx == 0) { seal::KeyGenerator keygen(seal_cntxts_[0]); secret_key_ = std::make_shared(keygen.secret_key()); auto pk = keygen.create_public_key(); - // NOTE(juhou): we patched seal/util/serializable.h + // NOTE(lwj): we patched seal/util/serializable.h auto pk_buf_send = EncodeSEALObject(pk.obj()); // exchange the public key int nxt_rank = conn->NextRank(); @@ -341,132 +339,141 @@ void CheetahMul::Impl::LazyExpandSEALContexts(uint32_t field_bitlen, std::make_shared(seal_cntxts_.back())); } current_crt_plain_bitlen_ = target_plain_bitlen; - SPDLOG_INFO( - "BeaverCheetah::Mul uses {} modulus ({} bit each) for {} bit ring", - num_seal_ctx, kSmallPrimeBitLen, field_bitlen); + SPDLOG_INFO("CheetahMul uses {} modulus for {} bit input over {} bit ring", + num_seal_ctx, options.msg_bitlen, options.ring_bitlen); } -ArrayRef CheetahMul::Impl::MulOLE(const ArrayRef &shr, - yacl::link::Context *conn, bool evaluator) { +NdArrayRef CheetahMul::Impl::MulOLE(const NdArrayRef &shr, + yacl::link::Context *conn, bool evaluator, + uint32_t msg_width_hint) { if (conn == nullptr) { conn = lctx_.get(); } auto eltype = shr.eltype(); SPU_ENFORCE(eltype.isa(), "must be ring_type, got={}", eltype); + SPU_ENFORCE(shr.shape().size() == 1, "need 1D Array"); SPU_ENFORCE(shr.numel() > 0); auto field = eltype.as()->field(); - LazyExpandSEALContexts(FieldBitLen(field), conn); - LazyInitModSwitchHelper(FieldBitLen(field)); + Options options; + options.ring_bitlen = SizeOf(field) * 8; + options.msg_bitlen = + msg_width_hint == 0 ? options.ring_bitlen : msg_width_hint; + SPU_ENFORCE(options.msg_bitlen > 0 && + options.msg_bitlen <= options.ring_bitlen); + LazyExpandSEALContexts(options, conn); + LazyInitModSwitchHelper(options); size_t numel = shr.numel(); int nxt_rank = conn->NextRank(); std::vector encoded_shr; if (evaluator) { - Options options; - options.max_pack = num_slots(); - options.scale_delta = false; - EncodeArray(shr, options, &encoded_shr); + // NOTE(lwj): + // 1. Alice & Bob enode local share to polynomials + // 2. Alice send ciphertexts to Bob + // 3. Bob multiply + // 4. Bob samples polys for random masking + // 5. Bob response the masked ciphertext to Alice + // We can overlap Step 2 (IO) and Step 4 (local computation) + EncodeArray(shr, false, options, &encoded_shr); size_t payload_sze = encoded_shr.size(); std::vector recv_ct(payload_sze); - for (size_t idx = 0; idx < payload_sze; ++idx) { - recv_ct[idx] = conn->Recv(nxt_rank, ""); - } - return MuThenResponse(field, numel, recv_ct, encoded_shr, conn); + auto io_task = std::async(std::launch::async, [&]() { + for (size_t idx = 0; idx < payload_sze; ++idx) { + recv_ct[idx] = conn->Recv(nxt_rank, ""); + } + }); + + std::vector encoded_mask; + auto random_mask = + PrepareRandomMask(field, shr.numel(), options, &encoded_mask); + SPU_ENFORCE(encoded_mask.size() == payload_sze, + "BeaverCheetah: random mask poly size mismatch"); + // wait for IO + io_task.get(); + MulThenResponse(field, numel, options, recv_ct, encoded_shr, + absl::MakeConstSpan(encoded_mask), conn); + return random_mask; } - size_t payload_sze = EncryptArrayThenSend(shr, nullptr, conn); + + size_t payload_sze = EncryptArrayThenSend(shr, options, conn); std::vector recv_ct(payload_sze); for (size_t idx = 0; idx < payload_sze; ++idx) { recv_ct[idx] = conn->Recv(nxt_rank, ""); } - return DecryptArray(field, numel, recv_ct); + return DecryptArray(field, numel, options, recv_ct); } -size_t CheetahMul::Impl::EncryptArrayThenSend(const ArrayRef &array, - std::vector *out, +size_t CheetahMul::Impl::EncryptArrayThenSend(const NdArrayRef &array, + const Options &options, yacl::link::Context *conn) { - size_t num_elts = array.numel(); + int64_t num_elts = array.numel(); auto eltype = array.eltype(); SPU_ENFORCE(num_elts > 0, "empty array"); SPU_ENFORCE(eltype.isa(), "array must be ring_type, got={}", eltype); - Options options; - options.max_pack = num_slots(); - options.scale_delta = true; + int64_t num_splits = CeilDiv(num_elts, num_slots()); + int64_t num_seal_ctx = WorkingContextSize(options); + int64_t num_polys = num_seal_ctx * num_splits; - auto field = eltype.as()->field(); - size_t field_bitlen = FieldBitLen(field); - size_t num_splits = CeilDiv(num_elts, options.max_pack); - size_t num_seal_ctx = WorkingContextSize(field_bitlen); - size_t num_polys = num_seal_ctx * num_splits; - - std::vector encoded_array; - absl::Span _wrap; - if (out != nullptr) { - out->resize(num_polys); - _wrap = {out->data(), num_polys}; - } else { - encoded_array.resize(num_polys); - _wrap = {encoded_array.data(), num_polys}; - } - EncodeArray(array, options, _wrap); + std::vector encoded_array(num_polys); + EncodeArray(array, /*scale_up*/ true, options, absl::MakeSpan(encoded_array)); std::vector payload(num_polys); - yacl::parallel_for( - 0, num_seal_ctx, kParallelGrain, [&](size_t cntxt_bgn, size_t cntxt_end) { - for (size_t c = cntxt_bgn; c < cntxt_end; ++c) { - size_t offset = c * num_splits; - for (size_t idx = 0; idx < num_splits; ++idx) { - // erase the random from memory - AutoMemGuard guard(&encoded_array[offset + idx]); - auto ct = - sym_encryptors_[c]->encrypt_symmetric(_wrap[offset + idx]); - payload.at(offset + idx) = EncodeSEALObject(ct.obj()); - } - } - }); + + yacl::parallel_for(0, num_polys, CalculateWorkLoad(num_polys), + [&](int64_t job_bgn, int64_t job_end) { + for (int64_t job_id = job_bgn; job_id < job_end; + ++job_id) { + int64_t cntxt_id = job_id / num_splits; + auto ct = sym_encryptors_[cntxt_id]->encrypt_symmetric( + encoded_array.at(job_id)); + payload.at(job_id) = EncodeSEALObject(ct.obj()); + } + }); if (conn == nullptr) { conn = lctx_.get(); } int nxt_rank = conn->NextRank(); - for (size_t i = 0; i < payload.size(); i += kCtAsyncParallel) { - size_t this_batch = std::min(payload.size() - i, kCtAsyncParallel); - conn->Send(nxt_rank, payload[i], ""); - for (size_t j = 1; j < this_batch; ++j) { - conn->SendAsync(nxt_rank, payload[i + j], ""); + for (int64_t i = 0; i < num_polys; i += kCtAsyncParallel) { + int64_t this_batch = std::min(num_polys - i, kCtAsyncParallel); + conn->Send(nxt_rank, payload[i], + fmt::format("CheetahMul::Send ct[{}] to rank={}", i, nxt_rank)); + for (int64_t j = 1; j < this_batch; ++j) { + conn->SendAsync( + nxt_rank, payload[i + j], + fmt::format("CheetahMul::Send ct[{}] to rank={}", i + j, nxt_rank)); } } return payload.size(); } -ArrayRef CheetahMul::Impl::PrepareRandomMask( - FieldType field, size_t size, const Options &options, +NdArrayRef CheetahMul::Impl::PrepareRandomMask( + FieldType field, int64_t size, const Options &options, std::vector *encoded_mask) { - const size_t max_pack = options.max_pack; - const size_t num_splits = CeilDiv(size, max_pack); - const size_t field_bitlen = FieldBitLen(field); - const size_t num_seal_ctx = WorkingContextSize(field_bitlen); - const size_t num_polys = num_seal_ctx * num_splits; - SPU_ENFORCE(ms_helpers_.count(field_bitlen) > 0); + const int64_t num_splits = CeilDiv(size, num_slots()); + const int64_t num_seal_ctx = WorkingContextSize(options); + const int64_t num_polys = num_seal_ctx * num_splits; + SPU_ENFORCE(ms_helpers_.count(options) > 0); encoded_mask->resize(num_polys); // sample r from [0, P) in the RNS format // Ref: the one-bit approximate re-sharing in Cheetah's paper (eprint ver). std::vector random_rns(size * num_seal_ctx); - for (size_t cidx = 0; cidx < num_seal_ctx; ++cidx) { + for (int64_t cidx = 0; cidx < num_seal_ctx; ++cidx) { const auto &plain_mod = seal_cntxts_[cidx].key_context_data()->parms().plain_modulus(); - size_t offset = cidx * num_splits; + int64_t offset = cidx * num_splits; std::vector u64tmp(num_slots(), 0); - for (size_t j = 0; j < num_splits; ++j) { - size_t bgn = j * max_pack; - size_t end = std::min(bgn + max_pack, size); - size_t len = end - bgn; + for (int64_t j = 0; j < num_splits; ++j) { + int64_t bgn = j * num_slots(); + int64_t end = std::min(bgn + num_slots(), size); + int64_t len = end - bgn; // sample the RNS component of r from [0, p_i) uint64_t *dst_ptr = random_rns.data() + cidx * size + bgn; @@ -481,107 +488,99 @@ ArrayRef CheetahMul::Impl::PrepareRandomMask( } // convert x \in [0, P) to [0, 2^k) by round(2^k*x/P) - auto &ms_helper = ms_helpers_.find(field_bitlen)->second; + auto &ms_helper = ms_helpers_.find(options)->second; absl::Span inp(random_rns.data(), random_rns.size()); - return ms_helper.ModulusDownRNS(field, inp); + return ms_helper.ModulusDownRNS(field, {size}, inp); } -void CheetahMul::Impl::EncodeArray(const ArrayRef &array, +void CheetahMul::Impl::EncodeArray(const NdArrayRef &array, bool need_encrypt, const Options &options, absl::Span out) { - size_t num_elts = array.numel(); + int64_t num_elts = array.numel(); auto eltype = array.eltype(); SPU_ENFORCE(num_elts > 0, "empty array"); + SPU_ENFORCE(array.shape().size() == 1, "need 1D array"); SPU_ENFORCE(eltype.isa(), "array must be ring_type, got={}", eltype); - auto field = eltype.as()->field(); - size_t max_pack = options.max_pack > 0 ? options.max_pack : num_slots(); - size_t num_splits = CeilDiv(num_elts, max_pack); - size_t field_bitlen = FieldBitLen(field); - size_t num_seal_ctx = WorkingContextSize(field_bitlen); - size_t num_polys = num_seal_ctx * num_splits; - SPU_ENFORCE_EQ(out.size(), num_polys, + int64_t num_splits = CeilDiv(num_elts, num_slots()); + int64_t num_seal_ctx = WorkingContextSize(options); + int64_t num_polys = num_seal_ctx * num_splits; + SPU_ENFORCE_EQ(out.size(), (size_t)num_polys, "out size mismatch, expect={}, got={}size", num_polys, out.size()); - SPU_ENFORCE(ms_helpers_.count(field_bitlen) > 0); + SPU_ENFORCE(ms_helpers_.count(options) > 0); - auto &ms_helper = ms_helpers_.find(field_bitlen)->second; + auto &ms_helper = ms_helpers_.find(options)->second; yacl::parallel_for( - 0, num_seal_ctx, kParallelGrain, [&](size_t cntxt_bgn, size_t cntxt_end) { - std::vector u64tmp(num_slots()); - - for (size_t cidx = cntxt_bgn; cidx < cntxt_end; ++cidx) { - const size_t offset = cidx * num_splits; - - for (size_t idx = 0; idx < num_splits; ++idx) { - auto slice = array.slice(idx * max_pack, - std::min(num_elts, (idx + 1) * max_pack)); - absl::Span dst(u64tmp.data(), slice.numel()); - - if (options.scale_delta) { - ms_helper.ModulusUpAt(slice, cidx, dst); - } else { - ms_helper.CenteralizeAt(slice, cidx, dst); - } - - // zero-padding the rest - std::fill_n(u64tmp.data() + slice.numel(), - u64tmp.size() - slice.numel(), 0); - - CATCH_SEAL_ERROR( - bfv_encoders_[cidx]->encode(u64tmp, out[offset + idx])); + 0, num_polys, CalculateWorkLoad(num_polys), + [&](int64_t job_bgn, int64_t job_end) { + std::vector _u64tmp(num_slots()); + auto u64tmp = absl::MakeSpan(_u64tmp); + + for (int64_t job_id = job_bgn; job_id < job_end; ++job_id) { + int64_t cntxt_id = job_id / num_splits; + int64_t split_id = job_id % num_splits; + int64_t slice_bgn = split_id * num_slots(); + int64_t slice_end = + std::min(num_elts, slice_bgn + static_cast(num_slots())); + + auto slice = array.slice({slice_bgn}, {slice_end}, {1}); + auto dst = u64tmp.subspan(0, slice_end - slice_bgn); + if (need_encrypt) { + ms_helper.ModulusUpAt(slice, cntxt_id, dst); + } else { + ms_helper.CenteralizeAt(slice, cntxt_id, dst); } + // zero-padding the rest + std::fill_n(u64tmp.data() + slice.numel(), + u64tmp.size() - slice.numel(), 0); + + CATCH_SEAL_ERROR( + bfv_encoders_[cntxt_id]->encode(_u64tmp, out[job_id])); } }); } -ArrayRef CheetahMul::Impl::MuThenResponse( - FieldType field, size_t num_elts, absl::Span ciphers, - absl::Span plains, yacl::link::Context *conn) { - SPU_ENFORCE(!ciphers.empty(), "BeaverCheetah: empty cipher"); +void CheetahMul::Impl::MulThenResponse(FieldType field, int64_t num_elts, + const Options &options, + absl::Span ciphers, + absl::Span plains, + absl::Span ecd_random, + yacl::link::Context *conn) { + SPU_ENFORCE(!ciphers.empty(), "CheetahMul: empty cipher"); SPU_ENFORCE(plains.size() == ciphers.size(), - "BeaverCheetah: ct/pt size mismatch"); - - const size_t max_pack = num_slots(); - const size_t num_splits = CeilDiv(num_elts, max_pack); - const size_t field_bitlen = FieldBitLen(field); - const size_t num_seal_ctx = WorkingContextSize(field_bitlen); - const size_t num_ciphers = num_seal_ctx * num_splits; - SPU_ENFORCE(ciphers.size() == num_ciphers, - fmt::format("MuThenResponse: expect {} != {}", num_ciphers, - ciphers.size())); - - std::vector ecd_random; - auto rnd_mask = PrepareRandomMask(field, num_elts, &ecd_random); - SPU_ENFORCE(ecd_random.size() == num_ciphers, - "BeaverCheetah: encoded poly size mismatch"); + "CheetahMul: ct/pt size mismatch"); + + const int64_t num_splits = CeilDiv(num_elts, num_slots()); + const int64_t num_seal_ctx = WorkingContextSize(options); + const int64_t num_ciphers = num_seal_ctx * num_splits; + SPU_ENFORCE(ciphers.size() == (size_t)num_ciphers, + "CheetahMul : expect {} != {}", num_ciphers, ciphers.size()); + SPU_ENFORCE(ecd_random.size() == (size_t)num_ciphers, + "CheetahMul: encoded rnaomd size mismatch"); std::vector response(num_ciphers); yacl::parallel_for( - 0, num_seal_ctx, kParallelGrain, [&](size_t cntxt_bgn, size_t cntxt_end) { + 0, num_ciphers, CalculateWorkLoad(num_ciphers), + [&](int64_t job_bgn, int64_t job_end) { RLWECt ct; - for (size_t cidx = cntxt_bgn; cidx < cntxt_end; ++cidx) { - const size_t offset = cidx * num_splits; - const auto &seal_cntxt = seal_cntxts_[cidx]; + std::vector u64tmp(num_slots(), 0); + for (int64_t job_id = job_bgn; job_id < job_end; ++job_id) { + int64_t cntxt_id = job_id / num_splits; + // int64_t offset = cntxt_id * num_splits; + const auto &seal_cntxt = seal_cntxts_[cntxt_id]; seal::Evaluator evaluator(seal_cntxt); - - std::vector u64tmp(max_pack, 0); // Multiply-then-H2A - for (size_t idx = 0; idx < num_splits; ++idx) { - DecodeSEALObject(ciphers.at(offset + idx), seal_cntxt, &ct); - // Multiply step - CATCH_SEAL_ERROR( - evaluator.multiply_plain_inplace(ct, plains[offset + idx])); - // re-randomize the ciphertext (e.g., noise flood) - RandomizeCipherForDecryption(ct, cidx); - // H2A - CATCH_SEAL_ERROR( - evaluator.sub_plain_inplace(ct, ecd_random[offset + idx])); - // Truncate for a smaller communication - TruncateBFVForDecryption(ct, seal_cntxt); - response[offset + idx] = EncodeSEALObject(ct); - } + DecodeSEALObject(ciphers.at(job_id), seal_cntxt, &ct); + // Multiply step + CATCH_SEAL_ERROR( + evaluator.multiply_plain_inplace(ct, plains[job_id])); + // H2A + CATCH_SEAL_ERROR(evaluator.sub_plain_inplace(ct, ecd_random[job_id])); + // re-randomize the ciphertext (e.g., noise flood) + RandomizeCipherForDecryption(ct, cntxt_id); + response[job_id] = EncodeSEALObject(ct); } }); @@ -590,69 +589,55 @@ ArrayRef CheetahMul::Impl::MuThenResponse( } int nxt_rank = conn->NextRank(); - for (size_t i = 0; i < response.size(); i += kCtAsyncParallel) { - size_t this_batch = std::min(response.size() - i, kCtAsyncParallel); - conn->Send(nxt_rank, response[i], ""); - for (size_t j = 1; j < this_batch; ++j) { - conn->SendAsync(nxt_rank, response[i + j], ""); + for (int64_t i = 0; i < num_ciphers; i += kCtAsyncParallel) { + int64_t this_batch = std::min(num_ciphers - i, kCtAsyncParallel); + conn->Send(nxt_rank, response[i], + fmt::format("MulThenResponse ct[{}] to rank{}", i, nxt_rank)); + for (int64_t j = 1; j < this_batch; ++j) { + conn->SendAsync( + nxt_rank, response[i + j], + fmt::format("MulThenResponse ct[{}] to rank{}", i + j, nxt_rank)); } } - - for (auto &pt : ecd_random) { - AutoMemGuard{&pt}; - } - - return rnd_mask; } -ArrayRef CheetahMul::Impl::DecryptArray( - FieldType field, size_t size, const std::vector &ct_array) { - const size_t max_pack = num_slots(); - const size_t num_splits = CeilDiv(size, max_pack); - const size_t field_bitlen = FieldBitLen(field); - const size_t num_seal_ctx = WorkingContextSize(field_bitlen); - const size_t num_ciphers = num_seal_ctx * num_splits; - SPU_ENFORCE(ct_array.size() == num_ciphers, - "BeaverCheetah: cipher size mismatch"); - SPU_ENFORCE(ms_helpers_.count(field_bitlen) > 0); - - auto rns_temp = - ring_zeros(FieldType::FM64, {static_cast(size * num_seal_ctx)}); - auto xrns_temp = xt_mutable_adapt(rns_temp); - +NdArrayRef CheetahMul::Impl::DecryptArray( + FieldType field, int64_t size, const Options &options, + const std::vector &ct_array) { + const int64_t num_splits = CeilDiv(size, num_slots()); + const int64_t num_seal_ctx = WorkingContextSize(options); + const int64_t num_ciphers = num_seal_ctx * num_splits; + SPU_ENFORCE(ct_array.size() == (size_t)num_ciphers, + "CheetahMul: cipher size mismatch"); + SPU_ENFORCE(ms_helpers_.count(options) > 0); + + // Decrypt ciphertexts into size x num_modulus + // Then apply the ModulusDown to get value in Z_{2^k}. + std::vector rns_temp(size * num_seal_ctx, 0); yacl::parallel_for( - 0, num_seal_ctx, kParallelGrain, [&](size_t cntxt_bgn, size_t cntxt_end) { - // Loop each SEALContext - // For each context, we obtain `size` uint64 from `num_splits` polys. - // Each poly will decode to `max_pack` uint64, i.e., `max_pack * - // num_splits - // >= size`. - for (size_t cidx = cntxt_bgn; cidx < cntxt_end; ++cidx) { - const size_t offset = cidx * num_splits; - auto ctx_slice = - xt::view(xrns_temp, xt::range(cidx * size, cidx * size + size)); - - RLWEPt pt; - RLWECt ct; - std::vector subarray(max_pack, 0); - - for (size_t idx = 0; idx < num_splits; ++idx) { - DecodeSEALObject(ct_array.at(offset + idx), seal_cntxts_[cidx], - &ct); - CATCH_SEAL_ERROR(decryptors_[cidx]->decrypt(ct, pt)); - CATCH_SEAL_ERROR(bfv_encoders_[cidx]->decode(pt, subarray)); - - size_t bgn = idx * max_pack; - size_t end = std::min(size, bgn + max_pack); - size_t len = end - bgn; - std::copy_n(subarray.data(), len, ctx_slice.begin() + bgn); - } + 0, num_ciphers, CalculateWorkLoad(num_ciphers), + [&](int64_t job_bgn, int64_t job_end) { + RLWEPt pt; + RLWECt ct; + std::vector subarray(num_slots(), 0); + for (int64_t job_id = job_bgn; job_id < job_end; ++job_id) { + int64_t cntxt_id = job_id / num_splits; + int64_t split_id = job_id % num_splits; + + DecodeSEALObject(ct_array.at(job_id), seal_cntxts_[cntxt_id], &ct); + CATCH_SEAL_ERROR(decryptors_[cntxt_id]->decrypt(ct, pt)); + CATCH_SEAL_ERROR(bfv_encoders_[cntxt_id]->decode(pt, subarray)); + + int64_t slice_bgn = split_id * num_slots(); + int64_t slice_end = std::min(size, slice_bgn + num_slots()); + int64_t slice_modulus_bgn = cntxt_id * size + slice_bgn; + std::copy_n(subarray.data(), slice_end - slice_bgn, + rns_temp.data() + slice_modulus_bgn); } }); - auto &ms_helper = ms_helpers_.find(field_bitlen)->second; - absl::Span inp(xrns_temp.data(), xrns_temp.size()); - return ms_helper.ModulusDownRNS(field, inp); + auto &ms_helper = ms_helpers_.find(options)->second; + return ms_helper.ModulusDownRNS(field, {size}, absl::MakeSpan(rns_temp)); } void CheetahMul::Impl::NoiseFloodCiphertext(RLWECt &ct, @@ -672,11 +657,8 @@ void CheetahMul::Impl::NoiseFloodCiphertext(RLWECt &ct, // sample random from [0, 2^{kStatRandom}) auto random = CPRNG(FieldType::FM64, num_coeffs); - for (int64_t idx = 0; idx < random.numel(); ++idx) { - random.at(idx) &= range_mask; - } - AutoMemGuard guard(&random); - + NdArrayView xrandom(random); + pforeach(0, xrandom.numel(), [&](int64_t i) { xrandom[i] &= range_mask; }); // add random to each modulus of ct.data(0) uint64_t *dst_ptr = ct.data(0); for (size_t l = 0; l < num_modulus; ++l) { @@ -684,22 +666,21 @@ void CheetahMul::Impl::NoiseFloodCiphertext(RLWECt &ct, if (modulus[l].bit_count() > kNoiseFloodRandomBits) { // When prime[l] > 2^{kNoiseFloodRandomBits} then we add directly add // the random - add_poly_coeffmod(reinterpret_cast(random.data()), dst_ptr, - num_coeffs, modulus[l], dst_ptr); + add_poly_coeffmod(&xrandom[0], dst_ptr, num_coeffs, modulus[l], dst_ptr); } else { - // When prime[l] < 2^{kNoiseFloodRandomBits} we need to compute mod - // prime[l] first + // When prime[l] < 2^{kNoiseFloodRandomBits} we need to compute modulo + // first. std::vector tmp(num_coeffs); - modulo_poly_coeffs(reinterpret_cast(random.data()), - num_coeffs, modulus[l], tmp.data()); + modulo_poly_coeffs(&xrandom[0], num_coeffs, modulus[l], tmp.data()); add_poly_coeffmod(tmp.data(), dst_ptr, num_coeffs, modulus[l], dst_ptr); } dst_ptr += num_coeffs; } } -void CheetahMul::Impl::RandomizeCipherForDecryption(RLWECt &ct, size_t cidx) { - auto &seal_cntxt = seal_cntxts_.at(cidx); +void CheetahMul::Impl::RandomizeCipherForDecryption(RLWECt &ct, + size_t context_id) { + auto &seal_cntxt = seal_cntxts_.at(context_id); auto context_data = seal_cntxt.last_context_data(); yacl::CheckNotNull(context_data.get()); seal::Evaluator evaluator(seal_cntxt); @@ -713,12 +694,17 @@ void CheetahMul::Impl::RandomizeCipherForDecryption(RLWECt &ct, size_t cidx) { // 3. Add zero-encryption for re-randomization RLWECt zero_enc; - CATCH_SEAL_ERROR(pk_encryptors_[cidx]->encrypt_zero(ct.parms_id(), zero_enc)); + CATCH_SEAL_ERROR( + pk_encryptors_[context_id]->encrypt_zero(ct.parms_id(), zero_enc)); CATCH_SEAL_ERROR(evaluator.add_inplace(ct, zero_enc)); + + // 4. Truncate for smaller communication + TruncateBFVForDecryption(ct, seal_cntxt); } -CheetahMul::CheetahMul(std::shared_ptr lctx) { - impl_ = std::make_unique(lctx); +CheetahMul::CheetahMul(std::shared_ptr lctx, + bool allow_high_prob_one_bit_error) { + impl_ = std::make_unique(lctx, allow_high_prob_one_bit_error); } CheetahMul::~CheetahMul() = default; @@ -730,16 +716,17 @@ size_t CheetahMul::OLEBatchSize() const { return impl_->OLEBatchSize(); } -ArrayRef CheetahMul::MulOLE(const ArrayRef &inp, yacl::link::Context *conn, - bool evaluator) { +NdArrayRef CheetahMul::MulOLE(const NdArrayRef &inp, yacl::link::Context *conn, + bool is_evaluator, uint32_t msg_width_hint) { SPU_ENFORCE(impl_ != nullptr); SPU_ENFORCE(conn != nullptr); - return impl_->MulOLE(inp, conn, evaluator); + return impl_->MulOLE(inp, conn, is_evaluator, msg_width_hint); } -ArrayRef CheetahMul::MulOLE(const ArrayRef &inp, bool evaluator) { +NdArrayRef CheetahMul::MulOLE(const NdArrayRef &inp, bool is_evaluator, + uint32_t msg_width_hint) { SPU_ENFORCE(impl_ != nullptr); - return impl_->MulOLE(inp, nullptr, evaluator); + return impl_->MulOLE(inp, nullptr, is_evaluator, msg_width_hint); } -} // namespace spu::mpc::cheetah +} // namespace spu::mpc::cheetah \ No newline at end of file diff --git a/libspu/mpc/cheetah/arith/cheetah_mul.h b/libspu/mpc/cheetah/arith/cheetah_mul.h index 8f711850..03ccd65d 100644 --- a/libspu/mpc/cheetah/arith/cheetah_mul.h +++ b/libspu/mpc/cheetah/arith/cheetah_mul.h @@ -18,7 +18,7 @@ #include "yacl/link/context.h" -#include "libspu/mpc/cheetah/array_ref.h" +#include "libspu/core/ndarray_ref.h" namespace spu::mpc::cheetah { @@ -28,7 +28,8 @@ namespace spu::mpc::cheetah { // https://eprint.iacr.org/2019/577.pdf class CheetahMul { public: - explicit CheetahMul(std::shared_ptr lctx); + explicit CheetahMul(std::shared_ptr lctx, + bool allow_high_prob_one_bit_error = false); ~CheetahMul(); @@ -38,10 +39,11 @@ class CheetahMul { CheetahMul(CheetahMul&&) = delete; - ArrayRef MulOLE(const ArrayRef& inp, yacl::link::Context* conn, - bool evaluator); + NdArrayRef MulOLE(const NdArrayRef& inp, yacl::link::Context* conn, + bool is_evaluator, uint32_t msg_width_hint = 0); - ArrayRef MulOLE(const ArrayRef& inp, bool evaluator); + NdArrayRef MulOLE(const NdArrayRef& inp, bool is_evaluator, + uint32_t msg_width_hint = 0); int Rank() const; @@ -53,4 +55,4 @@ class CheetahMul { std::unique_ptr impl_{nullptr}; }; -} // namespace spu::mpc::cheetah +} // namespace spu::mpc::cheetah \ No newline at end of file diff --git a/libspu/mpc/cheetah/arith/cheetah_mul_test.cc b/libspu/mpc/cheetah/arith/cheetah_mul_test.cc index 9187cb12..abaac8ce 100644 --- a/libspu/mpc/cheetah/arith/cheetah_mul_test.cc +++ b/libspu/mpc/cheetah/arith/cheetah_mul_test.cc @@ -58,15 +58,11 @@ TEST_P(CheetahMulTest, Basic) { NdArrayRef cross0, cross1; if (rank == 0) { - cross0 = - unflatten(mul->MulOLE(flatten(a_shr[0]), true), a_shr[0].shape()); - cross1 = - unflatten(mul->MulOLE(flatten(b_shr[0]), true), a_shr[0].shape()); + cross0 = mul->MulOLE(a_shr[0], true); + cross1 = mul->MulOLE(b_shr[0], true); } else { - cross0 = - unflatten(mul->MulOLE(flatten(b_shr[1]), false), a_shr[1].shape()); - cross1 = - unflatten(mul->MulOLE(flatten(a_shr[1]), false), b_shr[1].shape()); + cross0 = mul->MulOLE(b_shr[1], false); + cross1 = mul->MulOLE(a_shr[1], false); } result[rank] = ring_mul(a_shr[rank], b_shr[rank]); @@ -105,15 +101,11 @@ TEST_P(CheetahMulTest, BasicBinary) { NdArrayRef cross0, cross1; if (rank == 0) { - cross0 = - unflatten(mul->MulOLE(flatten(a_shr[0]), true), a_shr[0].shape()); - cross1 = - unflatten(mul->MulOLE(flatten(b_shr[0]), true), b_shr[0].shape()); + cross0 = mul->MulOLE(a_shr[0], true); + cross1 = mul->MulOLE(b_shr[0], true); } else { - cross0 = - unflatten(mul->MulOLE(flatten(b_shr[1]), false), a_shr[1].shape()); - cross1 = - unflatten(mul->MulOLE(flatten(a_shr[1]), false), b_shr[1].shape()); + cross0 = mul->MulOLE(b_shr[1], false); + cross1 = mul->MulOLE(a_shr[1], false); } result[rank] = ring_mul(a_shr[rank], b_shr[rank]); @@ -173,15 +165,11 @@ TEST_P(CheetahMulTest, MixedRingSizeMul) { NdArrayRef cross0, cross1; if (rank == 0) { - cross0 = - unflatten(mul->MulOLE(flatten(a_shr[0]), true), a_shr[0].shape()); - cross1 = - unflatten(mul->MulOLE(flatten(b_shr[0]), true), b_shr[0].shape()); + cross0 = mul->MulOLE(a_shr[0], true); + cross1 = mul->MulOLE(b_shr[0], true); } else { - cross0 = - unflatten(mul->MulOLE(flatten(b_shr[1]), false), a_shr[1].shape()); - cross1 = - unflatten(mul->MulOLE(flatten(a_shr[1]), false), a_shr[1].shape()); + cross0 = mul->MulOLE(b_shr[1], false); + cross1 = mul->MulOLE(a_shr[1], false); } result[rank] = ring_mul(a_shr[rank], b_shr[rank]); @@ -189,15 +177,11 @@ TEST_P(CheetahMulTest, MixedRingSizeMul) { ring_add_(result[rank], cross1); if (rank == 0) { - cross0 = - unflatten(mul->MulOLE(flatten(c_shr[0]), true), c_shr[0].shape()); - cross1 = - unflatten(mul->MulOLE(flatten(d_shr[0]), true), d_shr[0].shape()); + cross0 = mul->MulOLE(c_shr[0], true); + cross1 = mul->MulOLE(d_shr[0], true); } else { - cross1 = - unflatten(mul->MulOLE(flatten(d_shr[1]), false), d_shr[1].shape()); - cross0 = - unflatten(mul->MulOLE(flatten(c_shr[1]), false), c_shr[1].shape()); + cross1 = mul->MulOLE(d_shr[1], false); + cross0 = mul->MulOLE(c_shr[1], false); } result2[rank] = ring_mul(c_shr[rank], d_shr[rank]); diff --git a/libspu/mpc/cheetah/arith/common.cc b/libspu/mpc/cheetah/arith/common.cc index 83383719..acab69f0 100644 --- a/libspu/mpc/cheetah/arith/common.cc +++ b/libspu/mpc/cheetah/arith/common.cc @@ -21,6 +21,13 @@ #include "libspu/mpc/utils/ring_ops.h" namespace spu::mpc::cheetah { +size_t CalculateWorkLoad(size_t num_jobs, size_t num_cores) { + if (num_cores == 0) { + num_cores = spu::getNumberOfProc(); + } + return (num_jobs + num_cores - 1) / num_cores; +} + EnableCPRNG::EnableCPRNG() : seed_(yacl::crypto::RandSeed(/*drbg*/ true)), prng_counter_(0) {} @@ -35,7 +42,7 @@ void EnableCPRNG::UniformPrime(const seal::Modulus &prime, max_random - BarrettReduce(max_random, prime) - 1; auto r = CPRNG(FieldType::FM64, dst.size()); - ArrayView xr(r); + NdArrayView xr(r); pforeach(0, dst.size(), [&](int64_t i) { dst[i] = xr[i]; }); // std::copy_n(xr.data(), xr.size(), dst.data()); std::transform(dst.data(), dst.data() + dst.size(), dst.data(), @@ -70,7 +77,7 @@ void EnableCPRNG::UniformPoly(const seal::SEALContext &context, RLWEPt *poly, } // Uniform random on ring 2^k -ArrayRef EnableCPRNG::CPRNG(FieldType field, size_t size) { +NdArrayRef EnableCPRNG::CPRNG(FieldType field, size_t size) { constexpr uint64_t kPRNG_THREASHOLD = 1ULL << 50; // Lock prng_counter_ std::scoped_lock guard(counter_lock_); @@ -78,8 +85,7 @@ ArrayRef EnableCPRNG::CPRNG(FieldType field, size_t size) { seed_ = yacl::crypto::RandSeed(true); prng_counter_ = 0; } - return flatten( - ring_rand(field, {static_cast(size)}, seed_, &prng_counter_)); + return ring_rand(field, {static_cast(size)}, seed_, &prng_counter_); } NdArrayRef ring_conv2d(const NdArrayRef &tensor, const NdArrayRef &filter, diff --git a/libspu/mpc/cheetah/arith/common.h b/libspu/mpc/cheetah/arith/common.h index eb551154..31350576 100644 --- a/libspu/mpc/cheetah/arith/common.h +++ b/libspu/mpc/cheetah/arith/common.h @@ -15,7 +15,7 @@ #include -#include "libspu/mpc/cheetah/array_ref.h" +#include "libspu/core/ndarray_ref.h" #include "libspu/mpc/cheetah/rlwe/types.h" namespace seal { @@ -46,7 +46,7 @@ struct EnableCPRNG { seal::parms_id_type pid = seal::parms_id_zero); // Uniform random on ring 2^k - ArrayRef CPRNG(FieldType field, size_t size); + NdArrayRef CPRNG(FieldType field, size_t size); protected: mutable std::mutex counter_lock_; @@ -64,4 +64,5 @@ inline int64_t calcNumel(absl::Span shape) { std::multiplies<>()); } +size_t CalculateWorkLoad(size_t num_jobs, size_t num_cores = 0); } // namespace spu::mpc::cheetah diff --git a/libspu/mpc/cheetah/arith/conv2d_prot.cc b/libspu/mpc/cheetah/arith/conv2d_prot.cc index b9e9fde8..1c197deb 100644 --- a/libspu/mpc/cheetah/arith/conv2d_prot.cc +++ b/libspu/mpc/cheetah/arith/conv2d_prot.cc @@ -394,11 +394,11 @@ ArrayRef Conv2DProtocol::ParseResult(FieldType field, const Meta &meta, return ParseResult(field, meta, rlwe, tencoder_->ms_helper()); } -ArrayRef Conv2DProtocol::ParseResult( - FieldType field, const Meta &meta, absl::Span rlwe, - const ModulusSwitchHelper &ms_helper) const { +ArrayRef Conv2DProtocol::ParseResult(FieldType field, const Meta &meta, + absl::Span rlwe, + const ModulusSwitchHelper &msh) const { SPU_ENFORCE_EQ(rlwe.size(), GetOutSize(meta)); - size_t expected_n = poly_deg_ * ms_helper.coeff_modulus_size(); + size_t expected_n = poly_deg_ * msh.coeff_modulus_size(); SPU_ENFORCE(std::all_of(rlwe.data(), rlwe.data() + rlwe.size(), [expected_n](const RLWEPt &pt) { return pt.coeff_count() == expected_n; @@ -428,8 +428,8 @@ ArrayRef Conv2DProtocol::ParseResult( SPU_ENFORCE_EQ(static_cast(coefficients.size()), calcNumel(slice_shape)); - auto computed = toNdArray( - ms_helper.ModulusDownRNS(field, MakeSpan(rlwe[poly_idx++]))); + auto computed = + msh.ModulusDownRNS(field, {poly_deg_}, MakeSpan(rlwe[poly_idx++])); auto *coeff_iter = coefficients.data(); DISPATCH_ALL_FIELDS(field, "ParseResult", [&]() { diff --git a/libspu/mpc/cheetah/arith/matmat_prot.cc b/libspu/mpc/cheetah/arith/matmat_prot.cc index c72d97e3..95698de9 100644 --- a/libspu/mpc/cheetah/arith/matmat_prot.cc +++ b/libspu/mpc/cheetah/arith/matmat_prot.cc @@ -19,8 +19,10 @@ #include #include "seal/seal.h" +#include "seal/util/polyarithsmallmod.h" #include "spdlog/spdlog.h" #include "yacl/utils/parallel.h" +#include "yacl/utils/platform_utils.h" #include "libspu/mpc/cheetah/arith/vector_encoder.h" #include "libspu/mpc/cheetah/rlwe/utils.h" @@ -31,7 +33,7 @@ struct std::hash { size_t operator()( const spu::mpc::cheetah::MatMatProtocol::Meta& s) const noexcept { using namespace spu::mpc::cheetah; - // FIXME(juhou): use a better way for hash + // FIXME(lwj): use a better way for hash size_t h = std::hash()("MatMatProtocol::Meta"); for (int i : {0, 1, 2}) { h = (h << 1) ^ std::hash()(s.dims[i]); @@ -47,23 +49,33 @@ bool operator==(const MatMatProtocol::Meta& x, const MatMatProtocol::Meta& y) { class LHSIndexer { public: - explicit LHSIndexer(const Shape3D& meta) { + explicit LHSIndexer(const Shape3D& meta, size_t poly_deg) { offsets_[0] = meta[1] * meta[2]; offsets_[1] = meta[1] - 1; + poly_deg_ = poly_deg; } - size_t operator()(size_t r, size_t c) const { - return r * offsets_[0] + offsets_[1] - c; + inline int64_t get(int64_t r, int64_t c, bool& flip_sign) const { + flip_sign = r == 0 && c > 0; + return poly_deg_ * static_cast(flip_sign) + r * offsets_[0] - c; } - size_t offsets_[2]; + int64_t offsets_[2]; + int64_t poly_deg_ = 0; }; class RHSIndexer { public: - explicit RHSIndexer(const Shape3D& meta) { offset_ = meta[1]; } - size_t operator()(size_t r, size_t c) const { return c * offset_ + r; } - size_t offset_{0}; + explicit RHSIndexer(const Shape3D& meta, size_t /*poly_deg*/) { + offset_ = meta[1]; + } + + inline int64_t get(int64_t r, int64_t c, bool& flip_sign) const { + flip_sign = false; + return c * offset_ + r; + } + + int64_t offset_{0}; }; class ResultIndexer { @@ -71,25 +83,25 @@ class ResultIndexer { explicit ResultIndexer(const Shape3D& meta) { offsets_[0] = meta[1] * meta[2]; offsets_[1] = meta[1]; - offsets_[2] = meta[1] - 1; } - size_t operator()(size_t r, size_t c) const { - return r * offsets_[0] + c * offsets_[1] + offsets_[2]; + inline int64_t get(int64_t r, int64_t c) const { + return r * offsets_[0] + c * offsets_[1]; } - size_t offsets_[3]; + int64_t offsets_[2]; }; template -ArrayRef ConcatSubMatrix(const ArrayRef& mat, const Shape2D& mat_shape, - const Shape2D& starts, const Shape2D& extents, - const Shape2D& submat_shape, int64_t num_el, - const Indexer& indexer) { +NdArrayRef ConcatSubMatrix(const NdArrayRef& mat, const Shape2D& mat_shape, + const Shape2D& starts, const Shape2D& extents, + const Shape2D& submat_shape, int64_t num_coeff, + const Indexer& indexer) { const Type& eltype = mat.eltype(); SPU_ENFORCE(eltype.isa(), "must be ring_type, got={}", eltype); + // SPU_ENFORCE(mat.ndim() == 2, "should be a 2D matrix"); SPU_ENFORCE_EQ(mat.numel(), mat_shape[0] * mat_shape[1]); - SPU_ENFORCE(num_el >= submat_shape[0] * submat_shape[1]); + SPU_ENFORCE(num_coeff >= submat_shape[0] * submat_shape[1]); for (size_t d : {0, 1}) { SPU_ENFORCE(starts[d] < mat_shape[d]); SPU_ENFORCE(extents[d] > 0); @@ -98,49 +110,33 @@ ArrayRef ConcatSubMatrix(const ArrayRef& mat, const Shape2D& mat_shape, const auto field = eltype.as()->field(); // NOTE: zero padding via initialization - ArrayRef f = flatten(ring_zeros(field, {num_el})); + NdArrayRef flatten = ring_zeros(field, {num_coeff}); + + DISPATCH_ALL_FIELDS(field, "ConcatSubMat", [&]() { + using uT = std::make_unsigned::type; - DISPATCH_ALL_FIELDS(field, "", [&]() { for (int64_t r = 0, rr = starts[0]; r < extents[0]; ++r, ++rr) { for (int64_t c = 0, cc = starts[1]; c < extents[1]; ++c, ++cc) { - f.at(indexer(r, c)) = - mat.at(rr * mat_shape[1] + cc); + bool flip_sign = false; + auto idx = indexer.get(r, c, flip_sign); + auto v = mat.at(rr * mat_shape[1] + cc); + flatten.at(idx) = flip_sign ? -v : v; } } }); - return f; + return flatten; } MatMatProtocol::MatMatProtocol(const seal::SEALContext& context, const ModulusSwitchHelper& ms_helper, - bool use_montgomery_fma) - : use_montgomery_fma_(use_montgomery_fma), - context_(context), - ms_helper_(ms_helper) { + bool disable_pack) + : disable_pack_(disable_pack), context_(context), msh_(ms_helper) { SPU_ENFORCE(context_.parameters_set()); - SPU_ENFORCE(context_.first_parms_id() == ms_helper_.parms_id()); + SPU_ENFORCE(context_.first_parms_id() == msh_.parms_id()); poly_deg_ = context_.first_context_data()->parms().poly_modulus_degree(); - vencoder_ = std::make_unique(context_, ms_helper_); - - if (use_montgomery_fma_) { - const auto& cntxt_data = context.first_context_data(); - const auto& modulus = cntxt_data->parms().coeff_modulus(); - - for (const seal::Modulus& moduli : modulus) { - uint64_t prime_inv = [](uint64_t prime) { - uint64_t inv = 1; - for (int i = 0; i < 63; ++i) { - inv *= prime; - prime *= prime; - } - return inv; // prime^{-1} mod 2^64 - }(moduli.value()); - - montgomery_precond_.push_back(prime_inv); - } - } + vencoder_ = std::make_unique(context_, msh_); } bool MatMatProtocol::IsValidMeta(const Meta& meta) const { @@ -177,75 +173,74 @@ size_t MatMatProtocol::GetOutSize(const Meta& meta, } Shape3D MatMatProtocol::GetSubMatShape(const Meta& meta) const { - static std::unordered_map memo_; - static std::shared_mutex lock_; - { - std::shared_lock guard(lock_); - auto val = memo_.find(meta); - if (val != memo_.end()) return val->second; - } - std::unique_lock guard(lock_); - auto val = memo_.find(meta); - if (val != memo_.end()) return val->second; + return GetSubMatShape(meta, poly_deg_, disable_pack_); +} - constexpr int64_t enc_price = 3; - constexpr int64_t eval_price = 1; - constexpr int64_t dec_price = 2; +Shape3D MatMatProtocol::GetSubMatShape(const Meta& meta, int64_t poly_deg, + bool disable_pack) { + const Shape3D& dim3 = meta.dims; + const double cpu_price = 1.0; + const double bandwidth_price = 1000.0; - Shape3D blk; Shape3D subshape; - int64_t min_cost = std::numeric_limits::max(); + Shape3D blk; - for (int64_t d0 = 1; d0 <= meta.dims[0]; ++d0) { - if (d0 > poly_deg_) break; - blk[0] = CeilDiv(meta.dims[0], d0); + const int64_t n = poly_deg; + double min_cost = std::numeric_limits::max(); + for (int64_t d0 = 1; d0 <= std::min(n, dim3[0]); d0 += 1) { + blk[0] = CeilDiv(dim3[0], d0); - for (int64_t d1 = 1; d1 <= meta.dims[1]; ++d1) { - if (d0 * d1 > poly_deg_) break; - blk[1] = CeilDiv(meta.dims[1], d1); + int64_t d1 = 1; + while (d1 <= dim3[1]) { + if (d0 * d1 > n) { + break; + } + blk[1] = CeilDiv(dim3[1], d1); - for (int64_t d2 = 1; d2 <= meta.dims[2]; ++d2) { - if (d0 * d1 * d2 > poly_deg_) break; - blk[2] = CeilDiv(meta.dims[2], d2); + int64_t d2 = std::min(dim3[2], n / d0 / d1); + blk[2] = CeilDiv(dim3[2], d2); - int pivot = blk[0] < blk[2] ? 0 : 2; - int64_t enc_cost = blk[1] * blk[pivot] * enc_price; - int64_t eval_cost = blk[0] * blk[1] * blk[2] * eval_price; - int64_t dec_cost = blk[0] * blk[2] * dec_price; + int64_t sent_ct = std::min(blk[0], blk[2]) * blk[1]; + int64_t response_ct = CeilDiv(blk[0] * blk[2], d1); + int64_t num_automorph = + disable_pack ? 0 : CeilDiv(dim3[0] * dim3[2], n) * d1; + int64_t num_ct_mul = blk[0] * blk[1] * blk[2]; - int64_t cost = enc_cost + eval_cost + dec_cost; + double cost = (sent_ct + response_ct) * bandwidth_price + + num_automorph * cpu_price + num_ct_mul * cpu_price / 10.0; - if (cost <= min_cost) { - min_cost = cost; - subshape[0] = d0; - subshape[1] = d1; - subshape[2] = d2; - } + if (cost <= min_cost) { + min_cost = cost; + subshape = {d0, d1, d2}; } + + d1 = disable_pack ? d1 + 1 : d1 * 2; } } - memo_.insert({meta, subshape}); return subshape; } -void MatMatProtocol::EncodeLHS(const ArrayRef& mat, const Meta& meta, +void MatMatProtocol::EncodeLHS(const NdArrayRef& mat, const Meta& meta, bool need_encrypt, absl::Span out) const { int pivot = 0; - EncodeMatrix(mat, meta, pivot, need_encrypt, out); + EncodeMatrix(mat, meta, pivot, need_encrypt, + LayoutType::row_major, out); } -void MatMatProtocol::EncodeRHS(const ArrayRef& mat, const Meta& meta, +void MatMatProtocol::EncodeRHS(const NdArrayRef& mat, const Meta& meta, bool need_encrypt, absl::Span out) const { int pivot = 1; - EncodeMatrix(mat, meta, pivot, need_encrypt, out); + EncodeMatrix(mat, meta, pivot, need_encrypt, + LayoutType::col_major, out); } template -void MatMatProtocol::EncodeMatrix(const ArrayRef& mat, const Meta& meta, +void MatMatProtocol::EncodeMatrix(const NdArrayRef& mat, const Meta& meta, int pivot, bool need_encrypt, + LayoutType layout, absl::Span out) const { const int R = pivot; const int C = pivot + 1; @@ -258,14 +253,20 @@ void MatMatProtocol::EncodeMatrix(const ArrayRef& mat, const Meta& meta, Shape2D mat_shape = {meta.dims[R], meta.dims[C]}; Shape2D submat_shape = {subshape[R], subshape[C]}; - Indexer indexer(subshape); - std::array extents; - for (int64_t rb = 0; rb < num_row_blocks; ++rb) { - int64_t row_start = rb * subshape[R]; - int64_t row_end = std::min(meta.dims[R], row_start + subshape[R]); - extents[0] = row_end - row_start; - for (int64_t cb = 0; cb < num_col_blocks; ++cb) { - int64_t col_start = cb * subshape[C]; + Indexer indexer(subshape, poly_deg_); + size_t num_jobs = num_row_blocks * num_col_blocks; + size_t wload = CalculateWorkLoad(num_jobs); + + yacl::parallel_for(0, num_jobs, wload, [&](int64_t job_bgn, int64_t job_end) { + std::array extents; + for (int64_t job_id = job_bgn; job_id < job_end; ++job_id) { + int64_t rblk = job_id / num_col_blocks; + int64_t cblk = job_id % num_col_blocks; + int64_t row_start = rblk * subshape[R]; + int64_t row_end = std::min(meta.dims[R], row_start + subshape[R]); + extents[0] = row_end - row_start; + + int64_t col_start = cblk * subshape[C]; int64_t col_end = std::min(meta.dims[C], col_start + subshape[C]); extents[1] = col_end - col_start; @@ -273,14 +274,19 @@ void MatMatProtocol::EncodeMatrix(const ArrayRef& mat, const Meta& meta, ConcatSubMatrix(mat, mat_shape, {row_start, col_start}, extents, submat_shape, poly_deg_, indexer); - vencoder_->Forward(flatten, &out[rb * num_col_blocks + cb], need_encrypt); + int64_t poly_idx = layout == LayoutType::row_major + ? rblk * num_col_blocks + cblk + : cblk * num_row_blocks + rblk; + + vencoder_->Forward(flatten, &out[poly_idx], need_encrypt); } - } + }); } template <> void MatMatProtocol::FusedMulAddInplace(RLWECt& acc, const RLWECt& lhs, const RLWEPt& rhs) const { + namespace sut = seal::util; SPU_ENFORCE(lhs.parms_id() == rhs.parms_id()); auto cntxt_data = context_.get_context_data(lhs.parms_id()); SPU_ENFORCE(cntxt_data != nullptr); @@ -298,30 +304,26 @@ void MatMatProtocol::FusedMulAddInplace(RLWECt& acc, const RLWECt& lhs, size_t coeff_count = parms.poly_modulus_degree(); const auto& modulus = parms.coeff_modulus(); + std::vector tmp(coeff_count); for (size_t k = 0; k < lhs.size(); ++k) { - using namespace seal::util; const auto* op0 = lhs.data(k); const auto* op1 = rhs.data(); auto* dst = acc.data(k); - for (size_t j = 0; j < modulus.size(); ++j) { - if (use_montgomery_fma_) { - const uint64_t prime = modulus[j].value(); - const uint64_t prime_inv = montgomery_precond_.at(j); - const uint64_t tbl[2]{prime, 0}; - - unsigned long long wide[2], H; - for (size_t i = 0; i < coeff_count; ++i) { - multiply_uint64(*op0++, *op1++, wide); - uint64_t R = wide[0] * prime_inv; - multiply_uint64_hw64(R, prime, &H); - uint64_t r = static_cast(wide[1] - H) + prime; - r -= tbl[r < prime]; - r += *dst; - *dst++ = (r - tbl[r < prime]); - } + + for (const auto& moduli : modulus) { + if (yacl::hasAVX2()) { + sut::dyadic_product_coeffmod(op0, op1, coeff_count, moduli, tmp.data()); + sut::add_poly_coeffmod(tmp.data(), dst, coeff_count, moduli, dst); + op0 += coeff_count; + op1 += coeff_count; + dst += coeff_count; } else { - for (size_t i = 0; i < coeff_count; ++i, ++dst) { - *dst = multiply_add_uint_mod(*op0++, *op1++, *dst, modulus[j]); + // un-roll 4 + for (size_t ii = 0; ii < coeff_count; ii += 4, dst += 4) { + dst[0] = sut::multiply_add_uint_mod(*op0++, *op1++, dst[0], moduli); + dst[1] = sut::multiply_add_uint_mod(*op0++, *op1++, dst[1], moduli); + dst[2] = sut::multiply_add_uint_mod(*op0++, *op1++, dst[2], moduli); + dst[3] = sut::multiply_add_uint_mod(*op0++, *op1++, dst[3], moduli); } } } @@ -378,9 +380,9 @@ void TakeCoefficientsFromPoly(const RLWEPt& poly, size_t poly_degree, } } -ArrayRef MatMatProtocol::ParseResult( - FieldType field, const Meta& meta, absl::Span ans_poly, - const ModulusSwitchHelper& ms_helper) const { +NdArrayRef MatMatProtocol::ParseResult(FieldType field, const Meta& meta, + absl::Span ans_poly, + const ModulusSwitchHelper& msh) const { auto subdims = GetSubMatShape(meta); Shape2D out_blks = {CeilDiv(meta.dims[0], subdims[0]), CeilDiv(meta.dims[2], subdims[2])}; @@ -392,11 +394,11 @@ ArrayRef MatMatProtocol::ParseResult( for (int64_t r = 0; r < subdims[0]; ++r) { for (int64_t c = 0; c < subdims[2]; ++c) { - target_coeffs.at(r * subdims[2] + c) = ans_indexer(r, c); + target_coeffs.at(r * subdims[2] + c) = ans_indexer.get(r, c); } } - ArrayRef matmat = flatten(ring_zeros(field, {meta.dims[0] * meta.dims[2]})); + NdArrayRef matmat = ring_zeros(field, {meta.dims[0] * meta.dims[2]}); for (int64_t rb = 0; rb < out_blks[0]; ++rb) { const auto* this_ans = ans_poly.data() + rb * out_blks[1]; @@ -415,13 +417,13 @@ ArrayRef MatMatProtocol::ParseResult( absl::MakeSpan(target_coeffs), absl::MakeSpan(subset)); - auto result_poly = - ms_helper.ModulusDownRNS(field, absl::MakeSpan(subset)); + auto result_poly = msh.ModulusDownRNS(field, {subdims[0], subdims[2]}, + absl::MakeSpan(subset)); for (int64_t r = 0; r < row_ext; ++r) { for (int64_t c = 0; c < col_ext; ++c) { int64_t dst_idx = (r + row_start) * meta.dims[2] + col_start + c; - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, "ParseResult", [&]() { matmat.at(dst_idx) = result_poly.at(r * subdims[2] + c); }); @@ -430,12 +432,96 @@ ArrayRef MatMatProtocol::ParseResult( } } - return matmat; + return matmat.reshape({meta.dims[0], meta.dims[2]}); } -ArrayRef MatMatProtocol::ParseResult(FieldType field, const Meta& meta, - absl::Span ans_poly) const { - return ParseResult(field, meta, ans_poly, ms_helper_); +struct PackRLWEMappingHelper { + explicit PackRLWEMappingHelper(size_t poly_deg, size_t gap, size_t num_polys) + : gap_(gap), num_polys_(num_polys), group_size_(poly_deg / gap) {} + + // the i-th poly's j-th coeffient is packed into where ? + std::pair GetPackedIndex(size_t poly_idx, + size_t offset) const { + SPU_ENFORCE(poly_idx < num_polys_); + SPU_ENFORCE(offset < group_size_); + + auto idx0 = poly_idx / gap_; + auto idx1 = offset * gap_ + poly_idx % gap_; + return {idx0, idx1}; + } + + size_t gap_; + size_t num_polys_; + size_t group_size_; +}; + +NdArrayRef MatMatProtocol::ParseBatchPackedResult( + FieldType field, size_t batch_size, const Meta& meta, + absl::Span polys, const ModulusSwitchHelper& msh) const { + auto subshape = GetSubMatShape(meta); + const int64_t out_n = GetOutSize(meta, subshape); + SPU_ENFORCE_EQ(polys.size(), + CeilDiv(out_n * batch_size, subshape[1])); + + const int64_t gap = subshape[1]; + const int64_t total_polys = out_n * batch_size; + const int64_t polys_per_dot = out_n; + const int64_t numel_per_dot = meta.dims[0] * meta.dims[2]; + + std::vector decoded_vectors(polys.size()); + yacl::parallel_for(0, polys.size(), CalculateWorkLoad(polys.size()), + [&](int64_t bgn, int64_t end) { + for (int64_t i = bgn; i < end; ++i) { + decoded_vectors[i] = msh.ModulusDownRNS( + field, {poly_deg_}, + {polys[i].data(), polys[i].coeff_count()}); + } + }); + + PackRLWEMappingHelper mapper(poly_deg_, gap, total_polys); + + NdArrayRef mat = + ring_zeros(field, {static_cast(batch_size * numel_per_dot)}); + + for (int64_t idx = 0; idx < total_polys; idx += polys_per_dot) { + int64_t num_col_blk = CeilDiv(meta.dims[2], subshape[2]); + int64_t out_mat_idx = idx / polys_per_dot; + auto out_slice = + mat.slice({out_mat_idx * numel_per_dot}, + {out_mat_idx * numel_per_dot + numel_per_dot}, {1}); + + for (int64_t poly_idx = 0; poly_idx < polys_per_dot; ++poly_idx) { + int64_t out_rblk = poly_idx / num_col_blk; + int64_t out_cblk = poly_idx % num_col_blk; + + int64_t out_row_bgn = out_rblk * subshape[0]; + int64_t out_col_bgn = out_cblk * subshape[2]; + int64_t row_ext = + std::min(out_row_bgn + subshape[0], meta.dims[0]) - out_row_bgn; + int64_t col_ext = + std::min(out_col_bgn + subshape[2], meta.dims[2]) - out_col_bgn; + + for (int64_t r = 0; r < row_ext; ++r) { + for (int64_t c = 0; c < col_ext; ++c) { + int64_t coeff_idx = r * subshape[2] + c; + auto o = mapper.GetPackedIndex(idx + poly_idx, coeff_idx); + const NdArrayRef& dcd_vec = decoded_vectors.at(o.first); + std::memcpy( + &out_slice.at((r + out_row_bgn) * meta.dims[2] + out_col_bgn + c), + &dcd_vec.at(o.second), decoded_vectors[0].elsize()); + } + } + } + } + + return mat.reshape( + {static_cast(batch_size), meta.dims[0], meta.dims[2]}); +} + +NdArrayRef MatMatProtocol::ParseResult( + FieldType field, const Meta& meta, + absl::Span ans_poly) const { + return ParseResult(field, meta, ans_poly, msh_); } void MatMatProtocol::ExtractLWEsInplace(const Meta& meta, @@ -449,7 +535,7 @@ void MatMatProtocol::ExtractLWEsInplace(const Meta& meta, std::set to_keep; for (int64_t r = 0; r < subdims[0]; ++r) { for (int64_t c = 0; c < subdims[2]; ++c) { - to_keep.insert(ans_indexer(r, c)); + to_keep.insert(ans_indexer.get(r, c)); } } @@ -476,7 +562,7 @@ void MatMatProtocol::ExtractLWEsInplace(const Meta& meta, std::set to_keep_on_margin; for (int64_t r = 0; r < row_ext; ++r) { for (int64_t c = 0; c < col_ext; ++c) { - to_keep_on_margin.insert(ans_indexer(r, c)); + to_keep_on_margin.insert(ans_indexer.get(r, c)); } } KeepCoefficientsInplace(this_ans[cb], to_keep_on_margin); @@ -502,85 +588,69 @@ void MatMatProtocol::DoCompute(absl::Span lhs, dims[d] = CeilDiv(meta.dims[d], subshape[d]); } - constexpr int kMinLoopDim2 = 4; - if (dims[2] < kMinLoopDim2) { - // k, i, j - for (int64_t k = 0; k < dims[2]; ++k) { - yacl::parallel_for(0, dims[0], 1, [&](size_t bgn, size_t end) { - for (size_t i = bgn; i < end; ++i) { - auto lhs_row = lhs.data() + i * dims[1]; - auto out_row = out.data() + i * dims[2]; - yacl::parallel_for(0, dims[1], 1, [&](size_t bgn, size_t end) { - for (size_t j = bgn; j < end; ++j) { - auto rhs_row = rhs.data() + j * dims[2]; - FusedMulAddInplace(out_row[k], lhs_row[j], - rhs_row[k]); - } - }); + if (dims[0] >= dims[2]) { + auto wload = CalculateWorkLoad(dims[0]); + yacl::parallel_for(0, dims[0], wload, [&](int64_t bgn, int64_t end) { + // Loop dim0 + for (int64_t i = bgn; i < end; ++i) { + // out[i, k] + // NOTE(lwj): LHS is stored in row-major + auto lhs_row = lhs.data() + i * dims[1]; + auto out_row = out.data() + i * dims[2]; + + for (int64_t k = 0; k < dims[2]; ++k) { + auto rhs_col = rhs.data() + k * dims[1]; + for (int64_t j = 0; j < dims[1]; ++j) { + // rhs[j, k] + FusedMulAddInplace(out_row[k], lhs_row[j], rhs_col[j]); + } } - }); - } + } + }); } else { - // i, k, j - for (int64_t i = 0; i < dims[0]; ++i) { - auto lhs_row = lhs.data() + i * dims[1]; - auto out_row = out.data() + i * dims[2]; - yacl::parallel_for(0, dims[2], 1, [&](size_t bgn, size_t end) { - for (size_t k = bgn; k < end; ++k) { + auto wload = CalculateWorkLoad(dims[2]); + yacl::parallel_for(0, dims[2], wload, [&](int64_t bgn, int64_t end) { + // Loop dim2 + for (int64_t k = bgn; k < end; ++k) { + // NOTE(lwj): RHS is stored in column-major + auto rhs_col = rhs.data() + k * dims[1]; + for (int64_t i = 0; i < dims[0]; ++i) { + auto lhs_row = lhs.data() + i * dims[1]; + auto out_row = out.data() + i * dims[2]; for (int64_t j = 0; j < dims[1]; ++j) { - auto rhs_row = rhs.data() + j * dims[2]; - FusedMulAddInplace(out_row[k], lhs_row[j], rhs_row[k]); + FusedMulAddInplace(out_row[k], lhs_row[j], rhs_col[j]); } } - }); - } + } + }); } } void MatMatProtocol::Compute(absl::Span lhs_mat, absl::Span rhs_mat, const Meta& meta, absl::Span out_mat) const { + for (auto& ct : out_mat) { + ct.release(); + } DoCompute(lhs_mat, rhs_mat, meta, out_mat); } void MatMatProtocol::Compute(absl::Span lhs_mat, absl::Span rhs_mat, const Meta& meta, absl::Span out_mat) const { + for (auto& ct : out_mat) { + ct.release(); + } DoCompute(lhs_mat, rhs_mat, meta, out_mat); } void MatMatProtocol::Compute(absl::Span lhs_mat, absl::Span rhs_mat, const Meta& meta, absl::Span out_mat) const { - DoCompute(lhs_mat, rhs_mat, meta, out_mat); -} - -void MatMatProtocol::Montgomerize(absl::Span pt) const { - if (not use_montgomery_fma_) { - return; - } - size_t n = pt.size(); - for (size_t i = 0; i < n; ++i) { - auto cntxt_data = context_.get_context_data(pt[i].parms_id()); - SPU_ENFORCE(cntxt_data != nullptr); - const auto& modulus = cntxt_data->parms().coeff_modulus(); - - uint64_t* pt_rns = pt[i].data(); - for (const auto& moduli : modulus) { - uint64_t prime = moduli.value(); - uint64_t r0 = moduli.const_ratio()[0]; // r0 = hi64(2^128 / prime) - uint64_t r1 = moduli.const_ratio()[1]; // r1 = lo64(2^128 / prime) - // lazy Montgomery form, i.e., in the range [0, 2p). - std::transform(pt_rns, pt_rns + poly_deg_, pt_rns, - [r0, r1, prime](uint64_t a) -> uint64_t { - // SEAL requires explicit ULL type. - unsigned long long hi; - seal::util::multiply_uint64_hw64(a, r0, &hi); - return -((a * r1) + static_cast(hi)) * prime; - }); - pt_rns += poly_deg_; - } + for (auto& ct : out_mat) { + ct.release(); } + DoCompute(lhs_mat, rhs_mat, meta, out_mat); } -} // namespace spu::mpc::cheetah +} // namespace spu::mpc::cheetah \ No newline at end of file diff --git a/libspu/mpc/cheetah/arith/matmat_prot.h b/libspu/mpc/cheetah/arith/matmat_prot.h index 254e1c73..38e6874e 100644 --- a/libspu/mpc/cheetah/arith/matmat_prot.h +++ b/libspu/mpc/cheetah/arith/matmat_prot.h @@ -1,4 +1,3 @@ - // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -29,7 +28,7 @@ class MatMatProtocol { explicit MatMatProtocol(const seal::SEALContext& context, const ModulusSwitchHelper& ms_helper, - bool use_montgomery_fma = false); + bool disable_pack = false); size_t GetLeftSize(const Meta& meta) const { return GetLeftSize(meta, GetSubMatShape(meta)); @@ -51,26 +50,36 @@ class MatMatProtocol { Shape3D GetSubMatShape(const Meta& meta) const; - // Set `need_encrypt=true` if the encoded polynoials are going to be - // encrypted. Encoded polynomials are stored in non-NTT form. - void EncodeLHS(const ArrayRef& lhs_mat, const Meta& meta, bool need_encrypt, - absl::Span out) const; + static Shape3D GetSubMatShape(const Meta& meta, int64_t poly_deg, + bool disable_pack = false); - void Montgomerize(absl::Span pt) const; + void EncodeLHS(const NdArrayRef& lhs_mat, const Meta& meta, bool need_encrypt, + absl::Span out) const; - // Set `need_encrypt=true` if the encoded polynoials are going to be - // encrypted. Encoded polynomials are stored in non-NTT form. - void EncodeRHS(const ArrayRef& rhs_mat, const Meta& meta, bool need_encrypt, + void EncodeRHS(const NdArrayRef& rhs_mat, const Meta& meta, bool need_encrypt, absl::Span out) const; bool IsValidMeta(const Meta& meta) const; - ArrayRef ParseResult(FieldType field, const Meta& meta, - absl::Span ans_poly) const; + NdArrayRef ParseResult(FieldType field, const Meta& meta, + absl::Span ans_poly) const; + + NdArrayRef ParseResult(FieldType field, const Meta& meta, + absl::Span ans_poly, + const ModulusSwitchHelper& msh) const; - ArrayRef ParseResult(FieldType field, const Meta& meta, - absl::Span ans_poly, - const ModulusSwitchHelper& helper) const; + // Coefficients via Packed Batched MatMul + // output shape batch_size x dims[0] x dims[2] + NdArrayRef ParseBatchPackedResult(FieldType field, size_t batch_size, + const Meta& meta, + absl::Span polys, + const ModulusSwitchHelper& msh) const; + + // Coefficients via Packed MatMul + // output shape dims[0] x dims[2] + NdArrayRef ParsePackedResult(FieldType field, const Meta& meta, + absl::Span ans_poly, + const ModulusSwitchHelper& msh) const; void ExtractLWEsInplace(const Meta& meta, absl::Span rlwe) const; @@ -102,17 +111,18 @@ class MatMatProtocol { bool IsValidSubShape(const Shape3D& shape) const; + enum class LayoutType { row_major, col_major }; + template - void EncodeMatrix(const ArrayRef& mat, const Meta& meta, int pivot, - bool need_encrypt, absl::Span out) const; + void EncodeMatrix(const NdArrayRef& mat, const Meta& meta, int pivot, + bool need_encrypt, LayoutType layout, + absl::Span out) const; int64_t poly_deg_{0}; - bool use_montgomery_fma_{false}; + bool disable_pack_ = false; seal::SEALContext context_; - ModulusSwitchHelper ms_helper_; + ModulusSwitchHelper msh_; std::unique_ptr vencoder_{nullptr}; - - std::vector montgomery_precond_; }; -} // namespace spu::mpc::cheetah +} // namespace spu::mpc::cheetah \ No newline at end of file diff --git a/libspu/mpc/cheetah/arith/matmat_prot_test.cc b/libspu/mpc/cheetah/arith/matmat_prot_test.cc index 3f2869fc..e70d6abe 100644 --- a/libspu/mpc/cheetah/arith/matmat_prot_test.cc +++ b/libspu/mpc/cheetah/arith/matmat_prot_test.cc @@ -20,7 +20,6 @@ #include "seal/util/polyarithsmallmod.h" #include "yacl/crypto/utils/rand.h" -#include "libspu/core/xt_helper.h" #include "libspu/mpc/cheetah/rlwe/types.h" #include "libspu/mpc/cheetah/rlwe/utils.h" #include "libspu/mpc/utils/ring_ops.h" @@ -43,10 +42,6 @@ class MatMatProtTest inline uint32_t FieldBitLen(FieldType f) const { return 8 * SizeOf(f); } - ArrayRef CPRNG(FieldType field, size_t size) { - return flatten(ring_rand(field, {(int64_t)size}, seed_, &prng_counter_)); - } - void SetUp() override { field_ = std::get<0>(GetParam()); std::vector modulus_bits; @@ -113,12 +108,14 @@ TEST_P(MatMatProtTest, Plain) { meta.dims = std::get<1>(GetParam()); // NOTE(juhou): Cheetah now supports strided ArrayRef - for (size_t stride : {1, 2, 3}) { - auto _lhs = CPRNG(field_, meta.dims[0] * meta.dims[1] * stride); - auto _rhs = CPRNG(field_, meta.dims[1] * meta.dims[2] * stride + 1); + for (int64_t stride : {1, 2, 3}) { + auto _lhs = ring_rand(field_, {meta.dims[0] * meta.dims[1] * stride}); + auto _rhs = ring_rand(field_, {meta.dims[1] * meta.dims[2] * stride + 1}); - auto lhs = _lhs.slice(0, _lhs.numel(), stride); - auto rhs = _rhs.slice(1, _rhs.numel(), stride); + auto lhs = _lhs.slice({0}, {_lhs.numel()}, {stride}); + auto rhs = _rhs.slice({1}, {_rhs.numel()}, {stride}); + lhs = lhs.reshape({meta.dims[0], meta.dims[1]}); + rhs = rhs.reshape({meta.dims[1], meta.dims[2]}); SPU_ENFORCE_EQ(lhs.numel(), meta.dims[0] * meta.dims[1]); SPU_ENFORCE_EQ(rhs.numel(), meta.dims[1] * meta.dims[2]); @@ -128,6 +125,7 @@ TEST_P(MatMatProtTest, Plain) { size_t lhs_n = matmat_prot.GetLeftSize(meta); size_t rhs_n = matmat_prot.GetRightSize(meta); size_t out_n = matmat_prot.GetOutSize(meta); + std::vector lhs_poly(lhs_n); std::vector rhs_poly(rhs_n); std::vector out_poly(out_n); @@ -149,18 +147,16 @@ TEST_P(MatMatProtTest, Plain) { InvNttInplace(p, *context_); } - auto expected = ring_mmul(unflatten(lhs, {meta.dims[0], meta.dims[1]}), - unflatten(rhs, {meta.dims[1], meta.dims[2]})); - auto computed = unflatten( - matmat_prot.ParseResult(field_, meta, absl::MakeSpan(out_poly)), - {meta.dims[0], meta.dims[2]}); + auto expected = ring_mmul(lhs, rhs); + auto computed = + matmat_prot.ParseResult(field_, meta, absl::MakeSpan(out_poly)); EXPECT_EQ(expected.numel(), computed.numel()); DISPATCH_ALL_FIELDS(field_, "", [&]() { - auto xe = xt_adapt(expected); - auto xc = xt_adapt(computed); - for (size_t i = 0; i < xc.size(); ++i) { + auto xe = NdArrayView(expected); + auto xc = NdArrayView(computed); + for (int64_t i = 0; i < xc.numel(); ++i) { EXPECT_EQ(xe[i], xc[i]); } }); @@ -171,9 +167,9 @@ TEST_P(MatMatProtTest, EncLHS) { MatMatProtocol::Meta meta; meta.dims = std::get<1>(GetParam()); - auto lhs = CPRNG(field_, meta.dims[0] * meta.dims[1]); - auto rhs = CPRNG(field_, meta.dims[1] * meta.dims[2]); - MatMatProtocol matmat_prot(*context_, *ms_helper_, /*mont*/ true); + auto lhs = ring_rand(field_, {meta.dims[0], meta.dims[1]}); + auto rhs = ring_rand(field_, {meta.dims[1], meta.dims[2]}); + MatMatProtocol matmat_prot(*context_, *ms_helper_); size_t lhs_n = matmat_prot.GetLeftSize(meta); size_t rhs_n = matmat_prot.GetRightSize(meta); @@ -192,7 +188,6 @@ TEST_P(MatMatProtTest, EncLHS) { for (auto& p : rhs_poly) { NttInplace(p, *context_); } - matmat_prot.Montgomerize(absl::MakeSpan(rhs_poly)); std::vector out_ct(out_n); matmat_prot.Compute(absl::MakeSpan(enc_poly), absl::MakeSpan(rhs_poly), meta, @@ -203,7 +198,9 @@ TEST_P(MatMatProtTest, EncLHS) { seal::Decryptor decryptor(*context_, *rlwe_sk_); std::vector out_poly(out_n); for (size_t i = 0; i < out_n; ++i) { - if (!out_ct[i].is_ntt_form()) evaluator.transform_to_ntt_inplace(out_ct[i]); + if (!out_ct[i].is_ntt_form()) { + evaluator.transform_to_ntt_inplace(out_ct[i]); + } decryptor.decrypt(out_ct[i], out_poly[i]); } @@ -212,18 +209,16 @@ TEST_P(MatMatProtTest, EncLHS) { InvNttInplace(p, *context_); } - auto expected = ring_mmul(unflatten(lhs, {meta.dims[0], meta.dims[1]}), - unflatten(rhs, {meta.dims[1], meta.dims[2]})); + auto expected = ring_mmul(lhs, rhs); auto computed = - unflatten(matmat_prot.ParseResult(field_, meta, absl::MakeSpan(out_poly)), - {meta.dims[0], meta.dims[2]}); + matmat_prot.ParseResult(field_, meta, absl::MakeSpan(out_poly)); EXPECT_EQ(expected.numel(), computed.numel()); DISPATCH_ALL_FIELDS(field_, "", [&]() { - auto xe = xt_adapt(expected); - auto xc = xt_adapt(computed); - for (size_t i = 0; i < xc.size(); ++i) { + auto xe = NdArrayView(expected); + auto xc = NdArrayView(computed); + for (int64_t i = 0; i < xc.numel(); ++i) { EXPECT_EQ(xe[i], xc[i]); } }); @@ -233,8 +228,8 @@ TEST_P(MatMatProtTest, EncRHS) { MatMatProtocol::Meta meta; meta.dims = std::get<1>(GetParam()); - auto lhs = CPRNG(field_, meta.dims[0] * meta.dims[1]); - auto rhs = CPRNG(field_, meta.dims[1] * meta.dims[2]); + auto lhs = ring_rand(field_, {meta.dims[0], meta.dims[1]}); + auto rhs = ring_rand(field_, {meta.dims[1], meta.dims[2]}); MatMatProtocol matmat_prot(*context_, *ms_helper_); size_t lhs_n = matmat_prot.GetLeftSize(meta); @@ -264,7 +259,9 @@ TEST_P(MatMatProtTest, EncRHS) { seal::Decryptor decryptor(*context_, *rlwe_sk_); std::vector out_poly(out_n); for (size_t i = 0; i < out_n; ++i) { - if (!out_ct[i].is_ntt_form()) evaluator.transform_to_ntt_inplace(out_ct[i]); + if (!out_ct[i].is_ntt_form()) { + evaluator.transform_to_ntt_inplace(out_ct[i]); + } decryptor.decrypt(out_ct[i], out_poly[i]); } @@ -273,18 +270,16 @@ TEST_P(MatMatProtTest, EncRHS) { InvNttInplace(p, *context_); } - auto expected = ring_mmul(unflatten(lhs, {meta.dims[0], meta.dims[1]}), - unflatten(rhs, {meta.dims[1], meta.dims[2]})); + auto expected = ring_mmul(lhs, rhs); auto computed = - unflatten(matmat_prot.ParseResult(field_, meta, absl::MakeSpan(out_poly)), - {meta.dims[0], meta.dims[2]}); + matmat_prot.ParseResult(field_, meta, absl::MakeSpan(out_poly)); EXPECT_EQ(expected.numel(), computed.numel()); DISPATCH_ALL_FIELDS(field_, "", [&]() { - auto xe = xt_adapt(expected); - auto xc = xt_adapt(computed); - for (size_t i = 0; i < xc.size(); ++i) { + auto xe = NdArrayView(expected); + auto xc = NdArrayView(computed); + for (int64_t i = 0; i < xc.numel(); ++i) { EXPECT_EQ(xe[i], xc[i]); } }); diff --git a/libspu/mpc/cheetah/arith/tensor_encoder.cc b/libspu/mpc/cheetah/arith/tensor_encoder.cc index 4459b7dd..5c2b17ad 100644 --- a/libspu/mpc/cheetah/arith/tensor_encoder.cc +++ b/libspu/mpc/cheetah/arith/tensor_encoder.cc @@ -26,7 +26,7 @@ constexpr int kC = 2; TensorEncoder::TensorEncoder(const seal::SEALContext &context, const ModulusSwitchHelper &ms_helper) - : ms_helper_(ms_helper) { + : msh_(ms_helper) { SPU_ENFORCE(context.parameters_set()); auto pid0 = context.first_parms_id(); auto pid1 = ms_helper.parms_id(); @@ -47,9 +47,9 @@ void TensorEncoder::EncodeInput(const Sliced3DTensor &input, SPU_ENFORCE(poly_deg_ >= calcNumel(input.shape())); InputIndexer indexer(input_shape, kernel_shape); - auto poly = Tensor2Poly(input_shape, kernel_shape, input, indexer); + auto poly = toNdArray(Tensor2Poly(input_shape, kernel_shape, input, indexer)); - size_t num_modulus = ms_helper_.coeff_modulus_size(); + size_t num_modulus = msh_.coeff_modulus_size(); out->parms_id() = seal::parms_id_zero; out->resize( seal::util::mul_safe(static_cast(poly_deg_), num_modulus)); @@ -59,14 +59,14 @@ void TensorEncoder::EncodeInput(const Sliced3DTensor &input, std::fill_n(dst, poly_deg_, 0); absl::Span dst_wrap(dst, poly_deg_); if (need_encrypt) { - ms_helper_.ModulusUpAt(poly, mod_idx, dst_wrap); + msh_.ModulusUpAt(poly, mod_idx, dst_wrap); } else { - ms_helper_.CenteralizeAt(poly, mod_idx, dst_wrap); + msh_.CenteralizeAt(poly, mod_idx, dst_wrap); } dst += poly_deg_; } - out->parms_id() = ms_helper_.parms_id(); + out->parms_id() = msh_.parms_id(); out->scale() = 1.; } @@ -80,9 +80,10 @@ void TensorEncoder::EncodeKernel(const Sliced3DTensor &kernel, SPU_ENFORCE(poly_deg_ >= calcNumel(kernel.shape())); KernelIndexer indexer(input_shape, kernel_shape); - auto poly = Tensor2Poly(input_shape, kernel_shape, kernel, indexer); + auto poly = + toNdArray(Tensor2Poly(input_shape, kernel_shape, kernel, indexer)); - size_t num_modulus = ms_helper_.coeff_modulus_size(); + size_t num_modulus = msh_.coeff_modulus_size(); out->parms_id() = seal::parms_id_zero; out->resize( seal::util::mul_safe(static_cast(poly_deg_), num_modulus)); @@ -92,14 +93,14 @@ void TensorEncoder::EncodeKernel(const Sliced3DTensor &kernel, std::fill_n(dst, poly_deg_, 0); absl::Span dst_wrap(dst, poly_deg_); if (need_encrypt) { - ms_helper_.ModulusUpAt(poly, mod_idx, dst_wrap); + msh_.ModulusUpAt(poly, mod_idx, dst_wrap); } else { - ms_helper_.CenteralizeAt(poly, mod_idx, dst_wrap); + msh_.CenteralizeAt(poly, mod_idx, dst_wrap); } dst += poly_deg_; } - out->parms_id() = ms_helper_.parms_id(); + out->parms_id() = msh_.parms_id(); out->scale() = 1.; } @@ -110,13 +111,13 @@ ArrayRef TensorEncoder::Tensor2Poly(const Shape3D &input_shape, const Indexer &indexer) const { int64_t isze = calcNumel(input_shape); int64_t ksze = calcNumel(kernel_shape); - int64_t n_elt = tensor.numel(); + int64_t numel = tensor.numel(); int64_t N = poly_deg_; SPU_ENFORCE(isze > 0 && ksze > 0, "invalid shapes"); - SPU_ENFORCE(n_elt == isze || n_elt == ksze, "shape mismatch"); - SPU_ENFORCE(n_elt <= N, "too large tensor to encode as one poly"); + SPU_ENFORCE(numel == isze || numel == ksze, "shape mismatch"); + SPU_ENFORCE(numel <= N, "too large tensor to encode as one poly"); - Shape3D shape = isze == n_elt ? input_shape : kernel_shape; + Shape3D shape = isze == numel ? input_shape : kernel_shape; const auto field = tensor.field(); return DISPATCH_ALL_FIELDS(field, "Tensor2Poly", [&]() { diff --git a/libspu/mpc/cheetah/arith/tensor_encoder.h b/libspu/mpc/cheetah/arith/tensor_encoder.h index b716fcb4..30944e26 100644 --- a/libspu/mpc/cheetah/arith/tensor_encoder.h +++ b/libspu/mpc/cheetah/arith/tensor_encoder.h @@ -35,7 +35,7 @@ class TensorEncoder { void EncodeKernel(const Sliced3DTensor &kernel, const Shape3D &input_shape, bool need_encrypt, RLWEPt *out) const; - const ModulusSwitchHelper &ms_helper() const { return ms_helper_; } + const ModulusSwitchHelper &ms_helper() const { return msh_; } private: template @@ -46,7 +46,7 @@ class TensorEncoder { private: int64_t poly_deg_{0}; // Take the copy - ModulusSwitchHelper ms_helper_; + ModulusSwitchHelper msh_; std::shared_ptr conv2d_helper_{nullptr}; }; diff --git a/libspu/mpc/cheetah/arith/vector_encoder.cc b/libspu/mpc/cheetah/arith/vector_encoder.cc index dfce64d0..51df186a 100644 --- a/libspu/mpc/cheetah/arith/vector_encoder.cc +++ b/libspu/mpc/cheetah/arith/vector_encoder.cc @@ -22,24 +22,25 @@ namespace spu::mpc::cheetah { VectorEncoder::VectorEncoder(const seal::SEALContext &context, - const ModulusSwitchHelper &ms_helper) { + const ModulusSwitchHelper &msh) { SPU_ENFORCE(context.parameters_set()); auto pid0 = context.first_parms_id(); - auto pid1 = ms_helper.parms_id(); + auto pid1 = msh.parms_id(); SPU_ENFORCE_EQ(0, std::memcmp(&pid0, &pid1, sizeof(seal::parms_id_type)), fmt::format("parameter set mismatch")); - ms_helper_ = std::make_shared(ms_helper); + msh_ = std::make_shared(msh); poly_deg_ = context.first_context_data()->parms().poly_modulus_degree(); } -void VectorEncoder::Forward(const ArrayRef &vec, RLWEPt *out, +void VectorEncoder::Forward(const NdArrayRef &vec, RLWEPt *out, bool scale_delta) const { // Place the vector elements as polynomial coefficients forwardly. // a0, a1, ..., an -> \sum_i ai*X^i yacl::CheckNotNull(out); size_t num_coeffs = vec.numel(); - size_t num_modulus = ms_helper_->coeff_modulus_size(); + size_t num_modulus = msh_->coeff_modulus_size(); + SPU_ENFORCE(vec.shape().size() == 1, "need 1D array"); SPU_ENFORCE_GT(num_coeffs, 0UL); SPU_ENFORCE(num_coeffs <= poly_deg_); @@ -50,27 +51,29 @@ void VectorEncoder::Forward(const ArrayRef &vec, RLWEPt *out, for (size_t mod_idx = 0; mod_idx < num_modulus; ++mod_idx) { std::fill_n(dst, poly_deg_, 0); absl::Span dst_wrap(dst, num_coeffs); + if (scale_delta) { - ms_helper_->ModulusUpAt(vec, mod_idx, dst_wrap); + msh_->ModulusUpAt(vec, mod_idx, dst_wrap); } else { - ms_helper_->CenteralizeAt(vec, mod_idx, dst_wrap); + msh_->CenteralizeAt(vec, mod_idx, dst_wrap); } dst += poly_deg_; } - out->parms_id() = ms_helper_->parms_id(); + out->parms_id() = msh_->parms_id(); out->scale() = 1.; } -void VectorEncoder::Backward(const ArrayRef &vec, RLWEPt *out, +void VectorEncoder::Backward(const NdArrayRef &vec, RLWEPt *out, bool scale_delta) const { // Place the vector elements as polynomial coefficients in backward. // a0, a1, ..., an -> a0 - \sum_{i>0} ai*X^{N-i} // where N defines the base ring X^N + 1. yacl::CheckNotNull(out); + SPU_ENFORCE(vec.shape().size() == 1, "need 1D array"); size_t num_coeffs = vec.numel(); - size_t num_modulus = ms_helper_->coeff_modulus_size(); + size_t num_modulus = msh_->coeff_modulus_size(); SPU_ENFORCE_GT(num_coeffs, 0UL); SPU_ENFORCE(num_coeffs <= poly_deg_); @@ -83,15 +86,14 @@ void VectorEncoder::Backward(const ArrayRef &vec, RLWEPt *out, DISPATCH_ALL_FIELDS(field, "Backward", [&]() { auto tmp_buff = ring_zeros(field, {(int64_t)poly_deg_}); - auto nd_vec = toNdArray(vec); - auto xvec = xt_adapt(nd_vec); - auto xtmp = xt_mutable_adapt(tmp_buff); + auto xvec = NdArrayView(vec); + auto xtmp = NdArrayView(tmp_buff); xtmp[0] = xvec[0]; // reverse and sign flip - std::transform(xvec.data() + 1, xvec.data() + num_coeffs, - std::reverse_iterator(xtmp.data() + poly_deg_), - [](ring2k_t x) { return -x; }); + for (size_t i = 1; i < num_coeffs; ++i) { + xtmp[num_coeffs - 1 - i] = -xvec[i]; + } uint64_t *dst = out->data(); for (size_t mod_idx = 0; mod_idx < num_modulus; ++mod_idx) { @@ -99,18 +101,19 @@ void VectorEncoder::Backward(const ArrayRef &vec, RLWEPt *out, absl::Span dst_wrap(dst, poly_deg_); if (scale_delta) { - ms_helper_->ModulusUpAt(flatten(tmp_buff), mod_idx, dst_wrap); + msh_->ModulusUpAt(tmp_buff, mod_idx, dst_wrap); } else { - ms_helper_->CenteralizeAt(flatten(tmp_buff), mod_idx, dst_wrap); + msh_->CenteralizeAt(tmp_buff, mod_idx, dst_wrap); } dst += poly_deg_; } // clean up sensitive data - seal::util::seal_memzero(xtmp.data(), sizeof(ring2k_t) * poly_deg_); + seal::util::seal_memzero(&tmp_buff.at(0), + sizeof(ring2k_t) * poly_deg_); }); - out->parms_id() = ms_helper_->parms_id(); + out->parms_id() = msh_->parms_id(); out->scale() = 1.; } diff --git a/libspu/mpc/cheetah/arith/vector_encoder.h b/libspu/mpc/cheetah/arith/vector_encoder.h index 12ebb763..9687f5b3 100644 --- a/libspu/mpc/cheetah/arith/vector_encoder.h +++ b/libspu/mpc/cheetah/arith/vector_encoder.h @@ -16,7 +16,7 @@ #include "absl/types/span.h" -#include "libspu/mpc/cheetah/array_ref.h" +#include "libspu/core/ndarray_ref.h" #include "libspu/mpc/cheetah/rlwe/modswitch_helper.h" #include "libspu/mpc/cheetah/rlwe/types.h" @@ -25,7 +25,7 @@ namespace spu::mpc::cheetah { class VectorEncoder { public: explicit VectorEncoder(const seal::SEALContext &context, - const ModulusSwitchHelper &ms_helper); + const ModulusSwitchHelper &msh); // clang-format off // Math: // Forward(x) * Backward(y) mod X^N + 1 gives the inner product @@ -33,17 +33,19 @@ class VectorEncoder { // For example Enc(Forward(Delta*x)) * Backward(y) can give the without error. // Or we can encrypt the backward part, i.e., Enc(Backward(Delta*y)) * Forward(x) also gives the without error. // clang-format on - void Forward(const ArrayRef &vec, RLWEPt *out, bool scaleup = true) const; + void Forward(const NdArrayRef &vec, RLWEPt *out, + bool scale_delta = true) const; - void Backward(const ArrayRef &vec, RLWEPt *out, bool scaleup = false) const; + void Backward(const NdArrayRef &vec, RLWEPt *out, + bool scale_delta = false) const; - const ModulusSwitchHelper &ms_helper() const { return *ms_helper_; } + const ModulusSwitchHelper &ms_helper() const { return *msh_; } size_t poly_degree() const { return poly_deg_; } private: size_t poly_deg_{0}; - std::shared_ptr ms_helper_; + std::shared_ptr msh_; }; } // namespace spu::mpc::cheetah diff --git a/libspu/mpc/cheetah/arithmetic.cc b/libspu/mpc/cheetah/arithmetic.cc index e8c8776b..5810d06b 100644 --- a/libspu/mpc/cheetah/arithmetic.cc +++ b/libspu/mpc/cheetah/arithmetic.cc @@ -16,6 +16,7 @@ #include +#include "libspu/core/ndarray_ref.h" #include "libspu/core/trace.h" #include "libspu/core/xt_helper.h" #include "libspu/mpc/cheetah/arith/common.h" @@ -23,60 +24,44 @@ #include "libspu/mpc/cheetah/nonlinear/equal_prot.h" #include "libspu/mpc/cheetah/nonlinear/truncate_prot.h" #include "libspu/mpc/cheetah/state.h" +#include "libspu/mpc/cheetah/type.h" #include "libspu/mpc/common/communicator.h" #include "libspu/mpc/common/pv2k.h" -#include "libspu/mpc/semi2k/type.h" #include "libspu/mpc/utils/ring_ops.h" namespace spu::mpc::cheetah { -namespace { - -constexpr int64_t kMinWorkSize = 5000; - -TruncateProtocol::MSB_st getMsbType(SignType sign) { - switch (sign) { - case SignType::Positive: - return TruncateProtocol::MSB_st::zero; - case SignType::Negative: - return TruncateProtocol::MSB_st::one; - case SignType::Unknown: - return TruncateProtocol::MSB_st::unknown; - default: - SPU_THROW("should not be here"); - } -} - -} // namespace NdArrayRef TruncA::proc(KernelEvalContext* ctx, const NdArrayRef& x, size_t bits, SignType sign) const { SPU_TRACE_MPC_LEAF(ctx, x); - auto* comm = ctx->getState(); - auto* ot_state = ctx->getState(); - int64_t n = x.numel(); - int64_t nworker = std::min(static_cast(ot_state->parallel_size()), - CeilDiv(n, kMinWorkSize)); - int64_t work_load = nworker == 0 ? 0 : CeilDiv(n, nworker); - for (int64_t w = 0; w < nworker; ++w) { - ot_state->LazyInit(comm, w); + size_t n = x.numel(); + NdArrayRef out(x.eltype(), x.shape()); + if (n == 0) { + return out; } - NdArrayRef out(x.eltype(), x.shape()); + size_t nworker = InitOTState(ctx, n); + size_t work_load = nworker == 0 ? 0 : CeilDiv(n, nworker); + TruncateProtocol::Meta meta; meta.signed_arith = true; - meta.msb = getMsbType(sign); + meta.sign = sign; meta.shift_bits = bits; - auto f_x = flatten(x); + meta.use_heuristic = true; + + // Operate on 1D array + auto flatten_x = x.reshape({x.numel()}); yacl::parallel_for(0, nworker, 1, [&](int64_t bgn, int64_t end) { for (int64_t job = bgn; job < end; ++job) { - int64_t slice_bgn = std::min(job * work_load, n); - int64_t slice_end = std::min(slice_bgn + work_load, n); + int64_t slice_bgn = std::min(job * work_load, n); + int64_t slice_end = std::min(slice_bgn + work_load, n); if (slice_end == slice_bgn) { break; } TruncateProtocol prot(ctx->getState()->get(job)); - auto out_slice = prot.Compute(f_x.slice(slice_bgn, slice_end), meta); + auto out_slice = + prot.Compute(flatten_x.slice({slice_bgn}, {slice_end}, {1}), meta); std::memcpy(&out.at(slice_bgn), &out_slice.at(0), out_slice.numel() * out_slice.elsize()); } @@ -85,56 +70,59 @@ NdArrayRef TruncA::proc(KernelEvalContext* ctx, const NdArrayRef& x, return out; } +// Math: +// msb(x0 + x1 mod 2^k) = msb(x0) ^ msb(x1) ^ 1{(x0 + x1) > 2^{k-1} - 1} +// The carry bit +// 1{(x0 + x1) > 2^{k - 1} - 1} = 1{x0 > 2^{k - 1} - 1 - x1} +// is computed using a Millionare protocol. NdArrayRef MsbA2B::proc(KernelEvalContext* ctx, const NdArrayRef& x) const { - auto* comm = ctx->getState(); - auto* ot_state = ctx->getState(); - int64_t n = x.numel(); - int64_t nworker = std::min(static_cast(ot_state->parallel_size()), - CeilDiv(n, kMinWorkSize)); - int64_t work_load = nworker == 0 ? 0 : CeilDiv(n, nworker); - for (int64_t w = 0; w < nworker; ++w) { - ot_state->LazyInit(comm, w); + SPU_TRACE_MPC_LEAF(ctx, x); + const int64_t numel = x.numel(); + const auto field = ctx->getState()->getDefaultField(); + const size_t nbits = nbits_ == 0 ? SizeOf(field) * 8 : nbits_; + const size_t shft = nbits - 1; + SPU_ENFORCE(nbits <= 8 * SizeOf(field)); + + NdArrayRef out(x.eltype(), x.shape()); + if (numel == 0) { + return out.as(makeType(field, 1)); } - // Math: - // msb(x0 + x1 mod 2^k) = msb(x0) ^ msb(x1) ^ 1{(x0 + x1) > 2^{k-1} - 1} - // The carry bit - // 1{(x0 + x1) > 2^{k - 1} - 1} = 1{x0 > 2^{k - 1} - 1 - x1} - // is computed using a Millionare protocol. - const auto field = ctx->getState()->getDefaultField(); - const int rank = comm->getRank(); - const size_t shft = SizeOf(field) * 8 - 1; + const int64_t nworker = InitOTState(ctx, numel); + const int64_t work_load = nworker == 0 ? 0 : CeilDiv(numel, nworker); + const int rank = ctx->getState()->getRank(); return DISPATCH_ALL_FIELDS(field, "_", [&]() { using u2k = std::make_unsigned::type; const u2k mask = (static_cast(1) << shft) - 1; - NdArrayRef adjusted = ring_zeros(field, {n}); + NdArrayRef adjusted = ring_zeros(field, {numel}); auto xinp = NdArrayView(x); auto xadj = NdArrayView(adjusted); if (rank == 0) { // x0 - pforeach(0, n, [&](int64_t i) { xadj[i] = xinp[i] & mask; }); + pforeach(0, numel, [&](int64_t i) { xadj[i] = xinp[i] & mask; }); } else { // 2^{k - 1} - 1 - x1 - pforeach(0, n, [&](int64_t i) { xadj[i] = (mask - xinp[i]) & mask; }); + pforeach(0, numel, [&](int64_t i) { xadj[i] = (mask - xinp[i]) & mask; }); } NdArrayRef carry_bit(x.eltype(), x.shape()); - auto f_adjusted = flatten(adjusted); yacl::parallel_for(0, nworker, 1, [&](int64_t bgn, int64_t end) { for (int64_t job = bgn; job < end; ++job) { - int64_t slice_bgn = std::min(job * work_load, n); - int64_t slice_end = std::min(slice_bgn + work_load, n); + int64_t slice_bgn = std::min(job * work_load, numel); + int64_t slice_end = std::min(slice_bgn + work_load, numel); if (slice_end == slice_bgn) { break; } - CompareProtocol prot(ot_state->get(job)); + CompareProtocol prot(ctx->getState()->get(job)); + // 1{x0 > 2^{k - 1} - 1 - x1} - auto out_slice = prot.Compute(f_adjusted.slice(slice_bgn, slice_end), - /*greater*/ true); + auto out_slice = + prot.Compute(adjusted.slice({slice_bgn}, {slice_end}, {1}), + /*greater*/ true); std::memcpy(&carry_bit.at(slice_bgn), &out_slice.at(0), out_slice.numel() * out_slice.elsize()); @@ -143,14 +131,15 @@ NdArrayRef MsbA2B::proc(KernelEvalContext* ctx, const NdArrayRef& x) const { // [msb(x)]_B <- [1{x0 + x1 > 2^{k- 1} - 1]_B ^ msb(x0) NdArrayView _carry_bit(carry_bit); - pforeach(0, n, [&](int64_t i) { _carry_bit[i] ^= (xinp[i] >> shft); }); + pforeach(0, numel, [&](int64_t i) { _carry_bit[i] ^= (xinp[i] >> shft); }); - return carry_bit.as(makeType(field, 1)); + return carry_bit.as(makeType(field, 1)); }); } NdArrayRef EqualAP::proc(KernelEvalContext* ctx, const NdArrayRef& x, const NdArrayRef& y) const { + SPU_TRACE_MPC_LEAF(ctx, x, y); EqualAA equal_aa; const auto field = ctx->getState()->getDefaultField(); // TODO(juhou): Can we use any place holder to indicate the dummy 0s. @@ -166,18 +155,19 @@ NdArrayRef EqualAA::proc(KernelEvalContext* ctx, const NdArrayRef& x, SPU_TRACE_MPC_LEAF(ctx, x, y); SPU_ENFORCE_EQ(x.shape(), y.shape()); - auto* comm = ctx->getState(); - auto* ot_state = ctx->getState(); - int64_t n = x.numel(); - int64_t nworker = std::min(static_cast(ot_state->parallel_size()), - CeilDiv(n, kMinWorkSize)); - int64_t work_load = nworker == 0 ? 0 : CeilDiv(n, nworker); - for (int64_t w = 0; w < nworker; ++w) { - ot_state->LazyInit(comm, w); + const int64_t numel = x.numel(); + const auto field = ctx->getState()->getDefaultField(); + const size_t nbits = nbits_ == 0 ? SizeOf(field) * 8 : nbits_; + SPU_ENFORCE(nbits <= 8 * SizeOf(field)); + + NdArrayRef eq_bit(x.eltype(), x.shape()); + if (numel == 0) { + return eq_bit.as(makeType(field, 1)); } - const auto field = ctx->getState()->getDefaultField(); - const int rank = comm->getRank(); + const int64_t nworker = InitOTState(ctx, numel); + const int64_t work_load = nworker == 0 ? 0 : CeilDiv(numel, nworker); + const int rank = ctx->getState()->getRank(); // x0 + x1 = y0 + y1 mod 2k // <=> x0 - y0 = y1 - x1 mod 2k @@ -188,56 +178,58 @@ NdArrayRef EqualAA::proc(KernelEvalContext* ctx, const NdArrayRef& x, adjusted = ring_sub(y, x); } - NdArrayRef eq_bit(x.eltype(), x.shape()); - auto f_adjusted = flatten(adjusted); + // Need 1D array + adjusted = adjusted.reshape({adjusted.numel()}); yacl::parallel_for(0, nworker, 1, [&](int64_t bgn, int64_t end) { for (int64_t job = bgn; job < end; ++job) { - int64_t slice_bgn = std::min(job * work_load, n); - int64_t slice_end = std::min(slice_bgn + work_load, n); + int64_t slice_bgn = std::min(job * work_load, numel); + int64_t slice_end = std::min(slice_bgn + work_load, numel); if (slice_end == slice_bgn) { break; } - EqualProtocol prot(ot_state->get(job)); - auto out_slice = prot.Compute(f_adjusted.slice(slice_bgn, slice_end)); + EqualProtocol prot(ctx->getState()->get(job)); + auto out_slice = + prot.Compute(adjusted.slice({slice_bgn}, {slice_end}, {1}), nbits); std::memcpy(&eq_bit.at(slice_bgn), &out_slice.at(0), out_slice.numel() * out_slice.elsize()); } }); - return eq_bit.as(makeType(field, 1)); + return eq_bit.as(makeType(field, 1)); } -NdArrayRef MulA1B::proc(KernelEvalContext* ctx, const NdArrayRef& x, - const NdArrayRef& y) const { - SPU_ENFORCE_EQ(x.shape(), y.shape()); - auto* comm = ctx->getState(); - auto* ot_state = ctx->getState(); - int64_t n = x.numel(); - int64_t nworker = std::min(static_cast(ot_state->parallel_size()), - CeilDiv(n, kMinWorkSize)); - int64_t work_load = nworker == 0 ? 0 : CeilDiv(n, nworker); - - for (int64_t w = 0; w < nworker; ++w) { - ot_state->LazyInit(comm, w); +NdArrayRef MulA1B::proc(KernelEvalContext* ctx, const NdArrayRef& ashr, + const NdArrayRef& bshr) const { + SPU_TRACE_MPC_LEAF(ctx, ashr, bshr); + SPU_ENFORCE_EQ(ashr.shape(), bshr.shape()); + const int64_t numel = ashr.numel(); + NdArrayRef out(ashr.eltype(), ashr.shape()); + + if (numel == 0) { + return out; } - NdArrayRef out(x.eltype(), x.shape()); - auto f_x = flatten(x); - auto f_y = flatten(y); + const int64_t nworker = InitOTState(ctx, numel); + const int64_t work_load = nworker == 0 ? 0 : CeilDiv(numel, nworker); + + // Need 1D Array + auto flatten_a = ashr.reshape({ashr.numel()}); + auto flatten_b = bshr.reshape({bshr.numel()}); yacl::parallel_for(0, nworker, 1, [&](int64_t bgn, int64_t end) { for (int64_t job = bgn; job < end; ++job) { - int64_t slice_bgn = std::min(job * work_load, n); - int64_t slice_end = std::min(slice_bgn + work_load, n); + int64_t slice_bgn = std::min(job * work_load, numel); + int64_t slice_end = std::min(slice_bgn + work_load, numel); if (slice_end == slice_bgn) { break; } - auto out_slice = ot_state->get(job)->Multiplexer( - f_x.slice(slice_bgn, slice_end), f_y.slice(slice_bgn, slice_end)); + auto out_slice = ctx->getState()->get(job)->Multiplexer( + flatten_a.slice({slice_bgn}, {slice_end}, {1}), + flatten_b.slice({slice_bgn}, {slice_end}, {1})); std::memcpy(&out.at(slice_bgn), &out_slice.at(0), out_slice.numel() * out_slice.elsize()); @@ -249,10 +241,12 @@ NdArrayRef MulA1B::proc(KernelEvalContext* ctx, const NdArrayRef& x, NdArrayRef MulAA::proc(KernelEvalContext* ctx, const NdArrayRef& x, const NdArrayRef& y) const { - SPU_ENFORCE_EQ(x.numel(), y.numel()); + SPU_TRACE_MPC_LEAF(ctx, x, y); + SPU_ENFORCE_EQ(x.shape(), y.shape()); + + int64_t batch_sze = ctx->getState()->get()->OLEBatchSize(); + int64_t numel = x.numel(); - size_t batch_sze = ctx->getState()->get()->OLEBatchSize(); - size_t numel = x.numel(); if (numel >= batch_sze) { return mulDirectly(ctx, x, y); } @@ -271,20 +265,21 @@ NdArrayRef MulAA::mulWithBeaver(KernelEvalContext* ctx, const NdArrayRef& x, ctx->getState()->TakeCachedBeaver(field, numel); YACL_ENFORCE_EQ(a.numel(), numel); + a = a.reshape(x.shape()); + b = b.reshape(x.shape()); + c = c.reshape(x.shape()); + auto* comm = ctx->getState(); // Open x - a & y - b - auto res = vmap({ring_sub(x, unflatten(a, x.shape())), - ring_sub(y, unflatten(b, y.shape()))}, - [&](const NdArrayRef& s) { - return comm->allReduce(ReduceOp::ADD, s, kBindName); - }); + auto res = vmap({ring_sub(x, a), ring_sub(y, b)}, [&](const NdArrayRef& s) { + return comm->allReduce(ReduceOp::ADD, s, kBindName); + }); auto x_a = std::move(res[0]); auto y_b = std::move(res[1]); // Zi = Ci + (X - A) * Bi + (Y - B) * Ai + <(X - A) * (Y - B)> - auto z = ring_add(ring_mul(x_a, unflatten(b, x_a.shape())), - ring_mul(y_b, unflatten(a, y_b.shape()))); - ring_add_(z, unflatten(c, z.shape())); + auto z = ring_add(ring_mul(x_a, b), ring_mul(y_b, a)); + ring_add_(z, c); if (comm->getRank() == 0) { // z += (X-A) * (Y-B); @@ -301,28 +296,27 @@ NdArrayRef MulAA::mulDirectly(KernelEvalContext* ctx, const NdArrayRef& x, auto* comm = ctx->getState(); auto* mul_prot = ctx->getState()->get(); const int rank = comm->getRank(); + auto fx = x.reshape({x.numel()}); + auto fy = y.reshape({y.numel()}); - auto* conn = comm->lctx().get(); - auto dupx = conn->Spawn(); - std::future task = std::async(std::launch::async, [&] { + auto dupx = ctx->getState()->duplx(); + std::future task = std::async(std::launch::async, [&] { if (rank == 0) { - return mul_prot->MulOLE(flatten(x), dupx.get(), true); + return mul_prot->MulOLE(fx, dupx.get(), true); } - return mul_prot->MulOLE(flatten(y), dupx.get(), false); + return mul_prot->MulOLE(fy, dupx.get(), false); }); - ArrayRef x0y1; - ArrayRef x1y0; + NdArrayRef x1y0; if (rank == 0) { - x1y0 = mul_prot->MulOLE(flatten(y), conn, false); + x1y0 = mul_prot->MulOLE(fy, false); } else { - x1y0 = mul_prot->MulOLE(flatten(x), conn, true); + x1y0 = mul_prot->MulOLE(fx, true); } - x0y1 = task.get(); - return ring_add(unflatten(x0y1, x.shape()), - ring_add(unflatten(x1y0, x.shape()), ring_mul(x, y))) - .as(x.eltype()); + x1y0 = x1y0.reshape(x.shape()); + NdArrayRef x0y1 = task.get().reshape(x.shape()); + return ring_add(x0y1, ring_add(x1y0, ring_mul(x, y))).as(x.eltype()); } // A is (M, K); B is (K, N) @@ -338,91 +332,29 @@ NdArrayRef MatMulAA::proc(KernelEvalContext* ctx, const NdArrayRef& x, // (x0 + x1) * (y0 + y1) // Compute the cross terms homomorphically - const Shape3D dim3 = {static_cast(x.shape()[0]), - static_cast(x.shape()[1]), - static_cast(y.shape()[1])}; + const Shape3D dim3 = {x.shape()[0], x.shape()[1], y.shape()[1]}; auto* conn = comm->lctx().get(); - auto dupx = conn->Spawn(); - std::future task = std::async(std::launch::async, [&] { + auto dupx = ctx->getState()->duplx(); + std::future task = std::async(std::launch::async, [&] { + // Compute x0*y1 if (rank == 0) { - return dot_prot->DotOLE(flatten(x), dupx.get(), dim3, true); + return dot_prot->DotOLE(x, dupx.get(), dim3, true); } else { - return dot_prot->DotOLE(flatten(y), dupx.get(), dim3, false); + return dot_prot->DotOLE(y, dupx.get(), dim3, false); } }); - ArrayRef x1y0; + NdArrayRef x1y0; if (rank == 0) { - x1y0 = dot_prot->DotOLE(flatten(y), conn, dim3, false); + x1y0 = dot_prot->DotOLE(y, conn, dim3, false); } else { - x1y0 = dot_prot->DotOLE(flatten(x), conn, dim3, true); + x1y0 = dot_prot->DotOLE(x, conn, dim3, true); } auto ret = ring_mmul(x, y); - ring_add_(ret, unflatten(x1y0, ret.shape())); - return ring_add(ret, unflatten(task.get(), ret.shape())).as(x.eltype()); -} - -NdArrayRef Conv2DAA::proc(KernelEvalContext* ctx, const NdArrayRef& tensor, - const NdArrayRef& filter, int64_t stride_h, - int64_t stride_w) const { - SPU_TRACE_MPC_LEAF(ctx, tensor, filter); - if (0 == tensor.numel() || 0 == filter.numel()) { - return NdArrayRef(tensor.eltype(), {0}); - } - - auto N = tensor.shape()[0]; - auto C = tensor.shape()[3]; - auto H = tensor.shape()[1]; - auto W = tensor.shape()[2]; - - auto h = filter.shape()[0]; - auto w = filter.shape()[1]; - auto O = filter.shape()[3]; - - int64_t tensor_sze = N * H * W * C; - int64_t filter_sze = h * w * C * O; - SPU_ENFORCE_EQ(tensor.numel(), tensor_sze); - SPU_ENFORCE_EQ(filter.numel(), filter_sze); - auto* comm = ctx->getState(); - const int rank = comm->getRank(); - - Shape3D tensor_shape{H, W, C}; - Shape3D filter_shape{h, w, C}; - Shape2D window_strides{stride_h, stride_w}; - - auto* conv2d_prot = ctx->getState()->get(); - - auto* conn = comm->lctx().get(); - auto dupx = conn->Spawn(); - - auto f_tensor = flatten(tensor); - auto f_filter = flatten(filter); - std::future task = std::async(std::launch::async, [&] { - if (rank == 0) { - return conv2d_prot->Conv2dOLE(f_tensor, dupx.get(), N, tensor_shape, O, - filter_shape, window_strides, true); - } else { - return conv2d_prot->Conv2dOLE(f_filter, dupx.get(), N, tensor_shape, O, - filter_shape, window_strides, false); - } - }); - - ArrayRef x1y0; - if (rank == 0) { - x1y0 = conv2d_prot->Conv2dOLE(f_filter, comm->lctx().get(), N, tensor_shape, - O, filter_shape, window_strides, false); - - } else { - x1y0 = conv2d_prot->Conv2dOLE(f_tensor, comm->lctx().get(), N, tensor_shape, - O, filter_shape, window_strides, true); - } - - auto ret = ring_conv2d(tensor, filter, N, tensor_shape, O, filter_shape, - window_strides); - ring_add_(ret, unflatten(x1y0, ret.shape())); - return ring_add(ret, unflatten(task.get(), ret.shape())).as(tensor.eltype()); + ring_add_(ret, x1y0); + return ring_add(ret, task.get()).as(x.eltype()); } -} // namespace spu::mpc::cheetah +} // namespace spu::mpc::cheetah \ No newline at end of file diff --git a/libspu/mpc/cheetah/arithmetic.h b/libspu/mpc/cheetah/arithmetic.h index e2829450..b66d38b5 100644 --- a/libspu/mpc/cheetah/arithmetic.h +++ b/libspu/mpc/cheetah/arithmetic.h @@ -15,74 +15,101 @@ #pragma once #include "libspu/mpc/kernel.h" -#include "libspu/mpc/semi2k/arithmetic.h" namespace spu::mpc::cheetah { +class RandA : public RandKernel { + public: + static constexpr char kBindName[] = "rand_a"; -using RandA = spu::mpc::semi2k::RandA; + Kind kind() const override { return Kind::Dynamic; } -using P2A = spu::mpc::semi2k::P2A; + NdArrayRef proc(KernelEvalContext* ctx, const Shape& shape) const override; +}; -using A2P = spu::mpc::semi2k::A2P; +class P2A : public UnaryKernel { + public: + static constexpr char kBindName[] = "p2a"; -using V2A = spu::mpc::semi2k::V2A; + Kind kind() const override { return Kind::Dynamic; } -using A2V = spu::mpc::semi2k::A2V; + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in) const override; +}; -using NotA = spu::mpc::semi2k::NotA; +class A2P : public UnaryKernel { + public: + static constexpr char kBindName[] = "a2p"; -using AddAP = spu::mpc::semi2k::AddAP; + Kind kind() const override { return Kind::Dynamic; } -using AddAA = spu::mpc::semi2k::AddAA; + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in) const override; +}; -using MulAP = spu::mpc::semi2k::MulAP; +class A2V : public RevealToKernel { + public: + static constexpr char kBindName[] = "a2v"; -using MatMulAP = spu::mpc::semi2k::MatMulAP; + Kind kind() const override { return Kind::Dynamic; } -using LShiftA = spu::mpc::semi2k::LShiftA; + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, + size_t rank) const override; +}; -class TruncA : public TruncAKernel { +class V2A : public UnaryKernel { public: - static constexpr char kBindName[] = "trunc_a"; + static constexpr char kBindName[] = "v2a"; Kind kind() const override { return Kind::Dynamic; } - NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& x, size_t bits, - SignType sign) const override; + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in) const override; +}; - bool hasMsbError() const override { return false; } +class NotA : public UnaryKernel { + public: + static constexpr char kBindName[] = "not_a"; - TruncLsbRounding lsbRounding() const override { - return TruncLsbRounding::Probabilistic; - } + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in) const override; }; -class MsbA2B : public UnaryKernel { +class AddAP : public BinaryKernel { public: - static constexpr char kBindName[] = "msb_a2b"; + static constexpr char kBindName[] = "add_ap"; - NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& x) const override; + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& lhs, + const NdArrayRef& rhs) const override; }; -class EqualAA : public BinaryKernel { +class AddAA : public BinaryKernel { public: - static constexpr char kBindName[] = "equal_aa"; + static constexpr char kBindName[] = "add_aa"; - Kind kind() const override { return Kind::Dynamic; } + ce::CExpr latency() const override { return ce::Const(0); } - NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& x, - const NdArrayRef& y) const override; + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& lhs, + const NdArrayRef& rhs) const override; }; -class EqualAP : public BinaryKernel { +class MulAP : public BinaryKernel { public: - static constexpr char kBindName[] = "equal_ap"; + static constexpr char kBindName[] = "mul_ap"; - Kind kind() const override { return Kind::Dynamic; } + ce::CExpr latency() const override { return ce::Const(0); } - NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& x, - const NdArrayRef& y) const override; + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& lhs, + const NdArrayRef& rhs) const override; }; +/// Kernels that idenetical to semi2k /// class MulA1B : public BinaryKernel { public: @@ -102,6 +129,8 @@ class MulAA : public BinaryKernel { NdArrayRef mulWithBeaver(KernelEvalContext* ctx, const NdArrayRef& lhs, const NdArrayRef& rhs) const; + NdArrayRef squareDirectly(KernelEvalContext* ctx, const NdArrayRef& x) const; + public: static constexpr char kBindName[] = "mul_aa"; @@ -111,12 +140,28 @@ class MulAA : public BinaryKernel { const NdArrayRef& y) const override; }; +//////////////////////////////////////////////////////////////////// +// matmul family +//////////////////////////////////////////////////////////////////// +class MatMulAP : public MatmulKernel { + public: + static constexpr char kBindName[] = "mmul_ap"; + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& x, + const NdArrayRef& y) const override; +}; + class MatMulAA : public MatmulKernel { public: static constexpr char kBindName[] = "mmul_aa"; Kind kind() const override { return Kind::Dynamic; } - + // LHS: m x k + // RHS: k x n NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& x, const NdArrayRef& y) const override; }; @@ -132,4 +177,91 @@ class Conv2DAA : public Conv2DKernel { int64_t stride_w) const override; }; -} // namespace spu::mpc::cheetah +class TruncA : public TruncAKernel { + public: + static constexpr char kBindName[] = "trunc_a"; + + Kind kind() const override { return Kind::Dynamic; } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& x, size_t bits, + SignType sign) const override; + + bool hasMsbError() const override { return false; } + + TruncLsbRounding lsbRounding() const override { + return TruncLsbRounding::Probabilistic; + } +}; + +class LShiftA : public ShiftKernel { + public: + static constexpr char kBindName[] = "lshift_a"; + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, + size_t bits) const override; +}; + +class MsbA2B : public UnaryKernel { + public: + static constexpr char kBindName[] = "msb_a2b"; + + MsbA2B(size_t nbits = 0) : nbits_(nbits) {} + + Kind kind() const override { return Kind::Dynamic; } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& x) const override; + + private: + size_t nbits_; +}; + +class EqualAA : public BinaryKernel { + public: + static constexpr char kBindName[] = "equal_aa"; + + EqualAA(size_t nbits = 0) : nbits_(nbits) {} + + Kind kind() const override { return Kind::Dynamic; } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& x, + const NdArrayRef& y) const override; + + private: + size_t nbits_; +}; + +class EqualAP : public BinaryKernel { + public: + static constexpr char kBindName[] = "equal_ap"; + + Kind kind() const override { return Kind::Dynamic; } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& x, + const NdArrayRef& y) const override; +}; + +class LessAP : public BinaryKernel { + public: + static constexpr char kBindName[] = "f_less_ap"; + + Kind kind() const override { return Kind::Dynamic; } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& x, + const NdArrayRef& y) const override; +}; + +class LessPA : public BinaryKernel { + public: + static constexpr char kBindName[] = "f_less_pa"; + + Kind kind() const override { return Kind::Dynamic; } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& x, + const NdArrayRef& y) const override; +}; + +} // namespace spu::mpc::cheetah \ No newline at end of file diff --git a/libspu/mpc/cheetah/arithmetic_semi2k.cc b/libspu/mpc/cheetah/arithmetic_semi2k.cc new file mode 100644 index 00000000..17cbc78b --- /dev/null +++ b/libspu/mpc/cheetah/arithmetic_semi2k.cc @@ -0,0 +1,151 @@ +// Copyright 2021 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "libspu/mpc/cheetah/arithmetic.h" +#include "libspu/mpc/cheetah/state.h" +#include "libspu/mpc/cheetah/type.h" +#include "libspu/mpc/common/communicator.h" +#include "libspu/mpc/common/prg_state.h" +#include "libspu/mpc/common/pv2k.h" +#include "libspu/mpc/utils/ring_ops.h" +namespace spu::mpc::cheetah { +/// Kernels that identical to semi2k /// +NdArrayRef RandA::proc(KernelEvalContext* ctx, const Shape& shape) const { + auto* prg_state = ctx->getState(); + const auto field = ctx->getState()->getDefaultField(); + return ring_rshift(prg_state->genPriv(field, shape), 2) + .as(makeType(field)); +} + +NdArrayRef P2A::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { + const auto field = in.eltype().as()->field(); + auto* prg_state = ctx->getState(); + auto* comm = ctx->getState(); + + auto [r0, r1] = prg_state->genPrssPair(field, in.shape()); + auto x = ring_sub(r0, r1).as(makeType(field)); + + if (comm->getRank() == 0) { + ring_add_(x, in); + } + + return x.as(makeType(field)); +} + +NdArrayRef A2P::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { + const auto field = in.eltype().as()->field(); + auto* comm = ctx->getState(); + auto out = comm->allReduce(ReduceOp::ADD, in, kBindName); + return out.as(makeType(field)); +} + +NdArrayRef A2V::proc(KernelEvalContext* ctx, const NdArrayRef& in, + size_t rank) const { + auto* comm = ctx->getState(); + const auto field = in.eltype().as()->field(); + auto out_ty = makeType(field, rank); + + auto numel = in.numel(); + + return DISPATCH_ALL_FIELDS(field, "_", [&]() { + std::vector share(numel); + NdArrayView _in(in); + pforeach(0, numel, [&](int64_t idx) { share[idx] = _in[idx]; }); + + std::vector> shares = + comm->gather(share, rank, "a2v"); // comm => 1, k + if (comm->getRank() == rank) { + SPU_ENFORCE(shares.size() == comm->getWorldSize()); + NdArrayRef out(out_ty, in.shape()); + NdArrayView _out(out); + pforeach(0, numel, [&](int64_t idx) { + ring2k_t s = 0; + for (auto& share : shares) { + s += share[idx]; + } + _out[idx] = s; + }); + return out; + } else { + return makeConstantArrayRef(out_ty, in.shape()); + } + }); +} +NdArrayRef V2A::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { + const auto* in_ty = in.eltype().as(); + const size_t owner_rank = in_ty->owner(); + const auto field = in_ty->field(); + auto* prg_state = ctx->getState(); + auto* comm = ctx->getState(); + + auto [r0, r1] = prg_state->genPrssPair(field, in.shape()); + auto x = ring_sub(r0, r1).as(makeType(field)); + + if (comm->getRank() == owner_rank) { + ring_add_(x, in); + } + + return x.as(makeType(field)); +} + +NdArrayRef NotA::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { + auto* comm = ctx->getState(); + auto res = ring_neg(in); + if (comm->getRank() == 0) { + const auto field = in.eltype().as()->field(); + ring_add_(res, ring_not(ring_zeros(field, in.shape()))); + } + + return res.as(in.eltype()); +} + +NdArrayRef AddAP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, + const NdArrayRef& rhs) const { + SPU_ENFORCE(lhs.numel() == rhs.numel()); + auto* comm = ctx->getState(); + + if (comm->getRank() == 0) { + return ring_add(lhs, rhs).as(lhs.eltype()); + } + + return lhs; +} + +NdArrayRef AddAA::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, + const NdArrayRef& rhs) const { + SPU_ENFORCE(lhs.numel() == rhs.numel()); + SPU_ENFORCE(lhs.eltype() == rhs.eltype()); + + return ring_add(lhs, rhs).as(lhs.eltype()); +} + +NdArrayRef MulAP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, + const NdArrayRef& rhs) const { + return ring_mul(lhs, rhs).as(lhs.eltype()); +} + +NdArrayRef MatMulAP::proc(KernelEvalContext* ctx, const NdArrayRef& x, + const NdArrayRef& y) const { + return ring_mmul(x, y).as(x.eltype()); +} + +NdArrayRef LShiftA::proc(KernelEvalContext* ctx, const NdArrayRef& in, + size_t bits) const { + const auto field = in.eltype().as()->field(); + bits %= SizeOf(field) * 8; + + return ring_lshift(in, bits).as(in.eltype()); +} + +} // namespace spu::mpc::cheetah \ No newline at end of file diff --git a/libspu/mpc/cheetah/array_ref.cc b/libspu/mpc/cheetah/array_ref.cc deleted file mode 100644 index d5775ded..00000000 --- a/libspu/mpc/cheetah/array_ref.cc +++ /dev/null @@ -1,168 +0,0 @@ -// Copyright 2021 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "libspu/mpc/cheetah/array_ref.h" - -#include - -#include "fmt/format.h" -#include "fmt/ostream.h" - -#include "libspu/core/parallel_utils.h" - -namespace spu { -namespace detail { - -void strided_copy(int64_t numel, int64_t elsize, void* dst, int64_t dstride, - void const* src, int64_t sstride) { - const char* src_itr = static_cast(src); - char* dst_itr = static_cast(dst); - - if (dstride == 1 && sstride == 1) { - std::memcpy(dst_itr, src_itr, elsize * numel); - } else { - dstride *= elsize; - sstride *= elsize; - - pforeach(0, numel, [&](int64_t idx) { - std::memcpy(&dst_itr[idx * dstride], &src_itr[idx * sstride], elsize); - }); - } -} - -} // namespace detail - -ArrayRef::ArrayRef(std::shared_ptr buf, Type eltype, - int64_t numel, int64_t stride, int64_t offset) - : buf_(std::move(buf)), - eltype_(std::move(eltype)), - numel_(numel), - stride_(stride), - offset_(offset) { - // sanity check. - if (numel != 0) { - const auto elsize = static_cast(eltype_.size()); - const auto bufsize = buf_->size(); - SPU_ENFORCE(offset >= 0 && offset + elsize <= bufsize); - SPU_ENFORCE( - (offset + stride * (numel - 1) >= 0) && - (offset + stride * (numel - 1) + elsize <= bufsize), - "sanity failed, eltype={}, offset={}, stride={}, numel={}, buf.size={}", - eltype_, offset, stride, numel, bufsize); - } -} - -ArrayRef::ArrayRef(const Type& eltype, size_t numel) - : ArrayRef(std::make_shared(numel * eltype.size()), - eltype, // eltype - numel, // numel - 1, // stride, - 0 // offset - ) {} - -ArrayRef makeConstantArrayRef(const Type& eltype, size_t numel) { - auto buf = std::make_shared(eltype.size()); - memset(buf->data(), 0, eltype.size()); - return ArrayRef(buf, // buf - eltype, // eltype - numel, // numel - 0, // stride, - 0 // offset - ); -} - -bool ArrayRef::isCompact() const { return stride_ == 1 || numel_ == 1; } - -std::shared_ptr ArrayRef::getOrCreateCompactBuf() const { - if (isCompact() && offset_ == 0) { - return buf(); - } - return clone().buf(); -} - -ArrayRef ArrayRef::clone() const { - ArrayRef res(eltype(), numel()); - - detail::strided_copy(numel(), elsize(), res.data(), res.stride(), data(), - stride()); - return res; -} - -ArrayRef ArrayRef::as(const Type& new_ty, bool force) const { - if (!force) { - SPU_ENFORCE(elsize() == new_ty.size(), - "viewed type={} not equal to origin type={}", new_ty, eltype()); - } - - return ArrayRef(buf(), new_ty, numel(), stride(), offset()); -} - -ArrayRef ArrayRef::slice(int64_t start, int64_t stop, int64_t stride) const { - // From - // https://numpy.org/doc/stable/user/basics.indexing.html#slicing-and-striding - // - // The basic slice syntax is i:j:k where i is the starting index, j is the - // stopping index, and k is the step (). This selects the m elements (in the - // corresponding dimension) with index values i, i + k, …, i + (m - 1) k - // where and q and r are the quotient and remainder obtained by dividing j - - // i by k: j - i = q k + r, so that i + (m - 1) k < j. - SPU_ENFORCE(start < numel_, "start={}, numel_={}", start, numel_); - - const int64_t q = (stop - start) / stride; - const int64_t r = (stop - start) % stride; - const int64_t m = q + static_cast(r != 0); - - const int64_t n_stride = stride_ * stride; - const int64_t n_offset = offset_ + start * stride_ * elsize(); - - return ArrayRef(buf(), eltype_, m, n_stride, n_offset); -} - -bool ArrayRef::operator==(const ArrayRef& other) const { - if (numel() != other.numel() || eltype() != other.eltype()) { - return false; - } - - for (int64_t idx = 0; idx < numel(); idx++) { - const auto* a = &at(idx); - const auto* b = &other.at(idx); - - if (memcmp(a, b, elsize()) != 0) { - return false; - } - } - - return true; -} - -std::ostream& operator<<(std::ostream& out, const ArrayRef& v) { - out << fmt::format("ArrayRef<{}x{}>", v.numel(), v.eltype().toString()); - return out; -} - -ArrayRef flatten(const NdArrayRef& in) { - auto f = in.reshape({in.numel()}); - return ArrayRef(f.buf(), f.eltype(), in.numel(), f.strides()[0], f.offset()); -} - -NdArrayRef toNdArray(const ArrayRef& in) { - return NdArrayRef(in.buf(), in.eltype(), {in.numel()}, {in.stride()}, - in.offset()); -} - -NdArrayRef unflatten(const ArrayRef& in, const Shape& shape) { - return toNdArray(in).reshape(shape); -} - -} // namespace spu diff --git a/libspu/mpc/cheetah/array_ref.h b/libspu/mpc/cheetah/array_ref.h deleted file mode 100644 index 88d50501..00000000 --- a/libspu/mpc/cheetah/array_ref.h +++ /dev/null @@ -1,240 +0,0 @@ -// Copyright 2021 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include - -#include "yacl/base/buffer.h" - -#include "libspu/core/bit_utils.h" -#include "libspu/core/ndarray_ref.h" -#include "libspu/core/parallel_utils.h" -#include "libspu/core/type.h" -#include "libspu/core/vectorize.h" - -namespace spu { -namespace detail { - -void strided_copy(int64_t numel, int64_t elsize, void* dst, int64_t dstride, - void const* src, int64_t sstride); - -} - -// ArrayRef is a reference type which represent an strided array of objects. -class ArrayRef { - std::shared_ptr buf_{nullptr}; - - // element type. - Type eltype_{}; - - // number of elements. - int64_t numel_{0}; - - // element stride, in number of elements. - int64_t stride_{0}; - - // start offset from the mem buffer, in bytes. - int64_t offset_{0}; - - public: - ArrayRef() = default; - - // full constructor - explicit ArrayRef(std::shared_ptr buf, Type eltype, - int64_t numel, int64_t stride, int64_t offset); - - // create a new buffer of uninitialized elements and ref to it. - explicit ArrayRef(const Type& eltype, size_t numel); - - // Return total number of elements. - int64_t numel() const { return numel_; } - - inline size_t elsize() const { return eltype_.size(); } - - int64_t stride() const { return stride_; } - - int64_t offset() const { return offset_; } - - const Type& eltype() const { return eltype_; } - - Type& eltype() { return eltype_; } - - // https://numpy.org/doc/stable/user/basics.indexing.html#slicing-and-striding - ArrayRef slice(int64_t start, int64_t stop, int64_t stride = 1) const; - - std::shared_ptr buf() const { return buf_; } - - // Create a new buffer if current underline buffer is not compact. - // while compact means offset > 0 or stride != elsize. - // Or return the underline buffer. - std::shared_ptr getOrCreateCompactBuf() const; - - bool isCompact() const; - - ArrayRef clone() const; - - // View this array ref as another type. - // @param force, true if ignore the type check. - ArrayRef as(const Type& new_ty, bool force = false) const; - - // Test two array are bitwise equal - bool operator==(const ArrayRef& other) const; - - // Get data pointer - void* data() { - if (buf_) { - return reinterpret_cast(buf_->data() + offset_); - } - return nullptr; - } - void const* data() const { - if (buf_) { - return reinterpret_cast(buf_->data() + offset_); - } - return nullptr; - } - - // Get element. - template - T& at(int64_t pos) { - return *reinterpret_cast(static_cast(data()) + - stride_ * pos * elsize()); - } - template - const T& at(int64_t pos) const { - return *reinterpret_cast(static_cast(data()) + - stride_ * pos * elsize()); - } -}; - -std::ostream& operator<<(std::ostream& out, const ArrayRef& v); - -template <> -struct SimdTrait { - using PackInfo = std::vector; - - template - static ArrayRef pack(InputIt first, InputIt last, PackInfo& pi) { - SPU_ENFORCE(first != last); - - size_t total_numel = 0; - const Type ty = first->eltype(); - for (auto itr = first; itr != last; ++itr) { - SPU_ENFORCE(itr->eltype() == ty, "type mismatch {} != {}", itr->eltype(), - ty); - total_numel += itr->numel(); - } - ArrayRef result(first->eltype(), total_numel); - int64_t res_idx = 0; - for (; first != last; ++first) { - detail::strided_copy(first->numel(), ty.size(), &result.at(res_idx), - result.stride(), &first->at(0), first->stride()); - pi.push_back(first->numel()); - res_idx += first->numel(); - } - return result; - } - - template - static OutputIt unpack(const ArrayRef& v, OutputIt result, - const PackInfo& pi) { - const int64_t total_num = - std::accumulate(pi.begin(), pi.end(), 0, std::plus<>()); - - SPU_ENFORCE(v.numel() == total_num, "split number mismatch {} != {}", - v.numel(), total_num); - - int64_t offset = 0; - for (const auto& sz : pi) { - *result++ = ArrayRef(v.buf(), v.eltype(), sz, v.stride(), offset); - offset += sz * v.elsize(); - } - - return result; - } -}; - -ArrayRef makeConstantArrayRef(const Type& eltype, size_t numel); - -// A strided array view type. -template -class ArrayView { - T* const data_; - int64_t const stride_; - int64_t const numel_; - - public: - // Note: we explicit discard const correctness due to the complexity. - explicit ArrayView(const ArrayRef& arr) - : data_(const_cast(&arr.at(0))), - stride_(arr.stride()), - numel_(arr.numel()) {} - - explicit ArrayView(T* data, int64_t stride, int64_t numel) - : data_(data), stride_(stride), numel_(numel) {} - - int64_t numel() const { return numel_; } - - int64_t stride() const { return stride_; } - - bool isCompact() const { return stride_ == 1 || numel_ == 1; } - - ArrayRef clone() const { - ArrayRef res(makePtType(), numel_); - detail::strided_copy(numel_, sizeof(T), res.data(), res.stride(), data_, - stride_); - return res; - } - - T* data() { return data_; } - - T& operator[](size_t idx) { return *(data_ + idx * stride_); } - - T const& operator[](size_t idx) const { return *(data_ + idx * stride_); } - - // TODO: test me. - size_t maxBitWidth() const { - if (numel_ == 0) { - return 0; - } - if (stride() == 0) { - return BitWidth(this->operator[](0)); - } - - size_t res = preduce( - 0, numel(), - [&](int64_t begin, int64_t end) { - size_t partial_max = 0; - for (int64_t idx = begin; idx < end; ++idx) { - partial_max = - std::max(partial_max, BitWidth(this->operator[](idx))); - } - return partial_max; - }, - [](const size_t& a, const size_t& b) { return std::max(a, b); }); - - return res; - } -}; - -// This function is not free, miht have cost...do not use it in a loop -ArrayRef flatten(const NdArrayRef& in); - -// These two functions are free cast from ArrayRef to NdArrayRef -NdArrayRef unflatten(const ArrayRef& in, const Shape& shape); -NdArrayRef toNdArray(const ArrayRef& in); - -} // namespace spu diff --git a/libspu/mpc/cheetah/array_ref_test.cc b/libspu/mpc/cheetah/array_ref_test.cc deleted file mode 100644 index f431205b..00000000 --- a/libspu/mpc/cheetah/array_ref_test.cc +++ /dev/null @@ -1,165 +0,0 @@ -// Copyright 2021 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "libspu/mpc/cheetah/array_ref.h" - -#include "gmock/gmock.h" -#include "gtest/gtest.h" - -namespace spu { - -TEST(ArrayRefTest, Empty) { - ArrayRef a; - EXPECT_EQ(a.numel(), 0); - EXPECT_EQ(a.stride(), 0); - EXPECT_EQ(a.elsize(), 0); - EXPECT_EQ(a.offset(), 0); -} - -TEST(ArrayRefTest, Simple) { - constexpr size_t kBufSize = 100; - auto mem = std::make_shared(kBufSize); - - ArrayRef a(mem, I32, 100 / 4, 1, 0); - - EXPECT_EQ(a.numel(), kBufSize / I32.size()); - EXPECT_EQ(a.stride(), 1); - EXPECT_EQ(a.elsize(), I32.size()); - EXPECT_EQ(a.offset(), 0); - - ArrayRef b(mem, I32, 10, 2, 20); - EXPECT_EQ(b.numel(), 10); - EXPECT_EQ(b.stride(), 2); - EXPECT_EQ(b.elsize(), 4); - EXPECT_EQ(b.offset(), 20); - - EXPECT_EQ(a.buf(), b.buf()); - - // test packing - { - EXPECT_EQ(HasSimdTrait::value, true); - EXPECT_EQ(HasSimdTrait::value, false); - // - SimdTrait::PackInfo pi; - std::vector l = {a, b}; - ArrayRef packed = SimdTrait::pack(l.begin(), l.end(), pi); - - EXPECT_EQ(packed.numel(), a.numel() + b.numel()); - EXPECT_EQ(packed.stride(), 1); - EXPECT_EQ(packed.elsize(), I32.size()); - EXPECT_EQ(packed.offset(), 0); - - for (int64_t idx = 0; idx < packed.numel(); idx++) { - if (idx < a.numel()) { - EXPECT_EQ(memcmp(&packed.at(idx), &a.at(idx), a.elsize()), 0); - } else { - EXPECT_EQ(memcmp(&packed.at(idx), &b.at(idx - a.numel()), b.elsize()), - 0); - } - } - - std::vector unpacked; - SimdTrait::unpack(packed, std::back_inserter(unpacked), pi); - - EXPECT_EQ(unpacked.size(), 2); - EXPECT_EQ(unpacked[0].numel(), a.numel()); - EXPECT_EQ(unpacked[0].eltype(), a.eltype()); - EXPECT_EQ(unpacked[1].numel(), b.numel()); - EXPECT_EQ(unpacked[1].eltype(), b.eltype()); - - for (int64_t idx = 0; idx < a.numel(); idx++) { - EXPECT_EQ(memcmp(&unpacked[0].at(idx), &a.at(idx), a.elsize()), 0); - } - - for (int64_t idx = 0; idx < b.numel(); idx++) { - EXPECT_EQ(memcmp(&unpacked[1].at(idx), &b.at(idx), b.elsize()), 0); - } - } -} - -TEST(ArrayRefTest, Slice) { - constexpr size_t kBufSize = 100; - auto mem = std::make_shared(kBufSize); - - // GIVEN - ArrayRef a(mem, I32, 100 / 4, 1, 0); - EXPECT_EQ(a.numel(), kBufSize / I32.size()); - EXPECT_EQ(a.stride(), 1); - EXPECT_EQ(a.elsize(), I32.size()); - EXPECT_EQ(a.offset(), 0); - for (int32_t i = 0; i < 25; i++) { - a.at(i) = i; - } - - // WHEN - ArrayRef b = a.slice(10, 20); - // THEN - { - EXPECT_EQ(a.buf(), b.buf()); - EXPECT_EQ(b.numel(), 10); - EXPECT_EQ(b.stride(), 1); - EXPECT_EQ(b.elsize(), 4); - EXPECT_EQ(b.offset(), 40); - for (int32_t i = 0; i < 10; i++) { - EXPECT_EQ(b.at(i), 10 + i); - } - } - - // WHEN - ArrayRef c = b.slice(5, 10, 2); - // THEN - { - EXPECT_EQ(a.buf(), c.buf()); - EXPECT_EQ(c.numel(), 3); - EXPECT_EQ(c.stride(), 2); - EXPECT_EQ(c.elsize(), 4); - EXPECT_EQ(c.offset(), 60); - for (int32_t i = 0; i < 3; i++) { - EXPECT_EQ(c.at(i), 15 + 2 * i); - } - } -} - -TEST(ArrayRefTest, Strides) { - // Make 3 element, strides = 2 array - ArrayRef a(std::make_shared(6 * sizeof(int32_t)), - makePtType(PT_I32), 3, 2, 0); - - EXPECT_EQ(a.numel(), 3); - EXPECT_EQ(a.stride(), 2); - EXPECT_EQ(a.elsize(), sizeof(int32_t)); - - // Fill array with 0 1 2 3 4 5 - std::iota(static_cast(a.data()), - reinterpret_cast(reinterpret_cast(a.data()) + - a.buf()->size()), - 0); - - EXPECT_EQ(a.at(0), 0); - EXPECT_EQ(a.at(1), 2); - EXPECT_EQ(a.at(2), 4); - - // Make a compact clone - auto b = a.clone(); - - EXPECT_TRUE(b.isCompact()); - EXPECT_EQ(b.numel(), 3); - EXPECT_EQ(b.stride(), 1); - - EXPECT_EQ(b.at(0), 0); - EXPECT_EQ(b.at(1), 2); - EXPECT_EQ(b.at(2), 4); -} - -} // namespace spu diff --git a/libspu/mpc/cheetah/boolean.cc b/libspu/mpc/cheetah/boolean.cc index 2743ccbc..ca181ac3 100644 --- a/libspu/mpc/cheetah/boolean.cc +++ b/libspu/mpc/cheetah/boolean.cc @@ -15,44 +15,41 @@ #include "libspu/mpc/cheetah/boolean.h" #include "libspu/mpc/cheetah/state.h" -#include "libspu/mpc/common/pv2k.h" -#include "libspu/mpc/utils/ring_ops.h" namespace spu::mpc::cheetah { -constexpr size_t kMinWorkSize = 5000; NdArrayRef AndBB::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const NdArrayRef& rhs) const { SPU_ENFORCE_EQ(lhs.shape(), rhs.shape()); - auto* comm = ctx->getState(); - auto* ot_state = ctx->getState(); - size_t n = lhs.numel(); - size_t nworker = - std::min(ot_state->parallel_size(), CeilDiv(n, kMinWorkSize)); - size_t work_load = nworker == 0 ? 0 : CeilDiv(n, nworker); - for (size_t w = 0; w < nworker; ++w) { - ot_state->LazyInit(comm, w); + int64_t numel = lhs.numel(); + NdArrayRef out(lhs.eltype(), lhs.shape()); + if (numel == 0) { + return out; } - NdArrayRef z(lhs.eltype(), lhs.shape()); - auto flat_lhs = flatten(lhs); - auto flat_rhs = flatten(rhs); - yacl::parallel_for(0, nworker, 1, [&](size_t bgn, size_t end) { - for (size_t job = bgn; job < end; ++job) { - size_t slice_bgn = std::min(n, job * work_load); - size_t slice_end = std::min(n, slice_bgn + work_load); + int64_t nworker = InitOTState(ctx, numel); + int64_t work_load = nworker == 0 ? 0 : CeilDiv(numel, nworker); + + auto flat_lhs = lhs.reshape({lhs.numel()}); + auto flat_rhs = rhs.reshape({rhs.numel()}); + yacl::parallel_for(0, nworker, 1, [&](int64_t bgn, int64_t end) { + for (int64_t job = bgn; job < end; ++job) { + int64_t slice_bgn = std::min(numel, job * work_load); + int64_t slice_end = std::min(numel, slice_bgn + work_load); if (slice_bgn == slice_end) { break; } - auto out_slice = - ot_state->get(job)->BitwiseAnd(flat_lhs.slice(slice_bgn, slice_end), - flat_rhs.slice(slice_bgn, slice_end)); - std::memcpy(&z.at(slice_bgn), &out_slice.at(0), + + auto out_slice = ctx->getState()->get(job)->BitwiseAnd( + flat_lhs.slice({slice_bgn}, {slice_end}, {1}), + flat_rhs.slice({slice_bgn}, {slice_end}, {1})); + std::memcpy(&out.at(slice_bgn), &out_slice.at(0), out_slice.elsize() * out_slice.numel()); } }); - return z.as(lhs.eltype()); + + return out; } } // namespace spu::mpc::cheetah diff --git a/libspu/mpc/cheetah/boolean.h b/libspu/mpc/cheetah/boolean.h index 17254d4e..cdca12db 100644 --- a/libspu/mpc/cheetah/boolean.h +++ b/libspu/mpc/cheetah/boolean.h @@ -15,19 +15,63 @@ #pragma once #include "libspu/mpc/kernel.h" -#include "libspu/mpc/semi2k/boolean.h" namespace spu::mpc::cheetah { -using CommonTypeB = spu::mpc::semi2k::CommonTypeB; +class CommonTypeB : public Kernel { + public: + static constexpr char kBindName[] = "common_type_b"; + + Kind kind() const override { return Kind::Dynamic; } + + void evaluate(KernelEvalContext* ctx) const override; +}; + +class CastTypeB : public CastTypeKernel { + public: + static constexpr char kBindName[] = "cast_type_b"; + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, + const Type& to_type) const override; +}; + +class B2P : public UnaryKernel { + public: + static constexpr char kBindName[] = "b2p"; + + ce::CExpr latency() const override { return ce::Const(1); } + + ce::CExpr comm() const override { return ce::K() * (ce::N() - 1); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in) const override; +}; -using CastTypeB = spu::mpc::semi2k::CastTypeB; +class P2B : public UnaryKernel { + public: + static constexpr char kBindName[] = "p2b"; + + ce::CExpr latency() const override { return ce::Const(0); } -using B2P = spu::mpc::semi2k::B2P; + ce::CExpr comm() const override { return ce::Const(0); } -using P2B = spu::mpc::semi2k::P2B; + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in) const override; +}; + +class AndBP : public BinaryKernel { + public: + static constexpr char kBindName[] = "and_bp"; -using AndBP = spu::mpc::semi2k::AndBP; + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& lhs, + const NdArrayRef& rhs) const override; +}; class AndBB : public BinaryKernel { public: @@ -35,24 +79,103 @@ class AndBB : public BinaryKernel { Kind kind() const override { return Kind::Dynamic; } - ce::CExpr latency() const override { return ce::Const(1); } + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& lhs, + const NdArrayRef& rhs) const override; +}; + +class XorBP : public BinaryKernel { + public: + static constexpr char kBindName[] = "xor_bp"; + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } - ce::CExpr comm() const override { return ce::K() * 2 * (ce::N() - 1); } + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& lhs, + const NdArrayRef& rhs) const override; +}; + +class XorBB : public BinaryKernel { + public: + static constexpr char kBindName[] = "xor_bb"; + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& lhs, const NdArrayRef& rhs) const override; }; -using XorBP = spu::mpc::semi2k::XorBP; +class LShiftB : public ShiftKernel { + public: + static constexpr char kBindName[] = "lshift_b"; + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, + size_t shift) const override; +}; + +class RShiftB : public ShiftKernel { + public: + static constexpr char kBindName[] = "rshift_b"; + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, + size_t shift) const override; +}; + +class ARShiftB : public ShiftKernel { + public: + static constexpr char kBindName[] = "arshift_b"; + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, + size_t shift) const override; +}; + +class BitrevB : public BitrevKernel { + public: + static constexpr char kBindName[] = "bitrev_b"; + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, size_t start, + size_t end) const override; +}; + +class BitIntlB : public BitSplitKernel { + public: + static constexpr char kBindName[] = "bitintl_b"; -using XorBB = spu::mpc::semi2k::XorBB; + ce::CExpr latency() const override { return ce::Const(0); } -using LShiftB = spu::mpc::semi2k::LShiftB; + ce::CExpr comm() const override { return ce::Const(0); } -using RShiftB = spu::mpc::semi2k::RShiftB; + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, + size_t stride) const override; +}; + +class BitDeintlB : public BitSplitKernel { + public: + static constexpr char kBindName[] = "bitdeintl_b"; -using ARShiftB = spu::mpc::semi2k::ARShiftB; + ce::CExpr latency() const override { return ce::Const(0); } -using BitrevB = spu::mpc::semi2k::BitrevB; + ce::CExpr comm() const override { return ce::Const(0); } + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& in, + size_t stride) const override; +}; } // namespace spu::mpc::cheetah diff --git a/libspu/mpc/cheetah/boolean_semi2k.cc b/libspu/mpc/cheetah/boolean_semi2k.cc new file mode 100644 index 00000000..06ad18a6 --- /dev/null +++ b/libspu/mpc/cheetah/boolean_semi2k.cc @@ -0,0 +1,214 @@ +// Copyright 2021 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "libspu/mpc/cheetah/boolean.h" +#include "libspu/mpc/cheetah/state.h" +#include "libspu/mpc/cheetah/type.h" +#include "libspu/mpc/common/prg_state.h" +#include "libspu/mpc/common/pv2k.h" +#include "libspu/mpc/utils/ring_ops.h" + +namespace spu::mpc::cheetah { +namespace { + +size_t getNumBits(const NdArrayRef& in) { + if (in.eltype().isa()) { + const auto field = in.eltype().as()->field(); + return DISPATCH_ALL_FIELDS(field, "_", + [&]() { return maxBitWidth(in); }); + } else if (in.eltype().isa()) { + return in.eltype().as()->nbits(); + } else { + SPU_THROW("should not be here, {}", in.eltype()); + } +} + +NdArrayRef makeBShare(const NdArrayRef& r, FieldType field, size_t nbits) { + const auto ty = makeType(field, nbits); + return r.as(ty); +} +} // namespace + +void CommonTypeB::evaluate(KernelEvalContext* ctx) const { + const Type& lhs = ctx->getParam(0); + const Type& rhs = ctx->getParam(1); + + SPU_ENFORCE(lhs == rhs, "cheetah always use same bshare type, lhs={}, rhs={}", + lhs, rhs); + + ctx->setOutput(lhs); +} + +NdArrayRef CastTypeB::proc(KernelEvalContext* ctx, const NdArrayRef& in, + const Type& to_type) const { + SPU_ENFORCE(in.eltype() == to_type, + "cheetah always use same bshare type, lhs={}, rhs={}", + in.eltype(), to_type); + return in; +} + +NdArrayRef B2P::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { + const auto field = in.eltype().as()->field(); + auto* comm = ctx->getState(); + auto out = comm->allReduce(ReduceOp::XOR, in, kBindName); + return out.as(makeType(field)); +} + +NdArrayRef P2B::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { + const auto field = in.eltype().as()->field(); + auto* prg_state = ctx->getState(); + + auto* comm = ctx->getState(); + + auto [r0, r1] = prg_state->genPrssPair(field, in.shape()); + auto x = ring_xor(r0, r1).as(makeType(field, 0)); + + if (comm->getRank() == 0) { + ring_xor_(x, in); + } + + return makeBShare(x, field, getNumBits(in)); +} + +NdArrayRef AndBP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, + const NdArrayRef& rhs) const { + SPU_ENFORCE(lhs.shape() == rhs.shape()); + + const auto field = ctx->getState()->getDefaultField(); + const size_t out_nbits = std::min(getNumBits(lhs), getNumBits(rhs)); + NdArrayRef out(makeType(field, out_nbits), lhs.shape()); + + DISPATCH_ALL_FIELDS(field, "_", [&]() { + NdArrayView _lhs(lhs); + NdArrayView _rhs(rhs); + NdArrayView _out(out); + + pforeach(0, lhs.numel(), + [&](int64_t idx) { _out[idx] = _lhs[idx] & _rhs[idx]; }); + }); + return out; +} + +NdArrayRef XorBP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, + const NdArrayRef& rhs) const { + SPU_ENFORCE(lhs.numel() == rhs.numel()); + + auto* comm = ctx->getState(); + + const auto field = lhs.eltype().as()->field(); + const size_t out_nbits = std::max(getNumBits(lhs), getNumBits(rhs)); + + if (comm->getRank() == 0) { + return makeBShare(ring_xor(lhs, rhs), field, out_nbits); + } + + return makeBShare(lhs, field, out_nbits); +} + +NdArrayRef XorBB::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, + const NdArrayRef& rhs) const { + SPU_ENFORCE(lhs.numel() == rhs.numel()); + + const auto field = ctx->getState()->getDefaultField(); + const size_t out_nbits = std::max(getNumBits(lhs), getNumBits(rhs)); + return makeBShare(ring_xor(lhs, rhs), field, out_nbits); +} + +NdArrayRef LShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, + size_t shift) const { + const auto field = in.eltype().as()->field(); + shift %= SizeOf(field) * 8; + + size_t out_nbits = in.eltype().as()->nbits() + shift; + out_nbits = std::clamp(out_nbits, static_cast(0), SizeOf(field) * 8); + + return makeBShare(ring_lshift(in, shift), field, out_nbits); +} + +NdArrayRef RShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, + size_t shift) const { + const auto field = in.eltype().as()->field(); + shift %= SizeOf(field) * 8; + + size_t nbits = in.eltype().as()->nbits(); + size_t out_nbits = nbits - std::min(nbits, shift); + SPU_ENFORCE(nbits <= SizeOf(field) * 8); + + return makeBShare(ring_rshift(in, shift), field, out_nbits); +} + +NdArrayRef ARShiftB::proc(KernelEvalContext* ctx, const NdArrayRef& in, + size_t shift) const { + const auto field = in.eltype().as()->field(); + shift %= SizeOf(field) * 8; + + // arithmetic right shift expects to work on ring, or the behaviour is + // undefined. + return makeBShare(ring_arshift(in, shift), field, SizeOf(field) * 8); +} + +NdArrayRef BitrevB::proc(KernelEvalContext* ctx, const NdArrayRef& in, + size_t start, size_t end) const { + const auto field = in.eltype().as()->field(); + + SPU_ENFORCE(start <= end); + SPU_ENFORCE(end <= SizeOf(field) * 8); + const size_t out_nbits = std::max(getNumBits(in), end); + + // TODO: more accurate bits. + return makeBShare(ring_bitrev(in, start, end), field, out_nbits); +} + +NdArrayRef BitIntlB::proc(KernelEvalContext* ctx, const NdArrayRef& in, + size_t stride) const { + const auto field = in.eltype().as()->field(); + const auto nbits = getNumBits(in); + SPU_ENFORCE(absl::has_single_bit(nbits)); + + NdArrayRef out(in.eltype(), in.shape()); + auto numel = in.numel(); + + DISPATCH_ALL_FIELDS(field, "_", [&]() { + NdArrayView _in(in); + NdArrayView _out(out); + + pforeach(0, numel, [&](int64_t idx) { + _out[idx] = BitIntl(_in[idx], stride, nbits); + }); + }); + + return out; +} + +NdArrayRef BitDeintlB::proc(KernelEvalContext* ctx, const NdArrayRef& in, + size_t stride) const { + const auto field = in.eltype().as()->field(); + const auto nbits = getNumBits(in); + SPU_ENFORCE(absl::has_single_bit(nbits)); + + NdArrayRef out(in.eltype(), in.shape()); + auto numel = in.numel(); + + DISPATCH_ALL_FIELDS(field, "_", [&]() { + NdArrayView _in(in); + NdArrayView _out(out); + + pforeach(0, numel, [&](int64_t idx) { + _out[idx] = BitDeintl(_in[idx], stride, nbits); + }); + }); + + return out; +} +} // namespace spu::mpc::cheetah \ No newline at end of file diff --git a/libspu/mpc/cheetah/conversion.cc b/libspu/mpc/cheetah/conversion.cc index 70476afe..8d73e16d 100644 --- a/libspu/mpc/cheetah/conversion.cc +++ b/libspu/mpc/cheetah/conversion.cc @@ -14,41 +14,71 @@ #include "libspu/mpc/cheetah/conversion.h" +#include "yacl/utils/parallel.h" + +#include "libspu/mpc/ab_api.h" #include "libspu/mpc/cheetah/state.h" +#include "libspu/mpc/cheetah/type.h" +#include "libspu/mpc/common/prg_state.h" #include "libspu/mpc/common/pv2k.h" -#include "libspu/mpc/semi2k/type.h" // TODO: use cheetah type +#include "libspu/mpc/utils/ring_ops.h" namespace spu::mpc::cheetah { -constexpr size_t kMinWorkSize = 5000; -NdArrayRef B2A::proc(KernelEvalContext* ctx, const NdArrayRef& x) const { +static NdArrayRef wrap_add_bb(SPUContext* ctx, const NdArrayRef& x, + const NdArrayRef& y) { + SPU_ENFORCE(x.shape() == y.shape()); + return UnwrapValue(add_bb(ctx, WrapValue(x), WrapValue(y))); +} + +NdArrayRef A2B::proc(KernelEvalContext* ctx, const NdArrayRef& x) const { + const auto field = x.eltype().as()->field(); auto* comm = ctx->getState(); - auto* ot_state = ctx->getState(); + auto* prg_state = ctx->getState(); + + std::vector bshrs; + const auto bty = makeType(field); + for (size_t idx = 0; idx < comm->getWorldSize(); idx++) { + auto [r0, r1] = prg_state->genPrssPair(field, x.shape()); + auto b = ring_xor(r0, r1).as(bty); + + if (idx == comm->getRank()) { + ring_xor_(b, x); + } + bshrs.push_back(b.as(bty)); + } + + NdArrayRef res = vreduce(bshrs.begin(), bshrs.end(), + [&](const NdArrayRef& xx, const NdArrayRef& yy) { + return wrap_add_bb(ctx->sctx(), xx, yy); + }); + return res.as(bty); +} + +NdArrayRef B2A::proc(KernelEvalContext* ctx, const NdArrayRef& x) const { size_t n = x.numel(); - size_t nworker = - std::min(ot_state->parallel_size(), CeilDiv(n, kMinWorkSize)); + size_t nworker = InitOTState(ctx, n); size_t work_load = nworker == 0 ? 0 : CeilDiv(n, nworker); - for (size_t w = 0; w < nworker; ++w) { - ot_state->LazyInit(comm, w); - } const auto field = ctx->getState()->getDefaultField(); + const auto flatten_x = x.reshape({static_cast(n)}); NdArrayRef out(x.eltype(), x.shape()); - yacl::parallel_for(0, nworker, 1, [&](size_t bgn, size_t end) { - for (size_t job = bgn; job < end; ++job) { - size_t slice_bgn = std::min(n, job * work_load); - size_t slice_end = std::min(n, slice_bgn + work_load); + + yacl::parallel_for(0, nworker, 1, [&](int64_t bgn, int64_t end) { + for (int64_t job = bgn; job < end; ++job) { + auto slice_bgn = std::min(n, job * work_load); + auto slice_end = std::min(n, slice_bgn + work_load); if (slice_bgn == slice_end) { break; } - auto out_slice = - ot_state->get(job)->B2A(flatten(x).slice(slice_bgn, slice_end)); + auto out_slice = ctx->getState()->get(job)->B2A( + flatten_x.slice({slice_bgn}, {slice_end}, {1})); std::memcpy(&out.at(slice_bgn), &out_slice.at(0), out_slice.elsize() * out_slice.numel()); } }); - return out.as(makeType(field)); + return out.as(makeType(field)); } } // namespace spu::mpc::cheetah diff --git a/libspu/mpc/cheetah/conversion.h b/libspu/mpc/cheetah/conversion.h index c7f8ea04..a49424e0 100644 --- a/libspu/mpc/cheetah/conversion.h +++ b/libspu/mpc/cheetah/conversion.h @@ -15,11 +15,19 @@ #pragma once #include "libspu/mpc/kernel.h" -#include "libspu/mpc/semi2k/conversion.h" namespace spu::mpc::cheetah { -typedef spu::mpc::semi2k::A2B A2B; +class A2B : public UnaryKernel { + public: + static constexpr char kBindName[] = "a2b"; + + ce::CExpr latency() const override { return ce::Const(0); } + + ce::CExpr comm() const override { return ce::Const(0); } + + NdArrayRef proc(KernelEvalContext* ctx, const NdArrayRef& x) const override; +}; class B2A : public UnaryKernel { public: diff --git a/libspu/mpc/cheetah/io.cc b/libspu/mpc/cheetah/io.cc index c3edf35c..bfb4c505 100644 --- a/libspu/mpc/cheetah/io.cc +++ b/libspu/mpc/cheetah/io.cc @@ -14,12 +14,108 @@ #include "libspu/mpc/cheetah/io.h" -#include "libspu/mpc/semi2k/type.h" +#include "libspu/mpc/cheetah/type.h" +#include "libspu/mpc/common/pv2k.h" +#include "libspu/mpc/utils/ring_ops.h" namespace spu::mpc::cheetah { +Type CheetahIo::getShareType(Visibility vis, int owner_rank) const { + if (vis == VIS_PUBLIC) { + return makeType(field_); + } else if (vis == VIS_SECRET) { + if (owner_rank >= 0 && owner_rank < static_cast(world_size_)) { + return makeType(field_, owner_rank); + } else { + return makeType(field_); + } + } + + SPU_THROW("unsupported vis type {}", vis); +} + +std::vector CheetahIo::toShares(const NdArrayRef& raw, + Visibility vis, + int owner_rank) const { + SPU_ENFORCE(raw.eltype().isa(), "expected RingTy, got {}", + raw.eltype()); + const auto field = raw.eltype().as()->field(); + SPU_ENFORCE(field == field_, "expect raw value encoded in field={}, got={}", + field_, field); + + if (vis == VIS_PUBLIC) { + const auto share = raw.as(makeType(field)); + return std::vector(world_size_, share); + } else if (vis == VIS_SECRET) { +#if !defined(SPU_ENABLE_PRIVATE_TYPE) + owner_rank = -1; +#endif + + if (owner_rank >= 0 && owner_rank < static_cast(world_size_)) { + // indicates private + std::vector shares; + const auto ty = makeType(field, owner_rank); + for (int idx = 0; idx < static_cast(world_size_); idx++) { + if (idx == owner_rank) { + shares.push_back(raw.as(ty)); + } else { + shares.push_back(makeConstantArrayRef(ty, raw.shape())); + } + } + return shares; + } else { + // normal secret + SPU_ENFORCE(owner_rank == -1, "not a valid owner {}", owner_rank); + + std::vector shares; + const auto ty = makeType(field); + + // by default, make as arithmetic share. + const auto splits = ring_rand_additive_splits(raw, world_size_); + shares.reserve(splits.size()); + for (const auto& split : splits) { + shares.emplace_back(split.as(ty)); + } + return shares; + } + } + SPU_THROW("unsupported vis type {}", vis); +} + +NdArrayRef CheetahIo::fromShares(const std::vector& shares) const { + const auto& eltype = shares.at(0).eltype(); + const auto field = eltype.as()->field(); + + if (eltype.isa()) { + return shares[0].as(makeType(field)); + } else if (eltype.isa()) { + SPU_ENFORCE(field_ == eltype.as()->field()); + const size_t owner = eltype.as()->owner(); + return shares[owner].as(makeType(field_)); + } else if (eltype.isa()) { + auto res = ring_zeros(field, shares[0].shape()); + + for (const auto& share : shares) { + // Currently, only packed zeros are not compact, this is for colocation + // optimization + if (!share.isCompact()) { + continue; + } + if (eltype.isa()) { + ring_add_(res, share); + } else if (eltype.isa()) { + ring_xor_(res, share); + } else { + SPU_THROW("invalid share type {}", eltype); + } + } + return res; + } + SPU_THROW("unsupported eltype {}", eltype); +} + std::unique_ptr makeCheetahIo(FieldType field, size_t npc) { - semi2k::registerTypes(); + cheetah::registerTypes(); return std::make_unique(field, npc); } diff --git a/libspu/mpc/cheetah/io.h b/libspu/mpc/cheetah/io.h index bb163cd0..2154e31e 100644 --- a/libspu/mpc/cheetah/io.h +++ b/libspu/mpc/cheetah/io.h @@ -14,15 +14,22 @@ #pragma once -#include "libspu/mpc/semi2k/io.h" +#include "libspu/mpc/io_interface.h" namespace spu::mpc::cheetah { - -class CheetahIo final : public semi2k::Semi2kIo { +class CheetahIo : public BaseIo { public: - using semi2k::Semi2kIo::Semi2kIo; + using BaseIo::BaseIo; + + // when owner rank is valid (0 <= owner_rank < world_size), colocation + // optization may be applied + std::vector toShares(const NdArrayRef& raw, Visibility vis, + int owner_rank) const override; + + Type getShareType(Visibility vis, int owner_rank = -1) const override; + + NdArrayRef fromShares(const std::vector& shares) const override; }; std::unique_ptr makeCheetahIo(FieldType field, size_t npc); - } // namespace spu::mpc::cheetah diff --git a/libspu/mpc/cheetah/nonlinear/compare_prot.cc b/libspu/mpc/cheetah/nonlinear/compare_prot.cc index 6035a012..af147360 100644 --- a/libspu/mpc/cheetah/nonlinear/compare_prot.cc +++ b/libspu/mpc/cheetah/nonlinear/compare_prot.cc @@ -14,15 +14,15 @@ #include "libspu/mpc/cheetah/nonlinear/compare_prot.h" -#include "emp-tool/utils/prg.h" +#include "yacl/crypto/tools/prg.h" #include "yacl/link/link.h" #include "libspu/core/type.h" #include "libspu/mpc/cheetah/ot/basic_ot_prot.h" #include "libspu/mpc/cheetah/ot/ferret.h" #include "libspu/mpc/cheetah/ot/util.h" +#include "libspu/mpc/cheetah/type.h" #include "libspu/mpc/common/communicator.h" -#include "libspu/mpc/semi2k/type.h" #include "libspu/mpc/utils/ring_ops.h" namespace spu::mpc::cheetah { @@ -55,27 +55,24 @@ void SetLeafOTMsg(absl::Span ot_messages, uint8_t digit, // The Mill protocol from "CrypTFlow2: Practical 2-Party Secure Inference" // Algorithm 1. REF: https://arxiv.org/pdf/2010.06457.pdf -ArrayRef CompareProtocol::DoCompute(const ArrayRef& inp, bool greater_than, - ArrayRef* keep_eq) { +NdArrayRef CompareProtocol::DoCompute(const NdArrayRef& inp, bool greater_than, + NdArrayRef* keep_eq, int64_t bitwidth) { + SPU_ENFORCE(inp.shape().size() == 1, "need 1D array"); auto field = inp.eltype().as()->field(); - size_t bit_width = SizeOf(field) * 8; - SPU_ENFORCE(bit_width % compare_radix_ == 0, "invalid compare radix {}", - compare_radix_); - - size_t num_digits = CeilDiv(bit_width, compare_radix_); + int64_t num_digits = CeilDiv(bitwidth, (int64_t)compare_radix_); size_t radix = static_cast(1) << compare_radix_; // one-of-N OT - size_t num_cmp = inp.numel(); + int64_t num_cmp = inp.numel(); // init to all zero std::vector digits(num_cmp * num_digits, 0); // Step 1 break into digits \in [0, radix) - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, "break_digits", [&]() { using u2k = std::make_unsigned::type; const auto mask_radix = makeBitsMask(compare_radix_); - ArrayView xinp(inp); + NdArrayView xinp(inp); - for (size_t i = 0; i < num_cmp; ++i) { - for (size_t j = 0; j < num_digits; ++j) { + for (int64_t i = 0; i < num_cmp; ++i) { + for (int64_t j = 0; j < num_digits; ++j) { uint32_t shft = j * compare_radix_; digits[i * num_digits + j] = (xinp[i] >> shft) & mask_radix; } @@ -86,9 +83,15 @@ ArrayRef CompareProtocol::DoCompute(const ArrayRef& inp, bool greater_than, std::vector leaf_eq(num_cmp * num_digits, 0); if (is_sender_) { // Step 2 sample random bits - emp::PRG prg; - prg.random_bool(reinterpret_cast(leaf_cmp.data()), leaf_cmp.size()); - prg.random_bool(reinterpret_cast(leaf_eq.data()), leaf_eq.size()); + yacl::crypto::Prg prg; + prg.Fill(absl::MakeSpan(leaf_cmp)); + prg.Fill(absl::MakeSpan(leaf_eq)); + + // convert u8 random to boolean random + std::transform(leaf_cmp.begin(), leaf_cmp.end(), leaf_cmp.data(), + [](uint8_t v) { return v & 1; }); + std::transform(leaf_eq.begin(), leaf_eq.end(), leaf_eq.data(), + [](uint8_t v) { return v & 1; }); // Step 6-7 set the OT messages with two packed bits (one for compare, one // for equal) @@ -99,13 +102,13 @@ ArrayRef CompareProtocol::DoCompute(const ArrayRef& inp, bool greater_than, absl::Span{leaf_ot_msg.data() + i * radix, radix}; } - for (size_t i = 0; i < num_cmp; ++i) { + for (int64_t i = 0; i < num_cmp; ++i) { auto* this_ot_msg = each_leaf_ot_msg.data() + i * num_digits; auto* this_digit = digits.data() + i * num_digits; auto* this_leaf_cmp = leaf_cmp.data() + i * num_digits; auto* this_leaf_eq = leaf_eq.data() + i * num_digits; - for (size_t j = 0; j < num_digits; ++j) { + for (int64_t j = 0; j < num_digits; ++j) { uint8_t rnd_cmp = this_leaf_cmp[j] & 1; uint8_t rnd_eq = this_leaf_eq[j] & 1; SetLeafOTMsg(this_ot_msg[j], this_digit[j], rnd_cmp, rnd_eq, @@ -122,28 +125,28 @@ ArrayRef CompareProtocol::DoCompute(const ArrayRef& inp, bool greater_than, basic_ot_prot_->GetReceiverCOT()->RecvCMCC(absl::MakeSpan(digits), radix, absl::MakeSpan(leaf_cmp), 2); // extract equality bits from packed messages - for (size_t i = 0; i < num_cmp; ++i) { + for (int64_t i = 0; i < num_cmp; ++i) { auto* this_leaf_cmp = leaf_cmp.data() + i * num_digits; auto* this_leaf_eq = leaf_eq.data() + i * num_digits; - for (size_t j = 0; j < num_digits; ++j) { + for (int64_t j = 0; j < num_digits; ++j) { this_leaf_eq[j] = (this_leaf_cmp[j] >> 1) & 1; this_leaf_cmp[j] &= 1; } } } - using BShrTy = semi2k::BShrTy; auto boolean_t = makeType(field, 1); - ArrayRef prev_cmp = - flatten(ring_zeros(field, {static_cast(num_digits * num_cmp)})) + NdArrayRef prev_cmp = + ring_zeros(field, {static_cast(num_digits * num_cmp)}) .as(boolean_t); - ArrayRef prev_eq = - flatten(ring_zeros(field, {static_cast(num_digits * num_cmp)})) + NdArrayRef prev_eq = + ring_zeros(field, {static_cast(num_digits * num_cmp)}) .as(boolean_t); - DISPATCH_ALL_FIELDS(field, "", [&]() { - ArrayView xprev_cmp(prev_cmp); - ArrayView xprev_eq(prev_eq); - pforeach(0, xprev_cmp.numel(), [&](int64_t i) { + + DISPATCH_ALL_FIELDS(field, "copy_leaf", [&]() { + NdArrayView xprev_cmp(prev_cmp); + NdArrayView xprev_eq(prev_eq); + pforeach(0, prev_cmp.numel(), [&](int64_t i) { xprev_cmp[i] = leaf_cmp[i]; xprev_eq[i] = leaf_eq[i]; }); @@ -162,24 +165,26 @@ ArrayRef CompareProtocol::DoCompute(const ArrayRef& inp, bool greater_than, return _gt.as(boolean_t); } -std::array CompareProtocol::TraversalANDWithEq(ArrayRef cmp, - ArrayRef eq, - size_t num_input, - size_t num_digits) { +std::array CompareProtocol::TraversalANDWithEq( + NdArrayRef cmp, NdArrayRef eq, size_t num_input, size_t num_digits) { + SPU_ENFORCE(cmp.shape().size() == 1, "need 1D array"); + SPU_ENFORCE_EQ(cmp.shape(), eq.shape()); SPU_ENFORCE_EQ(cmp.numel(), eq.numel()); SPU_ENFORCE_EQ(num_input * num_digits, (size_t)cmp.numel()); for (size_t i = 1; i <= num_digits; i += 1) { - size_t current_num_digits = num_digits / (1 << (i - 1)); - if (current_num_digits == 1) break; + int64_t current_num_digits = num_digits / (1 << (i - 1)); + if (current_num_digits == 1) { + break; + } // eq[i-1, j] <- eq[i, 2*j] * eq[i, 2*j+1] // cmp[i-1, j] <- cmp[i,2*j] * eq[i,2*j+1] ^ cmp[i,2*j+1] - size_t n = current_num_digits * num_input; - auto lhs_eq = eq.slice(0, n, 2); - auto rhs_eq = eq.slice(1, n, 2); + int64_t n = current_num_digits * num_input; + auto lhs_eq = eq.slice({0}, {n}, {2}); + auto rhs_eq = eq.slice({1}, {n}, {2}); - auto lhs_cmp = cmp.slice(0, n, 2); - auto rhs_cmp = cmp.slice(1, n, 2); + auto lhs_cmp = cmp.slice({0}, {n}, {2}); + auto rhs_cmp = cmp.slice({1}, {n}, {2}); // Correlated ANDs // _eq = rhs_eq & lhs_eq @@ -187,14 +192,14 @@ std::array CompareProtocol::TraversalANDWithEq(ArrayRef cmp, auto [_eq, _cmp] = basic_ot_prot_->CorrelatedBitwiseAnd(rhs_eq, lhs_eq, lhs_cmp); eq = _eq; - cmp = flatten(ring_xor(toNdArray(_cmp), toNdArray(rhs_cmp))); + cmp = ring_xor(_cmp, rhs_cmp); } return {cmp, eq}; } -ArrayRef CompareProtocol::TraversalAND(ArrayRef cmp, ArrayRef eq, - size_t num_input, size_t num_digits) { +NdArrayRef CompareProtocol::TraversalAND(NdArrayRef cmp, NdArrayRef eq, + size_t num_input, size_t num_digits) { // Tree-based traversal ANDs // lt0[0], lt0[1], ..., lt0[M], // lt1[0], lt1[1], ..., lt1[M], @@ -214,46 +219,52 @@ ArrayRef CompareProtocol::TraversalAND(ArrayRef cmp, ArrayRef eq, // lt1[1], lt0[3], ..., lt0[2*j+1] // .... // ltn[1], ltn[3], ..., ltn[2*j+1] - SPU_ENFORCE_EQ(cmp.numel(), eq.numel()); + SPU_ENFORCE(cmp.shape().size() == 1, "need 1D Array"); + SPU_ENFORCE_EQ(cmp.shape(), eq.shape()); SPU_ENFORCE_EQ(num_input * num_digits, (size_t)cmp.numel()); for (size_t i = 1; i <= num_digits; i += 1) { size_t current_num_digits = num_digits / (1 << (i - 1)); - if (current_num_digits == 1) break; + if (current_num_digits == 1) { + break; + } // eq[i-1, j] <- eq[i, 2*j] * eq[i, 2*j+1] // cmp[i-1, j] <- cmp[i,2*j] * eq[i,2*j+1] ^ cmp[i,2*j+1] - size_t n = current_num_digits * num_input; - auto lhs_eq = eq.slice(0, n, 2); - auto rhs_eq = eq.slice(1, n, 2); - auto lhs_cmp = cmp.slice(0, n, 2); - auto rhs_cmp = cmp.slice(1, n, 2); + int64_t n = current_num_digits * num_input; + + auto lhs_eq = eq.slice({0}, {n}, {2}); + auto rhs_eq = eq.slice({1}, {n}, {2}); + auto lhs_cmp = cmp.slice({0}, {n}, {2}); + auto rhs_cmp = cmp.slice({1}, {n}, {2}); if (current_num_digits == 2) { cmp = basic_ot_prot_->BitwiseAnd(lhs_cmp, rhs_eq); - auto nd_cmp = toNdArray(cmp); - ring_xor_(nd_cmp, toNdArray(rhs_cmp)); + ring_xor_(cmp, rhs_cmp); // We skip the ANDs for eq on the last loop continue; } // We skip the AND on the 0-th digit which is unnecessary for the next loop. - size_t nrow = num_input; - size_t ncol = current_num_digits / 2; - SPU_ENFORCE_EQ((size_t)lhs_eq.numel(), nrow * ncol); + int64_t nrow = num_input; + int64_t ncol = current_num_digits / 2; + SPU_ENFORCE_EQ(lhs_eq.numel(), nrow * ncol); - ArrayRef _lhs_eq(lhs_eq.eltype(), nrow * (ncol - 1)); - ArrayRef _rhs_eq(lhs_eq.eltype(), nrow * (ncol - 1)); - ArrayRef _lhs_cmp(lhs_cmp.eltype(), nrow * (ncol - 1)); + Shape subshape = {nrow * (ncol - 1)}; + NdArrayRef _lhs_eq(lhs_eq.eltype(), subshape); + NdArrayRef _rhs_eq(rhs_eq.eltype(), subshape); + NdArrayRef _lhs_cmp(lhs_cmp.eltype(), subshape); - ArrayRef _lhs_cmp_col0(lhs_cmp.eltype(), nrow); - ArrayRef _rhs_eq_col0(rhs_eq.eltype(), nrow); + NdArrayRef _lhs_cmp_col0(lhs_cmp.eltype(), {static_cast(nrow)}); + NdArrayRef _rhs_eq_col0(rhs_eq.eltype(), _lhs_cmp_col0.shape()); - for (size_t r = 0; r < nrow; ++r) { + // Skip the 0-th column and take the remains columns + // TODO(lwj): Can we have a better way to avoid such copying? + for (int64_t r = 0; r < nrow; ++r) { std::memcpy(&_rhs_eq_col0.at(r), &rhs_eq.at(r * ncol), rhs_eq.elsize()); std::memcpy(&_lhs_cmp_col0.at(r), &lhs_cmp.at(r * ncol), lhs_cmp.elsize()); - for (size_t c = 1; c < ncol; ++c) { + for (int64_t c = 1; c < ncol; ++c) { std::memcpy(&_lhs_cmp.at(r * (ncol - 1) + c - 1), &lhs_cmp.at(r * ncol + c), lhs_eq.elsize()); std::memcpy(&_lhs_eq.at(r * (ncol - 1) + c - 1), @@ -262,6 +273,7 @@ ArrayRef CompareProtocol::TraversalAND(ArrayRef cmp, ArrayRef eq, &rhs_eq.at(r * ncol + c), rhs_eq.elsize()); } } + // Normal AND on 0-th column auto _next_cmp_col0 = basic_ot_prot_->BitwiseAnd(_rhs_eq_col0, _lhs_cmp_col0); @@ -270,13 +282,14 @@ ArrayRef CompareProtocol::TraversalAND(ArrayRef cmp, ArrayRef eq, auto [_next_cmp, _next_eq] = basic_ot_prot_->CorrelatedBitwiseAnd(_rhs_eq, _lhs_cmp, _lhs_eq); - eq = ArrayRef(_next_eq.eltype(), nrow * ncol); - cmp = ArrayRef(_next_cmp.eltype(), nrow * ncol); + // Concat two ANDs + eq = NdArrayRef(eq.eltype(), {_next_cmp_col0.numel() + _next_cmp.numel()}); + cmp = NdArrayRef(cmp.eltype(), eq.shape()); - for (size_t r = 0; r < nrow; ++r) { + for (int64_t r = 0; r < nrow; ++r) { std::memcpy(&cmp.at(r * ncol), &_next_cmp_col0.at(r), cmp.elsize()); - for (size_t c = 1; c < ncol; ++c) { + for (int64_t c = 1; c < ncol; ++c) { std::memcpy(&cmp.at(r * ncol + c), &_next_cmp.at(r * (ncol - 1) + c - 1), _next_cmp.elsize()); @@ -285,22 +298,40 @@ ArrayRef CompareProtocol::TraversalAND(ArrayRef cmp, ArrayRef eq, } } - auto nd_cmp = toNdArray(cmp); - ring_xor_(nd_cmp, toNdArray(rhs_cmp)); + ring_xor_(cmp, rhs_cmp); } return cmp; } -ArrayRef CompareProtocol::Compute(const ArrayRef& inp, bool greater_than) { - return DoCompute(inp, greater_than, nullptr); +NdArrayRef CompareProtocol::Compute(const NdArrayRef& inp, bool greater_than, + int64_t bitwidth) { + int64_t bw = SizeOf(inp.eltype().as()->field()) * 8; + SPU_ENFORCE(bitwidth >= 0 && bitwidth <= bw, "bit_width={} out of bound", + bitwidth); + if (bitwidth == 0) { + bitwidth = bw; + } + // NOTE(lwj): reshape might need copy + auto flatten = inp.reshape({inp.numel()}); + return DoCompute(flatten, greater_than, nullptr, bitwidth) + .reshape(inp.shape()); } -std::array CompareProtocol::ComputeWithEq(const ArrayRef& inp, - bool greater_than) { - ArrayRef eq; - auto cmp = DoCompute(inp, greater_than, &eq); - return {cmp, eq}; +std::array CompareProtocol::ComputeWithEq(const NdArrayRef& inp, + bool greater_than, + int64_t bitwidth) { + int64_t bw = SizeOf(inp.eltype().as()->field()) * 8; + SPU_ENFORCE(bitwidth >= 0 && bitwidth <= bw, "bit_width={} out of bound", + bitwidth); + if (bitwidth == 0) { + bitwidth = bw; + } + NdArrayRef eq; + // NOTE(lwj): reshape might need copy + auto flatten = inp.reshape({inp.numel()}); + auto cmp = DoCompute(flatten, greater_than, &eq, bitwidth); + return {cmp.reshape(inp.shape()), eq.reshape(inp.shape())}; } } // namespace spu::mpc::cheetah diff --git a/libspu/mpc/cheetah/nonlinear/compare_prot.h b/libspu/mpc/cheetah/nonlinear/compare_prot.h index 460edc08..763be10a 100644 --- a/libspu/mpc/cheetah/nonlinear/compare_prot.h +++ b/libspu/mpc/cheetah/nonlinear/compare_prot.h @@ -16,7 +16,7 @@ #include -#include "libspu/mpc/cheetah/array_ref.h" +#include "libspu/core/ndarray_ref.h" namespace spu::mpc::cheetah { @@ -45,21 +45,26 @@ class CompareProtocol { ~CompareProtocol(); - ArrayRef Compute(const ArrayRef& inp, bool greater_than); + NdArrayRef Compute(const NdArrayRef& inp, bool greater_than, + int64_t bitwidth = 0); - std::array ComputeWithEq(const ArrayRef& inp, bool greater_than); + std::array ComputeWithEq(const NdArrayRef& inp, + bool greater_than, + int64_t bitwidth = 0); private: - ArrayRef DoCompute(const ArrayRef& inp, bool greater_than, - ArrayRef* eq = nullptr); + // Require 1D array + NdArrayRef DoCompute(const NdArrayRef& inp, bool greater_than, + NdArrayRef* eq = nullptr, int64_t bitwidth = 0); - ArrayRef TraversalAND(ArrayRef cmp, ArrayRef eq, size_t num_input, - size_t num_digits); - - std::array TraversalANDWithEq(ArrayRef cmp, ArrayRef eq, - size_t num_input, - size_t num_digits); + // Require 1D array + NdArrayRef TraversalAND(NdArrayRef cmp, NdArrayRef eq, size_t num_input, + size_t num_digits); + // Require 1D array + std::array TraversalANDWithEq(NdArrayRef cmp, NdArrayRef eq, + size_t num_input, + size_t num_digits); size_t compare_radix_; bool is_sender_{false}; std::shared_ptr basic_ot_prot_; diff --git a/libspu/mpc/cheetah/nonlinear/compare_prot_test.cc b/libspu/mpc/cheetah/nonlinear/compare_prot_test.cc index 4f151869..e7405d4a 100644 --- a/libspu/mpc/cheetah/nonlinear/compare_prot_test.cc +++ b/libspu/mpc/cheetah/nonlinear/compare_prot_test.cc @@ -18,9 +18,8 @@ #include "gtest/gtest.h" -#include "libspu/core/xt_helper.h" #include "libspu/mpc/cheetah/ot/basic_ot_prot.h" -#include "libspu/mpc/semi2k/type.h" +#include "libspu/mpc/cheetah/type.h" #include "libspu/mpc/utils/ring_ops.h" #include "libspu/mpc/utils/simulate.h" @@ -43,44 +42,111 @@ INSTANTIATE_TEST_SUITE_P( TEST_P(CompareProtTest, Basic) { size_t kWorldSize = 2; - int64_t n = 16; + Shape shape = {13, 2, 3}; FieldType field = std::get<0>(GetParam()); size_t radix = std::get<2>(GetParam()); bool greater_than = std::get<1>(GetParam()); NdArrayRef inp[2]; - inp[0] = ring_rand(field, {n}); - inp[1] = ring_rand(field, {n}); + inp[0] = ring_rand(field, shape); + inp[1] = ring_rand(field, shape); DISPATCH_ALL_FIELDS(field, "", [&]() { - auto xinp = xt_mutable_adapt(inp[0]); + auto xinp = NdArrayView(inp[0]); xinp[0] = 1; xinp[1] = 10; xinp[2] = 100; - xinp = xt_mutable_adapt(inp[1]); + xinp = NdArrayView(inp[1]); xinp[0] = 1; xinp[1] = 9; xinp[2] = 1000; }); - ArrayRef cmp_oup[2]; + NdArrayRef cmp_oup[2]; utils::simulate(kWorldSize, [&](std::shared_ptr ctx) { auto conn = std::make_shared(ctx); int rank = ctx->Rank(); auto base = std::make_shared(conn); CompareProtocol comp_prot(base, radix); - auto _c = comp_prot.Compute(flatten(inp[rank]), greater_than); + auto _c = comp_prot.Compute(inp[rank], greater_than); cmp_oup[rank] = _c; }); DISPATCH_ALL_FIELDS(field, "", [&]() { - auto xout0 = ArrayView(cmp_oup[0]); - auto xout1 = ArrayView(cmp_oup[1]); - auto xinp0 = xt_adapt(inp[0]); - auto xinp1 = xt_adapt(inp[1]); + auto xout0 = NdArrayView(cmp_oup[0]); + auto xout1 = NdArrayView(cmp_oup[1]); + auto xinp0 = NdArrayView(inp[0]); + auto xinp1 = NdArrayView(inp[1]); - for (int64_t i = 0; i < n; ++i) { + for (int64_t i = 0; i < shape.numel(); ++i) { + bool expected = greater_than ? xinp0[i] > xinp1[i] : xinp0[i] < xinp1[i]; + bool got_cmp = xout0[i] ^ xout1[i]; + ASSERT_EQ(expected, got_cmp); + } + }); +} + +TEST_P(CompareProtTest, SpecifiedBitwidth) { + size_t kWorldSize = 2; + FieldType field = std::get<0>(GetParam()); + size_t radix = std::get<2>(GetParam()); + bool greater_than = std::get<1>(GetParam()); + int64_t bw = 16; + + NdArrayRef inp[2]; + int64_t n = 1 << 18; + inp[0] = ring_rand(field, {n * 2}); + inp[1] = ring_rand(field, {n * 2}); + + DISPATCH_ALL_FIELDS(field, "", [&]() { + ring2k_t mask = (static_cast(1) << bw) - 1; + auto xinp = NdArrayView(inp[0]); + xinp[0] = 1; + xinp[1] = 10; + xinp[2] = 100; + pforeach(0, inp[0].numel(), [&](int64_t i) { xinp[i] &= mask; }); + + xinp = NdArrayView(inp[1]); + xinp[0] = 1; + xinp[1] = 9; + xinp[2] = 1000; + pforeach(0, inp[0].numel(), [&](int64_t i) { xinp[i] &= mask; }); + }); + + inp[0] = inp[0].reshape({n, 2}); + inp[1] = inp[1].reshape({n, 2}); + + NdArrayRef cmp_oup[2]; + utils::simulate(kWorldSize, [&](std::shared_ptr ctx) { + auto conn = std::make_shared(ctx); + int rank = ctx->Rank(); + auto base = std::make_shared(conn); + + CompareProtocol comp_prot(base, radix); + + [[maybe_unused]] auto b0 = ctx->GetStats()->sent_bytes.load(); + [[maybe_unused]] auto s0 = ctx->GetStats()->sent_actions.load(); + + auto _c = comp_prot.Compute(inp[rank], greater_than, bw); + + [[maybe_unused]] auto b1 = ctx->GetStats()->sent_bytes.load(); + [[maybe_unused]] auto s1 = ctx->GetStats()->sent_actions.load(); + + SPDLOG_DEBUG( + "Compare {} bits {} elements sent {} bytes, {} bits each #sent {}", bw, + inp[0].numel(), (b1 - b0), (b1 - b0) * 8. / inp[0].numel(), (s1 - s0)); + + cmp_oup[rank] = _c; + }); + + DISPATCH_ALL_FIELDS(field, "", [&]() { + auto xout0 = NdArrayView(cmp_oup[0]); + auto xout1 = NdArrayView(cmp_oup[1]); + auto xinp0 = NdArrayView(inp[0]); + auto xinp1 = NdArrayView(inp[1]); + + for (int64_t i = 0; i < inp[0].numel(); ++i) { bool expected = greater_than ? xinp0[i] > xinp1[i] : xinp0[i] < xinp1[i]; bool got_cmp = xout0[i] ^ xout1[i]; EXPECT_EQ(expected, got_cmp); @@ -90,48 +156,53 @@ TEST_P(CompareProtTest, Basic) { TEST_P(CompareProtTest, WithEq) { size_t kWorldSize = 2; - int64_t n = 10; + Shape shape = {10, 10, 10}; FieldType field = std::get<0>(GetParam()); size_t radix = std::get<2>(GetParam()); bool greater_than = std::get<1>(GetParam()); + NdArrayRef _inp[2]; + _inp[0] = ring_rand(field, shape); + _inp[1] = ring_rand(field, shape); + NdArrayRef inp[2]; - inp[0] = ring_rand(field, {n}); - inp[1] = ring_rand(field, {n}); + inp[0] = _inp[0].slice({0, 0, 0}, {10, 10, 10}, {2, 3, 2}); + inp[1] = _inp[0].slice({0, 0, 0}, {10, 10, 10}, {2, 3, 2}); + shape = inp[0].shape(); DISPATCH_ALL_FIELDS(field, "", [&]() { - auto xinp = xt_mutable_adapt(inp[0]); + auto xinp = NdArrayView(inp[0]); xinp[0] = 1; xinp[1] = 10; xinp[2] = 100; - xinp = xt_mutable_adapt(inp[1]); + xinp = NdArrayView(inp[1]); xinp[0] = 1; xinp[1] = 9; xinp[2] = 1000; }); - ArrayRef cmp_oup[2]; - ArrayRef eq_oup[2]; + NdArrayRef cmp_oup[2]; + NdArrayRef eq_oup[2]; utils::simulate(kWorldSize, [&](std::shared_ptr ctx) { auto conn = std::make_shared(ctx); int rank = ctx->Rank(); auto base = std::make_shared(conn); CompareProtocol comp_prot(base, radix); - auto [_c, _e] = comp_prot.ComputeWithEq(flatten(inp[rank]), greater_than); + auto [_c, _e] = comp_prot.ComputeWithEq(inp[rank], greater_than); cmp_oup[rank] = _c; eq_oup[rank] = _e; }); DISPATCH_ALL_FIELDS(field, "", [&]() { - auto xout0 = ArrayView(cmp_oup[0]); - auto xout1 = ArrayView(cmp_oup[1]); - auto xeq0 = ArrayView(eq_oup[0]); - auto xeq1 = ArrayView(eq_oup[1]); - auto xinp0 = xt_adapt(inp[0]); - auto xinp1 = xt_adapt(inp[1]); - - for (int64_t i = 0; i < n; ++i) { + auto xout0 = NdArrayView(cmp_oup[0]); + auto xout1 = NdArrayView(cmp_oup[1]); + auto xeq0 = NdArrayView(eq_oup[0]); + auto xeq1 = NdArrayView(eq_oup[1]); + auto xinp0 = NdArrayView(inp[0]); + auto xinp1 = NdArrayView(inp[1]); + + for (int64_t i = 0; i < shape.numel(); ++i) { bool expected = greater_than ? xinp0[i] > xinp1[i] : xinp0[i] < xinp1[i]; bool got_cmp = xout0[i] ^ xout1[i]; bool got_eq = xeq0[i] ^ xeq1[i]; diff --git a/libspu/mpc/cheetah/nonlinear/equal_prot.cc b/libspu/mpc/cheetah/nonlinear/equal_prot.cc index 5d339db4..21b1a23c 100644 --- a/libspu/mpc/cheetah/nonlinear/equal_prot.cc +++ b/libspu/mpc/cheetah/nonlinear/equal_prot.cc @@ -14,15 +14,15 @@ #include "libspu/mpc/cheetah/nonlinear/equal_prot.h" -#include "emp-tool/utils/prg.h" +#include "yacl/crypto/tools/prg.h" #include "yacl/link/link.h" #include "libspu/core/type.h" #include "libspu/mpc/cheetah/ot/basic_ot_prot.h" #include "libspu/mpc/cheetah/ot/ferret.h" #include "libspu/mpc/cheetah/ot/util.h" +#include "libspu/mpc/cheetah/type.h" #include "libspu/mpc/common/communicator.h" -#include "libspu/mpc/semi2k/type.h" #include "libspu/mpc/utils/ring_ops.h" namespace spu::mpc::cheetah { @@ -47,25 +47,29 @@ static void SetLeafOTMsg(absl::Span ot_messages, uint8_t digit, } } -ArrayRef EqualProtocol::Compute(const ArrayRef& inp) { +NdArrayRef EqualProtocol::DoCompute(const NdArrayRef& inp, size_t bit_width) { + SPU_ENFORCE(inp.shape().size() == 1, "need 1D array"); auto field = inp.eltype().as()->field(); - size_t bit_width = SizeOf(field) * 8; + if (bit_width == 0) { + bit_width = SizeOf(field) * 8; + } + bit_width = std::min(bit_width, SizeOf(field) * 8); - size_t remain = bit_width % compare_radix_; - size_t num_digits = CeilDiv(bit_width, compare_radix_); - size_t radix = static_cast(1) << compare_radix_; // one-of-N OT - size_t num_cmp = inp.numel(); + int64_t remain = bit_width % compare_radix_; + int64_t num_digits = CeilDiv(bit_width, compare_radix_); + int64_t radix = static_cast(1) << compare_radix_; // one-of-N OT + int64_t num_cmp = inp.numel(); // init to all zero std::vector digits(num_cmp * num_digits, 0); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, "Equal_break_digits", [&]() { using u2k = std::make_unsigned::type; const auto mask_radix = makeBitsMask(compare_radix_); const auto mask_remain = makeBitsMask(remain); - ArrayView xinp(inp); + NdArrayView xinp(inp); - for (size_t i = 0; i < num_cmp; ++i) { - for (size_t j = 0; j < num_digits; ++j) { + for (int64_t i = 0; i < num_cmp; ++i) { + for (int64_t j = 0; j < num_digits; ++j) { uint32_t shft = j * compare_radix_; digits[i * num_digits + j] = (xinp[i] >> shft) & mask_radix; // last digits @@ -78,8 +82,11 @@ ArrayRef EqualProtocol::Compute(const ArrayRef& inp) { std::vector leaf_eq(num_cmp * num_digits, 0); if (is_sender_) { - emp::PRG prg; - prg.random_bool(reinterpret_cast(leaf_eq.data()), leaf_eq.size()); + yacl::crypto::Prg prg; + prg.Fill(absl::MakeSpan(leaf_eq)); + // convert u8 random to boolean random + std::transform(leaf_eq.begin(), leaf_eq.end(), leaf_eq.data(), + [](uint8_t v) { return v & 1; }); // n*M instances of 1-of-N OT std::vector leaf_ot_msg(radix * num_cmp * num_digits, 0); @@ -87,16 +94,16 @@ ArrayRef EqualProtocol::Compute(const ArrayRef& inp) { std::vector > each_leaf_ot_msg(num_cmp * num_digits); for (size_t i = 0; i < each_leaf_ot_msg.size(); ++i) { each_leaf_ot_msg[i] = - absl::Span{leaf_ot_msg.data() + i * radix, radix}; + absl::MakeSpan(leaf_ot_msg.data() + i * radix, radix); } - for (size_t i = 0; i < num_cmp; ++i) { + for (int64_t i = 0; i < num_cmp; ++i) { auto* this_ot_msg = each_leaf_ot_msg.data() + i * num_digits; auto* this_digit = digits.data() + i * num_digits; auto* this_leaf_eq = leaf_eq.data() + i * num_digits; // Step 6, 7 of Alg1 in CF2's paper - for (size_t j = 0; j < num_digits; ++j) { + for (int64_t j = 0; j < num_digits; ++j) { uint8_t rnd_eq = this_leaf_eq[j] & 1; SetLeafOTMsg(this_ot_msg[j], this_digit[j], rnd_eq); } @@ -110,10 +117,9 @@ ArrayRef EqualProtocol::Compute(const ArrayRef& inp) { absl::MakeSpan(leaf_eq), 1); } - using BShrTy = semi2k::BShrTy; auto boolean_t = makeType(field, 1); - ArrayRef prev_eq = - flatten(ring_zeros(field, {static_cast(num_digits * num_cmp)})) + NdArrayRef prev_eq = + ring_zeros(field, {static_cast(num_digits * num_cmp)}) .as(boolean_t); // Transpose from msg-major order @@ -128,10 +134,10 @@ ArrayRef EqualProtocol::Compute(const ArrayRef& inp) { // m0[1], m1[1], ..., mN[1] // ... // m0[M], m1[M], ..., mN[M] - DISPATCH_ALL_FIELDS(field, "", [&]() { - ArrayView xprev_eq(prev_eq); - for (size_t r = 0; r < num_cmp; ++r) { - for (size_t c = 0; c < num_digits; ++c) { + DISPATCH_ALL_FIELDS(field, "Equal_transpose", [&]() { + NdArrayView xprev_eq(prev_eq); + for (int64_t r = 0; r < num_cmp; ++r) { + for (int64_t c = 0; c < num_digits; ++c) { xprev_eq[c * num_cmp + r] = leaf_eq[r * num_digits + c]; } } @@ -145,18 +151,19 @@ ArrayRef EqualProtocol::Compute(const ArrayRef& inp) { // || // \/ // eq[0], eq[1], ..., eq[N] - size_t current_num_digits = num_digits; + int64_t current_num_digits = num_digits; while (current_num_digits > 1) { // eq[i-1, j] <- eq[i, 2*j] * eq[i, 2*j+1] - size_t half_d = current_num_digits / 2; - auto lhs_eq = prev_eq.slice(0, half_d * num_cmp); - auto rhs_eq = prev_eq.slice(lhs_eq.numel(), lhs_eq.numel() * 2); + int64_t half_d = current_num_digits / 2; + auto lhs_eq = prev_eq.slice({0}, {half_d * num_cmp}, {1}); + auto rhs_eq = prev_eq.slice({lhs_eq.numel()}, {lhs_eq.numel() * 2}, {1}); SPU_ENFORCE_EQ(lhs_eq.numel(), rhs_eq.numel()); - size_t remain_d = current_num_digits - 2 * half_d; + int64_t remain_d = current_num_digits - 2 * half_d; auto next_eq = basic_ot_prot_->BitwiseAnd(lhs_eq, rhs_eq); if (remain_d > 0) { - ArrayRef tmp(prev_eq.eltype(), next_eq.numel() + remain_d * num_cmp); + NdArrayRef tmp(prev_eq.eltype(), {next_eq.numel() + remain_d * num_cmp}); + std::memcpy(&tmp.at(0), &next_eq.at(0), prev_eq.elsize() * next_eq.numel()); @@ -166,10 +173,15 @@ ArrayRef EqualProtocol::Compute(const ArrayRef& inp) { } else { prev_eq = next_eq; } - current_num_digits = CeilDiv(prev_eq.numel(), num_cmp); + current_num_digits = CeilDiv(prev_eq.numel(), num_cmp); } return prev_eq.as(boolean_t); } +NdArrayRef EqualProtocol::Compute(const NdArrayRef& inp, size_t bit_width) { + const auto& shape = inp.shape(); + return DoCompute(inp.reshape({inp.numel()}), bit_width).reshape(shape); +} + } // namespace spu::mpc::cheetah diff --git a/libspu/mpc/cheetah/nonlinear/equal_prot.h b/libspu/mpc/cheetah/nonlinear/equal_prot.h index 72962679..9dd93fd3 100644 --- a/libspu/mpc/cheetah/nonlinear/equal_prot.h +++ b/libspu/mpc/cheetah/nonlinear/equal_prot.h @@ -16,7 +16,7 @@ #include -#include "libspu/mpc/cheetah/array_ref.h" +#include "libspu/core/ndarray_ref.h" namespace spu::mpc::cheetah { @@ -43,9 +43,12 @@ class EqualProtocol { ~EqualProtocol(); - ArrayRef Compute(const ArrayRef& inp); + NdArrayRef Compute(const NdArrayRef& inp, size_t bit_width = 0); private: + // Need 1D Array + NdArrayRef DoCompute(const NdArrayRef& inp, size_t bit_width = 0); + size_t compare_radix_; bool is_sender_{false}; std::shared_ptr basic_ot_prot_; diff --git a/libspu/mpc/cheetah/nonlinear/equal_prot_test.cc b/libspu/mpc/cheetah/nonlinear/equal_prot_test.cc index d569b143..b3137cc3 100644 --- a/libspu/mpc/cheetah/nonlinear/equal_prot_test.cc +++ b/libspu/mpc/cheetah/nonlinear/equal_prot_test.cc @@ -20,7 +20,7 @@ #include "libspu/core/xt_helper.h" #include "libspu/mpc/cheetah/ot/basic_ot_prot.h" -#include "libspu/mpc/semi2k/type.h" +#include "libspu/mpc/cheetah/type.h" #include "libspu/mpc/utils/ring_ops.h" #include "libspu/mpc/utils/simulate.h" @@ -39,12 +39,12 @@ INSTANTIATE_TEST_SUITE_P( TEST_P(EqualProtTest, Basic) { size_t kWorldSize = 2; - int64_t n = 11; + Shape shape = {10, 11, 12}; FieldType field = GetParam(); NdArrayRef inp[2]; - inp[0] = ring_rand(field, {n}); - inp[1] = ring_rand(field, {n}); + inp[0] = ring_rand(field, shape); + inp[1] = ring_rand(field, shape); DISPATCH_ALL_FIELDS(field, "", [&]() { auto xinp0 = xt_mutable_adapt(inp[0]); @@ -52,22 +52,25 @@ TEST_P(EqualProtTest, Basic) { std::copy_n(xinp1.data(), 5, xinp0.data()); }); - ArrayRef eq_oup[2]; + NdArrayRef eq_oup[2]; utils::simulate(kWorldSize, [&](std::shared_ptr ctx) { auto conn = std::make_shared(ctx); int rank = ctx->Rank(); auto base = std::make_shared(conn); EqualProtocol eq_prot(base); - eq_oup[rank] = eq_prot.Compute(flatten(inp[rank])); + eq_oup[rank] = eq_prot.Compute(inp[rank]); }); + SPU_ENFORCE_EQ(eq_oup[0].shape(), shape); + SPU_ENFORCE_EQ(eq_oup[1].shape(), shape); + DISPATCH_ALL_FIELDS(field, "", [&]() { - auto xeq0 = ArrayView(eq_oup[0]); - auto xeq1 = ArrayView(eq_oup[1]); - auto xinp0 = xt_adapt(inp[0]); - auto xinp1 = xt_adapt(inp[1]); + auto xeq0 = NdArrayView(eq_oup[0]); + auto xeq1 = NdArrayView(eq_oup[1]); + auto xinp0 = NdArrayView(inp[0]); + auto xinp1 = NdArrayView(inp[1]); - for (int64_t i = 0; i < n; ++i) { + for (int64_t i = 0; i < shape.numel(); ++i) { bool expected = xinp0[i] == xinp1[i]; bool got_eq = xeq0[i] ^ xeq1[i]; EXPECT_EQ(expected, got_eq); diff --git a/libspu/mpc/cheetah/nonlinear/truncate_prot.cc b/libspu/mpc/cheetah/nonlinear/truncate_prot.cc index 4f673571..de56a866 100644 --- a/libspu/mpc/cheetah/nonlinear/truncate_prot.cc +++ b/libspu/mpc/cheetah/nonlinear/truncate_prot.cc @@ -17,7 +17,7 @@ #include "libspu/mpc/cheetah/nonlinear/compare_prot.h" #include "libspu/mpc/cheetah/ot/basic_ot_prot.h" #include "libspu/mpc/cheetah/ot/util.h" -#include "libspu/mpc/semi2k/type.h" +#include "libspu/mpc/cheetah/type.h" #include "libspu/mpc/utils/ring_ops.h" namespace spu::mpc::cheetah { @@ -29,11 +29,12 @@ TruncateProtocol::TruncateProtocol(std::shared_ptr base) TruncateProtocol::~TruncateProtocol() { basic_ot_prot_->Flush(); } -ArrayRef TruncateProtocol::ComputeWrap(const ArrayRef& inp, const Meta& meta) { +NdArrayRef TruncateProtocol::ComputeWrap(const NdArrayRef& inp, + const Meta& meta) { const int rank = basic_ot_prot_->Rank(); - switch (meta.msb) { - case MSB_st::zero: { + switch (meta.sign) { + case SignType::Positive: { if (!meta.signed_arith) { // without sign flip return MSB0ToWrap(inp, meta.shift_bits); @@ -43,7 +44,7 @@ ArrayRef TruncateProtocol::ComputeWrap(const ArrayRef& inp, const Meta& meta) { } break; } - case MSB_st::one: { + case SignType::Negative: { if (!meta.signed_arith) { // without sign flip return MSB1ToWrap(inp, meta.shift_bits); @@ -53,25 +54,25 @@ ArrayRef TruncateProtocol::ComputeWrap(const ArrayRef& inp, const Meta& meta) { } break; } - case MSB_st::unknown: + case SignType::Unknown: default: { CompareProtocol compare_prot(basic_ot_prot_); - ArrayRef wrap_bool; + NdArrayRef wrap_bool; // w = 1{x_A + x_B > 2^k - 1} // = 1{x_A > 2^k - 1 - x_B} const auto field = inp.eltype().as()->field(); if (rank == 0) { wrap_bool = compare_prot.Compute(inp, true); } else { - auto adjusted = flatten(ring_neg(toNdArray(inp))); - DISPATCH_ALL_FIELDS(field, "", [&]() { - ArrayView xadj(adjusted); + auto adjusted = ring_neg(inp); + DISPATCH_ALL_FIELDS(field, "wrap_adjust", [&]() { + NdArrayView xadj(adjusted); pforeach(0, inp.numel(), [&](int64_t i) { xadj[i] -= 1; }); }); wrap_bool = compare_prot.Compute(adjusted, true); } return basic_ot_prot_->B2ASingleBitWithSize( - wrap_bool.as(makeType(field, 1)), meta.shift_bits); + wrap_bool.as(makeType(field, 1)), meta.shift_bits); break; } } @@ -83,17 +84,18 @@ ArrayRef TruncateProtocol::ComputeWrap(const ArrayRef& inp, const Meta& meta) { // COT msg corr=msb(xA) on choice msb(xB) // - msb(xB) = 0: get(-x, x) => 0 // - msb(xB) = 1: get(-x, x + msb(xA)) => msb(xA) -ArrayRef TruncateProtocol::MSB1ToWrap(const ArrayRef& inp, size_t shift_bits) { +NdArrayRef TruncateProtocol::MSB1ToWrap(const NdArrayRef& inp, + size_t shift_bits) { const auto field = inp.eltype().as()->field(); const int64_t numel = inp.numel(); const int rank = basic_ot_prot_->Rank(); const size_t bw = SizeOf(field) * 8; - ArrayRef cot_output = flatten(ring_zeros(field, {numel})); - DISPATCH_ALL_FIELDS(field, "", [&]() { + NdArrayRef cot_output = ring_zeros(field, inp.shape()); + DISPATCH_ALL_FIELDS(field, "MSB1ToWrap", [&]() { using u2k = std::make_unsigned::type; - ArrayView xinp(inp); - ArrayView xout(cot_output); + NdArrayView xinp(inp); + auto xout = absl::MakeSpan(&cot_output.at(0), cot_output.numel()); if (rank == 0) { std::vector cot_input(numel); @@ -101,8 +103,7 @@ ArrayRef TruncateProtocol::MSB1ToWrap(const ArrayRef& inp, size_t shift_bits) { [&](int64_t i) { cot_input[i] = ((xinp[i] >> (bw - 1)) & 1); }); auto sender = basic_ot_prot_->GetSenderCOT(); - sender->SendCAMCC(absl::MakeSpan(cot_input), - {xout.data(), (size_t)xout.numel()}, shift_bits); + sender->SendCAMCC(absl::MakeSpan(cot_input), xout, shift_bits); sender->Flush(); pforeach(0, numel, [&](int64_t i) { xout[i] = -xout[i]; }); } else { @@ -110,13 +111,12 @@ ArrayRef TruncateProtocol::MSB1ToWrap(const ArrayRef& inp, size_t shift_bits) { pforeach(0, numel, [&](int64_t i) { cot_input[i] = ((xinp[i] >> (bw - 1)) & 1); }); - basic_ot_prot_->GetReceiverCOT()->RecvCAMCC( - absl::MakeSpan(cot_input), {xout.data(), (size_t)xout.numel()}, - shift_bits); + basic_ot_prot_->GetReceiverCOT()->RecvCAMCC(absl::MakeSpan(cot_input), + xout, shift_bits); } }); - return cot_output.as(makeType(field, 1)); + return cot_output.as(makeType(field, 1)); } // Given msb(xA + xB mod 2^k) = 0, and xA, xB \in [0, 2^k) @@ -132,7 +132,8 @@ ArrayRef TruncateProtocol::MSB1ToWrap(const ArrayRef& inp, size_t shift_bits) { // 1-of-2 OT msg (r^msb(xA), r^1) on choice msb(xB) // - msb(xB) = 0: get (r, r^msb(xA)) => msb(xA) // - msb(xB) = 1: get (r, r^1) => 1 -ArrayRef TruncateProtocol::MSB0ToWrap(const ArrayRef& inp, size_t shift_bits) { +NdArrayRef TruncateProtocol::MSB0ToWrap(const NdArrayRef& inp, + size_t shift_bits) { const auto field = inp.eltype().as()->field(); const int64_t numel = inp.numel(); const int rank = basic_ot_prot_->Rank(); @@ -141,15 +142,15 @@ ArrayRef TruncateProtocol::MSB0ToWrap(const ArrayRef& inp, size_t shift_bits) { constexpr size_t N = 2; // 1-of-2 OT constexpr size_t nbits = 1; - ArrayRef outp; + NdArrayRef outp; if (0 == rank) { - outp = flatten(ring_randbit(field, {numel})); + outp = ring_randbit(field, inp.shape()); std::vector send(numel * N); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, "MSB0_adjust", [&]() { using u2k = std::make_unsigned::type; - ArrayView xinp(inp); - ArrayView xrnd(outp); + NdArrayView xinp(inp); + NdArrayView xrnd(outp); // when msb(xA) = 0, set (r, 1^r) // ow. msb(xA) = 1, set (1^r, 1^r) // Equals to (r^msb(xA), r^1) @@ -164,9 +165,9 @@ ArrayRef TruncateProtocol::MSB0ToWrap(const ArrayRef& inp, size_t shift_bits) { sender->Flush(); } else { std::vector choices(numel, 0); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, "MSB0_adjust", [&]() { using u2k = std::make_unsigned::type; - ArrayView xinp(inp); + NdArrayView xinp(inp); for (int64_t i = 0; i < numel; ++i) { choices[i] = (xinp[i] >> (bw - 1)) & 1; } @@ -176,9 +177,9 @@ ArrayRef TruncateProtocol::MSB0ToWrap(const ArrayRef& inp, size_t shift_bits) { basic_ot_prot_->GetReceiverCOT()->RecvCMCC(absl::MakeSpan(choices), N, absl::MakeSpan(recv), nbits); - outp = flatten(ring_zeros(field, {numel})); - DISPATCH_ALL_FIELDS(field, "", [&]() { - ArrayView xoup(outp); + outp = ring_zeros(field, inp.shape()); + DISPATCH_ALL_FIELDS(field, "MSB0_finalize", [&]() { + NdArrayView xoup(outp); pforeach(0, numel, [&](int64_t i) { xoup[i] = static_cast(recv[i] & 1); }); @@ -186,12 +187,15 @@ ArrayRef TruncateProtocol::MSB0ToWrap(const ArrayRef& inp, size_t shift_bits) { } return basic_ot_prot_->B2ASingleBitWithSize( - outp.as(makeType(field, 1)), (int)shift_bits); + outp.as(makeType(field, 1)), static_cast(shift_bits)); } -ArrayRef TruncateProtocol::Compute(const ArrayRef& inp, Meta meta) { +NdArrayRef TruncateProtocol::Compute(const NdArrayRef& inp, Meta meta) { size_t shift = meta.shift_bits; - if (shift == 0) return inp; + if (shift == 0) { + return inp; + } + auto field = inp.eltype().as()->field(); const size_t bit_width = SizeOf(field) * 8; SPU_ENFORCE(shift < bit_width, "truncate should not truncate full bit width"); @@ -202,31 +206,67 @@ ArrayRef TruncateProtocol::Compute(const ArrayRef& inp, Meta meta) { const int rank = basic_ot_prot_->Rank(); - ArrayRef wrap_ashr; - ArrayRef out = flatten(ring_zeros(field, {inp.numel()})); + if (meta.signed_arith && meta.sign == SignType::Unknown && + meta.use_heuristic) { + // Use heuristic optimization from SecureQ8: Add a large positive to make + // sure the value is always positive + // We assume |x| < 2^{k - b - 1} + // 1. x' = x + 2^{k - b} (should no wrap round 2^k) + // 2. y = TruncMSB0(x' ,f) ie y = (x + 2^{k - b}) / 2^f + // 3. output y - 2^{k - b - f} + meta.use_heuristic = false; + meta.sign = SignType::Positive; + + if (rank == 0) { + NdArrayRef tmp = inp.clone(); + DISPATCH_ALL_FIELDS(field, "trunc_with_heuristic", [&] { + NdArrayView _inp(tmp); + ring2k_t big_value = static_cast(1) + << (bit_width - kHeuristicBound); + pforeach(0, inp.numel(), + [&](int64_t i) { _inp[i] = _inp[i] + big_value; }); + }); + + tmp = Compute(tmp, meta); - return DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, "trunc_with_heuristic", [&] { + NdArrayView _outp(tmp); + ring2k_t big_value = static_cast(1) + << (bit_width - kHeuristicBound - shift); + pforeach(0, inp.numel(), + [&](int64_t i) { _outp[i] = _outp[i] - big_value; }); + }); + return tmp; + } else { + return Compute(inp, meta); + } + } + + NdArrayRef wrap_ashr; + NdArrayRef out = ring_zeros(field, inp.shape()); + + return DISPATCH_ALL_FIELDS(field, "Truncate", [&]() { const ring2k_t component = (static_cast(1) << (bit_width - 1)); - ArrayView xinp(inp); + NdArrayView xinp(inp); // Compute w = 1{x0 + x1 >= 2^{k}} if (meta.signed_arith && rank == 0) { // For signed arith right shift, we convert to unsigned logic right shift // by convert to two-component form. - auto tmp = flatten(ring_zeros(field, {inp.numel()})); - ArrayView xtmp(tmp); + auto tmp = ring_zeros(field, inp.shape()); + NdArrayView xtmp(tmp); pforeach(0, inp.numel(), [&](int64_t i) { xtmp[i] = xinp[i] + component; }); wrap_ashr = ComputeWrap(tmp, meta); } else { wrap_ashr = ComputeWrap(inp, meta); } - ArrayView xwrap(wrap_ashr); + NdArrayView xwrap(wrap_ashr); - // NOTE(juhou) We need logic right shift here + // NOTE(lwj) We need logic right shift here /// m' = (m >> shift) - wrap * 2^{k - shift} // [m']_A = (m0 >> shift) - [wrap]_A * 2^{k - shift} - ArrayView xout(out); + NdArrayView xout(out); if (meta.signed_arith && rank == 0) { pforeach(0, inp.numel(), [&](int64_t i) { xout[i] = ((xinp[i] + component) >> shift); @@ -243,6 +283,14 @@ ArrayRef TruncateProtocol::Compute(const ArrayRef& inp, Meta meta) { pforeach(0, inp.numel(), [&](int64_t i) { xout[i] -= u; }); } + if (rank == 0) { + // The origin Truncate introduce -1 error by 50%. + // We balance it by +1 at the rate of 50%. + // As a result, we introduce 0 error by 50%, +1 error by 25% and -1 error + // by 25%. + pforeach(0, inp.numel(), [&](int64_t i) { xout[i] += (xout[i] & 1); }); + } + basic_ot_prot_->Flush(); return out.as(inp.eltype()); }); diff --git a/libspu/mpc/cheetah/nonlinear/truncate_prot.h b/libspu/mpc/cheetah/nonlinear/truncate_prot.h index 581bedac..b2487d77 100644 --- a/libspu/mpc/cheetah/nonlinear/truncate_prot.h +++ b/libspu/mpc/cheetah/nonlinear/truncate_prot.h @@ -16,7 +16,8 @@ #include -#include "libspu/mpc/cheetah/array_ref.h" +#include "libspu/core/ndarray_ref.h" +#include "libspu/core/type_util.h" namespace spu::mpc::cheetah { @@ -35,41 +36,33 @@ class BasicOTProtocols; // where w = 1{x0 + x1 > 2^{k} - 1} indicates whether the sum wrap round 2^k class TruncateProtocol { public: - enum class MSB_st { - zero, - one, - unknown, - }; + // For x \in [-2^{k - 2}, 2^{k - 2}) + // 0 <= x + 2^{k - 2} < 2^{k - 1} ie the MSB is always positive + static constexpr size_t kHeuristicBound = 2; struct Meta { - MSB_st msb; - bool use_heuristic; // not implemented yet - bool signed_arith; - size_t shift_bits; - - Meta() - : msb(MSB_st::unknown), - use_heuristic(false), - signed_arith(true), - shift_bits(0) {} + SignType sign = SignType::Unknown; + bool use_heuristic = false; + bool signed_arith = true; + size_t shift_bits = 0; }; explicit TruncateProtocol(std::shared_ptr base); ~TruncateProtocol(); - ArrayRef Compute(const ArrayRef &inp, Meta meta); + NdArrayRef Compute(const NdArrayRef &inp, Meta meta); private: - ArrayRef ComputeWrap(const ArrayRef &inp, const Meta &meta); + NdArrayRef ComputeWrap(const NdArrayRef &inp, const Meta &meta); // w = msbA | msbB - ArrayRef MSB0ToWrap(const ArrayRef &inp, size_t shift_bits); + NdArrayRef MSB0ToWrap(const NdArrayRef &inp, size_t shift_bits); // w = msbA & msbB - ArrayRef MSB1ToWrap(const ArrayRef &inp, size_t shift_bits); + NdArrayRef MSB1ToWrap(const NdArrayRef &inp, size_t shift_bits); std::shared_ptr basic_ot_prot_{nullptr}; }; -} // namespace spu::mpc::cheetah +} // namespace spu::mpc::cheetah \ No newline at end of file diff --git a/libspu/mpc/cheetah/nonlinear/truncate_prot_test.cc b/libspu/mpc/cheetah/nonlinear/truncate_prot_test.cc index 5907a19d..ba0bdad0 100644 --- a/libspu/mpc/cheetah/nonlinear/truncate_prot_test.cc +++ b/libspu/mpc/cheetah/nonlinear/truncate_prot_test.cc @@ -18,9 +18,8 @@ #include "gtest/gtest.h" -#include "libspu/core/xt_helper.h" #include "libspu/mpc/cheetah/ot/basic_ot_prot.h" -#include "libspu/mpc/semi2k/type.h" +#include "libspu/mpc/cheetah/type.h" #include "libspu/mpc/utils/ring_ops.h" #include "libspu/mpc/utils/simulate.h" @@ -56,78 +55,77 @@ TEST_P(TruncateProtTest, Basic) { FieldType field = std::get<0>(GetParam()); bool signed_arith = std::get<1>(GetParam()); std::string msb = std::get<2>(GetParam()); - TruncateProtocol::MSB_st msb_t; + SignType sign; NdArrayRef inp[2]; inp[0] = ring_rand(field, {n}); if (msb == "Unknown") { inp[1] = ring_rand(field, {n}); - msb_t = TruncateProtocol::MSB_st::unknown; + sign = SignType::Unknown; } else { auto msg = ring_rand(field, {n}); DISPATCH_ALL_FIELDS(field, "", [&]() { - auto xmsg = xt_mutable_adapt(msg); + auto xmsg = NdArrayView(msg); size_t bw = SizeOf(field) * 8; if (msb == "Zero") { ring2k_t mask = (static_cast(1) << (bw - 1)) - 1; - xmsg &= mask; - msb_t = TruncateProtocol::MSB_st::zero; + pforeach(0, msg.numel(), [&](int64_t i) { xmsg[i] &= mask; }); + sign = SignType::Positive; } else { ring2k_t mask = (static_cast(1) << (bw - 1)); - xmsg |= mask; - msb_t = TruncateProtocol::MSB_st::one; + pforeach(0, msg.numel(), [&](int64_t i) { xmsg[i] |= mask; }); + sign = SignType::Negative; } }); inp[1] = ring_sub(msg, inp[0]); } - ArrayRef oup[2]; + NdArrayRef oup[2]; utils::simulate(kWorldSize, [&](std::shared_ptr ctx) { int rank = ctx->Rank(); auto conn = std::make_shared(ctx); auto base = std::make_shared(conn); TruncateProtocol trunc_prot(base); TruncateProtocol::Meta meta; - meta.msb = msb_t; + meta.sign = sign; meta.signed_arith = signed_arith; meta.shift_bits = shift; - oup[rank] = trunc_prot.Compute(flatten(inp[rank]), meta); + oup[rank] = trunc_prot.Compute(inp[rank], meta); }); + EXPECT_EQ(oup[0].shape(), oup[1].shape()); DISPATCH_ALL_FIELDS(field, "", [&]() { using signed_t = std::make_signed::type; using usigned_t = std::make_unsigned::type; if (signed_arith) { - auto xout0 = ArrayView(oup[0]); - auto xout1 = ArrayView(oup[1]); - auto xinp0 = xt_adapt(inp[0]); - auto xinp1 = xt_adapt(inp[1]); + auto xout0 = NdArrayView(oup[0]); + auto xout1 = NdArrayView(oup[1]); + auto xinp0 = absl::MakeSpan(&inp[0].at(0), inp[0].numel()); + auto xinp1 = absl::MakeSpan(&inp[1].at(0), inp[1].numel()); for (int64_t i = 0; i < n; ++i) { signed_t in = xinp0[i] + xinp1[i]; signed_t expected = in >> shift; - if (msb_t != TruncateProtocol::MSB_st::unknown) { - ASSERT_EQ(SignBit(in), - msb_t == TruncateProtocol::MSB_st::one); + if (sign != SignType::Unknown) { + ASSERT_EQ(SignBit(in), sign == SignType::Negative); } signed_t got = xout0[i] + xout1[i]; EXPECT_NEAR(expected, got, 1); } } else { - auto xout0 = ArrayView(oup[0]); - auto xout1 = ArrayView(oup[1]); - auto xinp0 = xt_adapt(inp[0]); - auto xinp1 = xt_adapt(inp[1]); + auto xout0 = NdArrayView(oup[0]); + auto xout1 = NdArrayView(oup[1]); + auto xinp0 = absl::MakeSpan(&inp[0].at(0), inp[0].numel()); + auto xinp1 = absl::MakeSpan(&inp[1].at(0), inp[1].numel()); for (int64_t i = 0; i < n; ++i) { usigned_t in = xinp0[i] + xinp1[i]; usigned_t expected = (in) >> shift; - if (msb_t != TruncateProtocol::MSB_st::unknown) { - ASSERT_EQ(SignBit(in), - msb_t == TruncateProtocol::MSB_st::one); + if (sign != SignType::Unknown) { + ASSERT_EQ(SignBit(in), sign == SignType::Negative); } usigned_t got = xout0[i] + xout1[i]; ASSERT_NEAR(expected, got, 1); @@ -136,4 +134,72 @@ TEST_P(TruncateProtTest, Basic) { }); } +TEST_P(TruncateProtTest, Heuristic) { + size_t kWorldSize = 2; + int64_t n = 100; + size_t shift = 13; + FieldType field = std::get<0>(GetParam()); + bool signed_arith = std::get<1>(GetParam()); + std::string msb = std::get<2>(GetParam()); + if (not signed_arith or msb != "Unknown") { + return; + } + + NdArrayRef inp[2]; + inp[0] = ring_rand(field, {n}); + + DISPATCH_ALL_FIELDS(field, "", [&]() { + auto msg = ring_rand(field, {n}); + ring_rshift_(msg, TruncateProtocol::kHeuristicBound); + NdArrayView xmsg(msg); + for (int64_t i = 0; i < n; i += 2) { + xmsg[i] = -xmsg[i]; + } + inp[1] = ring_sub(msg, inp[0]); + }); + + NdArrayRef oup[2]; + utils::simulate(kWorldSize, [&](std::shared_ptr ctx) { + int rank = ctx->Rank(); + auto conn = std::make_shared(ctx); + auto base = std::make_shared(conn); + TruncateProtocol trunc_prot(base); + TruncateProtocol::Meta meta; + meta.sign = SignType::Unknown; + meta.signed_arith = true; + meta.shift_bits = shift; + meta.use_heuristic = true; + oup[rank] = trunc_prot.Compute(inp[rank], meta); + }); + + [[maybe_unused]] int count_zero = 0; + [[maybe_unused]] int count_pos = 0; + [[maybe_unused]] int count_neg = 0; + DISPATCH_ALL_FIELDS(field, "", [&]() { + using signed_t = std::make_signed::type; + + auto xout0 = NdArrayView(oup[0]); + auto xout1 = NdArrayView(oup[1]); + auto xinp0 = NdArrayView(inp[0]); + auto xinp1 = NdArrayView(inp[1]); + + for (int64_t i = 0; i < n; ++i) { + signed_t in = xinp0[i] + xinp1[i]; + signed_t expected = in >> shift; + signed_t got = xout0[i] + xout1[i]; + EXPECT_NEAR(expected, got, 1); + if (expected == got) { + count_zero += 1; + } else if (expected < got) { + count_pos += 1; + } else { + count_neg += 1; + } + } + }); + + // printf("0 %d %f, +1 %d %f -1 %d %f\n", count_zero, count_zero * 1. / n, + // count_pos, count_pos * 1. / n, count_neg, count_neg * 1. / n); +} + } // namespace spu::mpc::cheetah diff --git a/libspu/mpc/cheetah/ot/BUILD.bazel b/libspu/mpc/cheetah/ot/BUILD.bazel index b64a64ae..2f6e94c8 100644 --- a/libspu/mpc/cheetah/ot/BUILD.bazel +++ b/libspu/mpc/cheetah/ot/BUILD.bazel @@ -32,8 +32,9 @@ spu_cc_library( ], copts = AES_COPT_FLAGS + ["-Wno-ignored-attributes"], deps = [ - "//libspu/mpc/cheetah:array_ref", - "//libspu/mpc/semi2k:conversion", + "//libspu/core:xt_helper", + "//libspu/mpc/cheetah:type", + "//libspu/mpc/common:communicator", "@com_github_emptoolkit_emp_ot//:emp-ot", "@com_github_emptoolkit_emp_tool//:emp-tool", "@yacl//yacl/base:int128", diff --git a/libspu/mpc/cheetah/ot/basic_ot_prot.cc b/libspu/mpc/cheetah/ot/basic_ot_prot.cc index c7a2cbf0..293ebac0 100644 --- a/libspu/mpc/cheetah/ot/basic_ot_prot.cc +++ b/libspu/mpc/cheetah/ot/basic_ot_prot.cc @@ -15,14 +15,12 @@ #include "libspu/mpc/cheetah/ot/basic_ot_prot.h" #include "libspu/mpc/cheetah/ot/util.h" +#include "libspu/mpc/cheetah/type.h" #include "libspu/mpc/common/communicator.h" -#include "libspu/mpc/semi2k/type.h" #include "libspu/mpc/utils/ring_ops.h" namespace spu::mpc::cheetah { -using BShrTy = semi2k::BShrTy; - BasicOTProtocols::BasicOTProtocols(std::shared_ptr conn) : conn_(std::move(conn)) { SPU_ENFORCE(conn_ != nullptr); @@ -43,7 +41,7 @@ void BasicOTProtocols::Flush() { } } -ArrayRef BasicOTProtocols::B2A(const ArrayRef &inp) { +NdArrayRef BasicOTProtocols::B2A(const NdArrayRef &inp) { const auto *share_t = inp.eltype().as(); if (share_t->nbits() == 1) { return SingleB2A(inp); @@ -51,67 +49,57 @@ ArrayRef BasicOTProtocols::B2A(const ArrayRef &inp) { return PackedB2A(inp); } -// Random bit r \in {0, 1} and return as AShr -ArrayRef BasicOTProtocols::RandBits(FieldType filed, size_t numel) { - // TODO(juhou): profile ring_randbit performance - auto r = ring_randbit(filed, {(int64_t)numel}).as(makeType(filed, 1)); - return SingleB2A(flatten(r)); -} - -ArrayRef BasicOTProtocols::PackedB2A(const ArrayRef &inp) { +NdArrayRef BasicOTProtocols::PackedB2A(const NdArrayRef &inp) { const auto *share_t = inp.eltype().as(); auto field = inp.eltype().as()->field(); const size_t nbits = share_t->nbits(); SPU_ENFORCE(nbits > 0 && nbits <= 8 * SizeOf(field)); - auto convert_from_bits_form = [&](ArrayRef bform) { - const int64_t n = bform.numel() / nbits; + auto convert_from_bits_form = [&](NdArrayRef _bits) { + SPU_ENFORCE(_bits.isCompact(), "need compact input"); + const int64_t n = _bits.numel() / nbits; // init as all 0s. - auto iform = flatten(ring_zeros(field, {n})); - DISPATCH_ALL_FIELDS(field, "", [&]() { - auto xb = ArrayView(bform); - auto xi = ArrayView(iform); - SPU_ENFORCE(xb.isCompact()); + auto iform = ring_zeros(field, inp.shape()); + DISPATCH_ALL_FIELDS(field, "conv_to_bits", [&]() { + auto bits = NdArrayView(_bits); + auto digit = NdArrayView(iform); for (int64_t i = 0; i < n; ++i) { // LSB is bits[0]; MSB is bits[nbits - 1] - // We use reverse_iterator to iterate the bits. - auto bits = xb.data() + i * nbits; - std::reverse_iterator rbits(bits + nbits); - xi[i] = std::accumulate(rbits, rbits + nbits, - /*init*/ static_cast(0), - [](ring2k_t init, ring2k_t next) { - return (init << 1) | (next & 1); - }); + // We iterate the bits in reversed order + const size_t offset = i * nbits; + digit[i] = 0; + for (size_t j = nbits; j > 0; --j) { + digit[i] = (digit[i] << 1) | (bits[offset + j - 1] & 1); + } } }); return iform; }; const int64_t n = inp.numel(); - auto rand_bits = RandBits(field, n * nbits); + auto rand_bits = RandBits(field, Shape{n * static_cast(nbits)}); auto rand = convert_from_bits_form(rand_bits); // open c = x ^ r // FIXME(juhou): Actually, we only want to exchange the low-end bits. - auto opened = flatten( - conn_->allReduce(ReduceOp::XOR, ring_xor(toNdArray(inp), toNdArray(rand)), - "B2AFull_open")); + auto opened = + conn_->allReduce(ReduceOp::XOR, ring_xor(inp, rand), "B2AFull_open"); // compute c + (1 - 2*c)* - ArrayRef oup = flatten(ring_zeros(field, {n})); - DISPATCH_ALL_FIELDS(field, "", [&]() { + NdArrayRef oup = ring_zeros(field, inp.shape()); + DISPATCH_ALL_FIELDS(field, "packed_b2a", [&]() { using u2k = std::make_unsigned::type; int rank = Rank(); - auto xr = ArrayView(rand_bits); - auto xc = ArrayView(opened); - auto xo = ArrayView(oup); + auto xr = NdArrayView(rand_bits); + auto xc = NdArrayView(opened); + auto xo = NdArrayView(oup); for (int64_t i = 0; i < n; ++i) { - auto rbits = xr.data() + i * nbits; + const size_t offset = i * nbits; u2k this_elt = xc[i]; for (size_t j = 0; j < nbits; ++j, this_elt >>= 1) { u2k c_ij = this_elt & 1; - ring2k_t one_bit = (1 - c_ij * 2) * rbits[j]; + ring2k_t one_bit = (1 - c_ij * 2) * xr[offset + j]; if (rank == 0) { one_bit += c_ij; } @@ -122,26 +110,16 @@ ArrayRef BasicOTProtocols::PackedB2A(const ArrayRef &inp) { return oup; } -ArrayRef BasicOTProtocols::B2ASingleBitWithSize(const ArrayRef &inp, - int bit_width) { - const auto *share_t = inp.eltype().as(); - SPU_ENFORCE(share_t->nbits() == 1, "Support for 1bit boolean only"); - auto field = inp.eltype().as()->field(); - SPU_ENFORCE(bit_width > 1 && bit_width < (int)(8 * SizeOf(field)), - "bit_width={} is invalid", bit_width); - return SingleB2A(inp, bit_width); -} - -ArrayRef BasicOTProtocols::SingleB2A(const ArrayRef &inp, int bit_width) { - // Math: - // b0^b1 = b0 + b1 - 2*b0*b1 - // Sender set corr = -2*b0 - // Recv set choice b1 - // Sender gets x0 - // Recv gets x1 = x0 + corr*b1 = x0 - 2*b0*b1 - // - // b0 - x0 + b1 + x1 - // = b0 - x0 + b1 + x0 - 2*b0*b1 +// Math: +// b0^b1 = b0 + b1 - 2*b0*b1 +// Sender set corr = -2*b0 +// Recv set choice b1 +// Sender gets x0 +// Recv gets x1 = x0 + corr*b1 = x0 - 2*b0*b1 +// +// b0 - x0 + b1 + x1 +// = b0 - x0 + b1 + x0 - 2*b0*b1 +NdArrayRef BasicOTProtocols::SingleB2A(const NdArrayRef &inp, int bit_width) { const auto *share_t = inp.eltype().as(); SPU_ENFORCE_EQ(share_t->nbits(), 1UL); auto field = inp.eltype().as()->field(); @@ -150,177 +128,188 @@ ArrayRef BasicOTProtocols::SingleB2A(const ArrayRef &inp, int bit_width) { } const int64_t n = inp.numel(); - ArrayRef oup = flatten(ring_zeros(field, {n})); - DISPATCH_ALL_FIELDS(field, "", [&]() { + NdArrayRef oup = ring_zeros(field, inp.shape()); + DISPATCH_ALL_FIELDS(field, "single_b2a", [&]() { using u2k = std::make_unsigned::type; - auto xinp = ArrayView(inp); - auto xoup = ArrayView(oup); - SPU_ENFORCE(xoup.isCompact()); + auto input = NdArrayView(inp); + // NOTE(lwj): oup is compact, so we just use Span + auto output = absl::MakeSpan(&oup.at(0), n); + + SPU_ENFORCE(oup.isCompact()); if (Rank() == 0) { std::vector corr_data(n); - // NOTE(juhou): Masking to make sure there is only single bit. + // NOTE(lwj): Masking to make sure there is only single bit. for (int64_t i = 0; i < n; ++i) { // corr=-2*xi - corr_data[i] = -((xinp[i] & 1) << 1); + corr_data[i] = -((input[i] & 1) << 1); } - ferret_sender_->SendCAMCC(absl::MakeSpan(corr_data), - {xoup.data(), (size_t)n}, bit_width); + ferret_sender_->SendCAMCC(absl::MakeSpan(corr_data), output, bit_width); ferret_sender_->Flush(); for (int64_t i = 0; i < n; ++i) { - xoup[i] = (xinp[i] & 1) - xoup[i]; + output[i] = (input[i] & 1) - output[i]; } } else { std::vector choices(n); for (int64_t i = 0; i < n; ++i) { - choices[i] = static_cast(xinp[i] & 1); + choices[i] = static_cast(input[i] & 1); } - ferret_receiver_->RecvCAMCC(absl::MakeSpan(choices), - {xoup.data(), (size_t)n}, bit_width); + ferret_receiver_->RecvCAMCC(absl::MakeSpan(choices), output, bit_width); for (int64_t i = 0; i < n; ++i) { - xoup[i] = (xinp[i] & 1) + xoup[i]; + output[i] = (input[i] & 1) + output[i]; } } }); return oup; } -ArrayRef BasicOTProtocols::BitwiseAnd(const ArrayRef &lhs, - const ArrayRef &rhs) { - SPU_ENFORCE_EQ(lhs.numel(), rhs.numel()); +// Random bit r \in {0, 1} and return as AShr +NdArrayRef BasicOTProtocols::RandBits(FieldType filed, const Shape &shape) { + // TODO(juhou): profile ring_randbit performance + auto r = ring_randbit(filed, shape).as(makeType(filed, 1)); + return SingleB2A(r); +} + +NdArrayRef BasicOTProtocols::B2ASingleBitWithSize(const NdArrayRef &inp, + int bit_width) { + const auto *share_t = inp.eltype().as(); + SPU_ENFORCE(share_t->nbits() == 1, "Support for 1bit boolean only"); + auto field = inp.eltype().as()->field(); + SPU_ENFORCE(bit_width > 1 && bit_width < (int)(8 * SizeOf(field)), + "bit_width={} is invalid", bit_width); + return SingleB2A(inp, bit_width); +} + +NdArrayRef BasicOTProtocols::BitwiseAnd(const NdArrayRef &lhs, + const NdArrayRef &rhs) { + SPU_ENFORCE_EQ(lhs.shape(), rhs.shape()); auto field = lhs.eltype().as()->field(); - const auto *shareType = lhs.eltype().as(); - size_t size = lhs.numel(); - auto [a, b, c] = AndTriple(field, size, shareType->nbits()); + const auto *shareType = lhs.eltype().as(); + size_t numel = lhs.numel(); + auto [a, b, c] = AndTriple(field, lhs.shape(), shareType->nbits()); - ArrayRef x_a = flatten(ring_xor(toNdArray(lhs), toNdArray(a))); - ArrayRef y_b = flatten(ring_xor(toNdArray(rhs), toNdArray(b))); + NdArrayRef x_a = ring_xor(lhs, a); + NdArrayRef y_b = ring_xor(rhs, b); size_t pack_load = 8 * SizeOf(field) / shareType->nbits(); if (pack_load == 1) { // Open x^a, y^b - auto res = vmap({x_a, y_b}, [&](const ArrayRef &s) { - return flatten( - conn_->allReduce(ReduceOp::XOR, toNdArray(s), "BitwiseAnd")); + auto res = vmap({x_a, y_b}, [&](const NdArrayRef &s) { + return conn_->allReduce(ReduceOp::XOR, s, "BitwiseAnd"); }); x_a = std::move(res[0]); y_b = std::move(res[1]); } else { // Open x^a, y^b - // pack multiple nbits() into single field element before exchange - SPU_ENFORCE(x_a.isCompact() && y_b.isCompact(), - "x_a numel = {}, stride = {}, y_b numel = {}, stride = {}", - x_a.numel(), x_a.stride(), y_b.numel(), y_b.stride()); - size_t packed_sze = CeilDiv(size, pack_load); - - ArrayRef packed_xa(x_a.eltype(), packed_sze); - ArrayRef packed_yb(y_b.eltype(), packed_sze); + // pack multiple nbits() into single field element before sending through + // network + SPU_ENFORCE(x_a.isCompact() && y_b.isCompact()); + int64_t packed_sze = CeilDiv(numel, pack_load); + + NdArrayRef packed_xa(x_a.eltype(), {packed_sze}); + NdArrayRef packed_yb(y_b.eltype(), {packed_sze}); + DISPATCH_ALL_FIELDS(field, "_", [&]() { - size_t used = - ZipArray({&x_a.at(0), size}, shareType->nbits(), - {&packed_xa.at(0), packed_sze}); + auto xa_wrap = absl::MakeSpan(&x_a.at(0), numel); + auto yb_wrap = absl::MakeSpan(&y_b.at(0), numel); + auto packed_xa_wrap = + absl::MakeSpan(&packed_xa.at(0), packed_sze); + auto packed_yb_wrap = + absl::MakeSpan(&packed_yb.at(0), packed_sze); + + int64_t used = + ZipArray(xa_wrap, shareType->nbits(), packed_xa_wrap); + (void)ZipArray(yb_wrap, shareType->nbits(), packed_yb_wrap); SPU_ENFORCE_EQ(used, packed_sze); - (void)ZipArray({&y_b.at(0), size}, shareType->nbits(), - {&packed_yb.at(0), packed_sze}); - // open x^a, y^b - auto res = vmap({packed_xa, packed_yb}, [&](const ArrayRef &s) { - return flatten( - conn_->allReduce(ReduceOp::XOR, toNdArray(s), "BitwiseAnd")); + auto res = vmap({packed_xa, packed_yb}, [&](const NdArrayRef &s) { + return conn_->allReduce(ReduceOp::XOR, s, "BitwiseAnd"); }); packed_xa = std::move(res[0]); packed_yb = std::move(res[1]); - UnzipArray({&packed_xa.at(0), packed_sze}, - shareType->nbits(), {&x_a.at(0), size}); - - UnzipArray({&packed_yb.at(0), packed_sze}, - shareType->nbits(), {&y_b.at(0), size}); + packed_xa_wrap = absl::MakeSpan(&packed_xa.at(0), packed_sze); + packed_yb_wrap = absl::MakeSpan(&packed_yb.at(0), packed_sze); + UnzipArray(packed_xa_wrap, shareType->nbits(), xa_wrap); + UnzipArray(packed_yb_wrap, shareType->nbits(), yb_wrap); }); } // Zi = Ci ^ ((X ^ A) & Bi) ^ ((Y ^ B) & Ai) ^ <(X ^ A) & (Y ^ B)> - auto z = flatten(ring_xor(ring_xor(ring_and(toNdArray(x_a), toNdArray(b)), - ring_and(toNdArray(y_b), toNdArray(a))), - toNdArray(c))); + auto z = ring_xor(ring_xor(ring_and(x_a, b), ring_and(y_b, a)), c); if (conn_->getRank() == 0) { - auto _z = toNdArray(z); - ring_xor_(_z, ring_and(toNdArray(x_a), toNdArray(y_b))); + ring_xor_(z, ring_and(x_a, y_b)); } return z.as(lhs.eltype()); } -std::array BasicOTProtocols::CorrelatedBitwiseAnd( - const ArrayRef &lhs, const ArrayRef &rhs0, const ArrayRef &rhs1) { - SPU_ENFORCE_EQ(lhs.numel(), rhs0.numel()); +std::array BasicOTProtocols::CorrelatedBitwiseAnd( + const NdArrayRef &lhs, const NdArrayRef &rhs0, const NdArrayRef &rhs1) { + SPU_ENFORCE_EQ(lhs.shape(), rhs0.shape()); SPU_ENFORCE(lhs.eltype() == rhs0.eltype()); - SPU_ENFORCE_EQ(lhs.numel(), rhs1.numel()); + SPU_ENFORCE_EQ(lhs.shape(), rhs1.shape()); SPU_ENFORCE(lhs.eltype() == rhs1.eltype()); auto field = lhs.eltype().as()->field(); - const auto *shareType = lhs.eltype().as(); + const auto *shareType = lhs.eltype().as(); SPU_ENFORCE_EQ(shareType->nbits(), 1UL); - size_t size = lhs.numel(); - auto [a, b0, c0, b1, c1] = CorrelatedAndTriple(field, size); + auto [a, b0, c0, b1, c1] = CorrelatedAndTriple(field, lhs.shape()); // open x^a, y^b0, y1^b1 - auto res = vmap({ring_xor(toNdArray(lhs), toNdArray(a)), - ring_xor(toNdArray(rhs0), toNdArray(b0)), - ring_xor(toNdArray(rhs1), toNdArray(b1))}, - [&](const NdArrayRef &s) { - return conn_->allReduce(ReduceOp::XOR, s, "BitwiseAnd"); - }); + auto res = + vmap({ring_xor(lhs, a), ring_xor(rhs0, b0), ring_xor(rhs1, b1)}, + [&](const NdArrayRef &s) { + return conn_->allReduce(ReduceOp::XOR, s, "CorrelatedBitwiseAnd"); + }); auto xa = std::move(res[0]); auto y0b0 = std::move(res[1]); auto y1b1 = std::move(res[2]); // Zi = Ci ^ ((X ^ A) & Bi) ^ ((Y ^ B) & Ai) ^ <(X ^ A) & (Y ^ B)> - auto z0 = ring_xor( - ring_xor(ring_and(xa, toNdArray(b0)), ring_and(y0b0, toNdArray(a))), - toNdArray(c0)); - auto z1 = ring_xor( - ring_xor(ring_and(xa, toNdArray(b1)), ring_and(y1b1, toNdArray(a))), - toNdArray(c1)); + auto z0 = ring_xor(ring_xor(ring_and(xa, b0), ring_and(y0b0, a)), c0); + auto z1 = ring_xor(ring_xor(ring_and(xa, b1), ring_and(y1b1, a)), c1); if (conn_->getRank() == 0) { ring_xor_(z0, ring_and(xa, y0b0)); ring_xor_(z1, ring_and(xa, y1b1)); } - return {flatten(z0).as(lhs.eltype()), flatten(z1).as(lhs.eltype())}; + return {z0.as(lhs.eltype()), z1.as(lhs.eltype())}; } // Ref: https://eprint.iacr.org/2013/552.pdf // Algorithm 1. AND triple using 1-of-2 ROT. -std::array BasicOTProtocols::AndTriple(FieldType field, - size_t numel, - size_t nbits_each) { - // ROT sender obtains x_0, x_1 - // ROT recevier obtains x_a, a for a \in {0, 1} - // - // Sender set (b = x0 ^ x1, v = x0) - // Recevier set (a, u = x_a) - // a & b = a & (x0 ^ x1) - // = a & (x0 ^ x1) ^ (x0 ^ x0) <- zero m0 ^ m0 - // = (a & (x0 ^ x1) ^ x0) ^ x0 - // = (x_a) ^ x0 - // = u ^ v - // - // P0 acts as S to obtain (a0, u0) - // P1 acts as R to obtain (b1, v1) - // such that a0 & b1 = u0 ^ v1 - // - // Flip the role - // P1 obtains (a1, u1) - // P0 obtains (b0, v0) - // such that a1 & b0 = u1 ^ v0 - // - // Pi sets ci = ai & bi ^ ui ^ vi - // such that (a0 ^ a1) & (b0 ^ b1) = (c0 ^ c1) +// Math +// ROT sender obtains x_0, x_1 +// ROT recevier obtains x_a, a for a \in {0, 1} +// +// Sender set (b = x0 ^ x1, v = x0) +// Recevier set (a, u = x_a) +// a & b = a & (x0 ^ x1) +// = a & (x0 ^ x1) ^ (x0 ^ x0) <- zero m0 ^ m0 +// = (a & (x0 ^ x1) ^ x0) ^ x0 +// = (x_a) ^ x0 +// = u ^ v +// +// P0 acts as S to obtain (a0, u0) +// P1 acts as R to obtain (b1, v1) +// such that a0 & b1 = u0 ^ v1 +// +// Flip the role +// P1 obtains (a1, u1) +// P0 obtains (b0, v0) +// such that a1 & b0 = u1 ^ v0 +// +// Pi sets ci = ai & bi ^ ui ^ vi +// such that (a0 ^ a1) & (b0 ^ b1) = (c0 ^ c1) +std::array BasicOTProtocols::AndTriple(FieldType field, + const Shape &shape, + size_t nbits_each) { + int64_t numel = shape.numel(); SPU_ENFORCE(numel > 0); SPU_ENFORCE(nbits_each >= 1 && nbits_each <= SizeOf(field) * 8, "invalid packing load {} for one AND", nbits_each); @@ -352,13 +341,14 @@ std::array BasicOTProtocols::AndTriple(FieldType field, }); // init as zero - auto AND_a = flatten(ring_zeros(field, {static_cast(numel)})); - auto AND_b = flatten(ring_zeros(field, {static_cast(numel)})); - auto AND_c = flatten(ring_zeros(field, {static_cast(numel)})); + auto AND_a = ring_zeros(field, shape); + auto AND_b = ring_zeros(field, shape); + auto AND_c = ring_zeros(field, shape); + DISPATCH_ALL_FIELDS(field, "AndTriple", [&]() { - auto AND_xa = ArrayView(AND_a); - auto AND_xb = ArrayView(AND_b); - auto AND_xc = ArrayView(AND_c); + auto AND_xa = NdArrayView(AND_a); + auto AND_xb = NdArrayView(AND_b); + auto AND_xc = NdArrayView(AND_c); pforeach(0, numel, [&](int64_t i) { int64_t bgn = i * nbits_each; int64_t end = bgn + nbits_each; @@ -373,8 +363,9 @@ std::array BasicOTProtocols::AndTriple(FieldType field, return {AND_a, AND_b, AND_c}; } -std::array BasicOTProtocols::CorrelatedAndTriple(FieldType field, - size_t numel) { +std::array BasicOTProtocols::CorrelatedAndTriple( + FieldType field, const Shape &shape) { + int64_t numel = shape.numel(); SPU_ENFORCE(numel > 0); // NOTE(juhou): we use uint8_t to store 2-bit ROT constexpr size_t ot_msg_width = 2; @@ -405,18 +396,18 @@ std::array BasicOTProtocols::CorrelatedAndTriple(FieldType field, c[i] = (((a[i] << 1) | a[i]) & b[i]) ^ u[i] ^ v[i]; }); - auto AND_a = flatten(ring_zeros(field, {static_cast(numel)})); - auto AND_b0 = flatten(ring_zeros(field, {static_cast(numel)})); - auto AND_c0 = flatten(ring_zeros(field, {static_cast(numel)})); - auto AND_b1 = flatten(ring_zeros(field, {static_cast(numel)})); - auto AND_c1 = flatten(ring_zeros(field, {static_cast(numel)})); + auto AND_a = ring_zeros(field, shape); + auto AND_b0 = ring_zeros(field, shape); + auto AND_c0 = ring_zeros(field, shape); + auto AND_b1 = ring_zeros(field, shape); + auto AND_c1 = ring_zeros(field, shape); DISPATCH_ALL_FIELDS(field, "AndTriple", [&]() { - auto AND_xa = ArrayView(AND_a); - auto AND_xb0 = ArrayView(AND_b0); - auto AND_xc0 = ArrayView(AND_c0); - auto AND_xb1 = ArrayView(AND_b1); - auto AND_xc1 = ArrayView(AND_c1); + auto AND_xa = NdArrayView(AND_a); + auto AND_xb0 = NdArrayView(AND_b0); + auto AND_xc0 = NdArrayView(AND_c0); + auto AND_xb1 = NdArrayView(AND_b1); + auto AND_xc1 = NdArrayView(AND_c1); pforeach(0, numel, [&](int64_t i) { AND_xa[i] = a[i] & 1; AND_xb0[i] = b[i] & 1; @@ -425,32 +416,33 @@ std::array BasicOTProtocols::CorrelatedAndTriple(FieldType field, AND_xc1[i] = (c[i] >> 1) & 1; }); }); + return {AND_a, AND_b0, AND_c0, AND_b1, AND_c1}; } int BasicOTProtocols::Rank() const { return ferret_sender_->Rank(); } -ArrayRef BasicOTProtocols::Multiplexer(const ArrayRef &msg, - const ArrayRef &select) { - SPU_ENFORCE_EQ(msg.numel(), select.numel()); - const auto *shareType = select.eltype().as(); +NdArrayRef BasicOTProtocols::Multiplexer(const NdArrayRef &msg, + const NdArrayRef &select) { + SPU_ENFORCE_EQ(msg.shape(), select.shape()); + const auto *shareType = select.eltype().as(); SPU_ENFORCE_EQ(shareType->nbits(), 1UL); const auto field = msg.eltype().as()->field(); const int64_t size = msg.numel(); - auto _corr_data = flatten(ring_zeros(field, {size})); - auto _sent = flatten(ring_zeros(field, {size})); - auto _recv = flatten(ring_zeros(field, {size})); + auto _corr_data = ring_zeros(field, msg.shape()); + auto _sent = ring_zeros(field, msg.shape()); + auto _recv = ring_zeros(field, msg.shape()); std::vector sel(size); // Compute (x0 + x1) * (b0 ^ b1) // Also b0 ^ b1 = 1 - 2*b0*b1 return DISPATCH_ALL_FIELDS(field, "Multiplexer", [&]() { - ArrayView _msg(msg); - ArrayView _sel(select); - ArrayView corr_data(_corr_data); - ArrayView sent(_sent); - ArrayView recv(_recv); + NdArrayView _msg(msg); + NdArrayView _sel(select); + auto corr_data = absl::MakeSpan(&_corr_data.at(0), size); + auto sent = absl::MakeSpan(&_sent.at(0), size); + auto recv = absl::MakeSpan(&_recv.at(0), size); pforeach(0, size, [&](int64_t i) { sel[i] = static_cast(_sel[i] & 1); @@ -458,16 +450,12 @@ ArrayRef BasicOTProtocols::Multiplexer(const ArrayRef &msg, }); if (Rank() == 0) { - ferret_sender_->SendCAMCC({corr_data.data(), (size_t)size}, - {sent.data(), (size_t)size}); + ferret_sender_->SendCAMCC(corr_data, sent); ferret_sender_->Flush(); - ferret_receiver_->RecvCAMCC(absl::MakeSpan(sel), - {recv.data(), (size_t)size}); + ferret_receiver_->RecvCAMCC(absl::MakeSpan(sel), recv); } else { - ferret_receiver_->RecvCAMCC(absl::MakeSpan(sel), - {recv.data(), (size_t)size}); - ferret_sender_->SendCAMCC({corr_data.data(), (size_t)size}, - {sent.data(), (size_t)size}); + ferret_receiver_->RecvCAMCC(absl::MakeSpan(sel), recv); + ferret_sender_->SendCAMCC(corr_data, sent); ferret_sender_->Flush(); } @@ -478,4 +466,5 @@ ArrayRef BasicOTProtocols::Multiplexer(const ArrayRef &msg, return _recv; }); } + } // namespace spu::mpc::cheetah diff --git a/libspu/mpc/cheetah/ot/basic_ot_prot.h b/libspu/mpc/cheetah/ot/basic_ot_prot.h index 56a4d2dc..7c09b974 100644 --- a/libspu/mpc/cheetah/ot/basic_ot_prot.h +++ b/libspu/mpc/cheetah/ot/basic_ot_prot.h @@ -14,7 +14,7 @@ #pragma once -#include "libspu/mpc/cheetah/array_ref.h" +#include "libspu/core/ndarray_ref.h" #include "libspu/mpc/cheetah/ot/ferret.h" namespace spu::mpc::cheetah { @@ -27,31 +27,38 @@ class BasicOTProtocols { int Rank() const; - ArrayRef B2A(const ArrayRef &inp); + NdArrayRef B2A(const NdArrayRef &inp); + + NdArrayRef RandBits(FieldType filed, const Shape &shape); // NOTE(lwj): compute the B2A(b) and output to the specified ring // Require: input is 1-bit boolean and 1 <= bit_width < k. - ArrayRef B2ASingleBitWithSize(const ArrayRef &inp, int bit_width); - - ArrayRef RandBits(FieldType filed, size_t numel); + NdArrayRef B2ASingleBitWithSize(const NdArrayRef &inp, int bit_width); // msg * select for select \in {0, 1} - ArrayRef Multiplexer(const ArrayRef &msg, const ArrayRef &select); + NdArrayRef Multiplexer(const NdArrayRef &msg, const NdArrayRef &select); // Create `numel` of AND-triple. Each element contains `k` bits // 1 <= k <= field size - std::array AndTriple(FieldType field, size_t numel, size_t k); + // std::array AndTriple(FieldType field, size_t numel, size_t k); + + std::array AndTriple(FieldType field, const Shape &shape, + size_t k); // [a, b, b', c, c'] such that c = a*b and c' = a*b' for the same a - std::array CorrelatedAndTriple(FieldType field, size_t numel); + // std::array CorrelatedAndTriple(FieldType field, size_t numel); - ArrayRef BitwiseAnd(const ArrayRef &lhs, const ArrayRef &rhs); + std::array CorrelatedAndTriple(FieldType field, + const Shape &shape); + + // ArrayRef BitwiseAnd(const ArrayRef &lhs, const ArrayRef &rhs); + + NdArrayRef BitwiseAnd(const NdArrayRef &lhs, const NdArrayRef &rhs); // Compute the ANDs `lhs & rhs0` and `lhs & rhs1` - // Require non-packed Boolean currently. - std::array CorrelatedBitwiseAnd(const ArrayRef &lhs, - const ArrayRef &rhs0, - const ArrayRef &rhs1); + std::array CorrelatedBitwiseAnd(const NdArrayRef &lhs, + const NdArrayRef &rhs0, + const NdArrayRef &rhs1); std::shared_ptr GetSenderCOT() { return ferret_sender_; } @@ -60,12 +67,9 @@ class BasicOTProtocols { void Flush(); protected: - ArrayRef Compare(const ArrayRef &inp, bool greater_than, bool equality, - int radix_base); - - ArrayRef SingleB2A(const ArrayRef &inp, int bit_width = 0); + NdArrayRef SingleB2A(const NdArrayRef &inp, int bit_width = 0); - ArrayRef PackedB2A(const ArrayRef &inp); + NdArrayRef PackedB2A(const NdArrayRef &inp); private: std::shared_ptr conn_; diff --git a/libspu/mpc/cheetah/ot/basic_ot_prot_test.cc b/libspu/mpc/cheetah/ot/basic_ot_prot_test.cc index 2ec49342..e112d5d3 100644 --- a/libspu/mpc/cheetah/ot/basic_ot_prot_test.cc +++ b/libspu/mpc/cheetah/ot/basic_ot_prot_test.cc @@ -19,7 +19,7 @@ #include "gtest/gtest.h" #include "libspu/core/xt_helper.h" -#include "libspu/mpc/semi2k/type.h" +#include "libspu/mpc/cheetah/type.h" #include "libspu/mpc/utils/ring_ops.h" #include "libspu/mpc/utils/simulate.h" @@ -38,15 +38,15 @@ INSTANTIATE_TEST_SUITE_P( TEST_P(BasicOTProtTest, SingleB2A) { size_t kWorldSize = 2; - int64_t n = 7; + Shape shape = {10, 30}; FieldType field = GetParam(); size_t nbits = 8 * SizeOf(field) - 1; size_t packed_nbits = 8 * SizeOf(field) - nbits; - auto boolean_t = makeType(field, packed_nbits); + auto boolean_t = makeType(field, packed_nbits); - auto bshr0 = ring_rand(field, {n}).as(boolean_t); - auto bshr1 = ring_rand(field, {n}).as(boolean_t); + auto bshr0 = ring_rand(field, shape).as(boolean_t); + auto bshr1 = ring_rand(field, shape).as(boolean_t); DISPATCH_ALL_FIELDS(field, "", [&]() { auto mask = static_cast(-1); if (nbits > 0) { @@ -60,17 +60,21 @@ TEST_P(BasicOTProtTest, SingleB2A) { } }); - NdArrayRef ashr0, ashr1; + NdArrayRef ashr0; + NdArrayRef ashr1; utils::simulate(kWorldSize, [&](std::shared_ptr ctx) { auto conn = std::make_shared(ctx); BasicOTProtocols ot_prot(conn); if (ctx->Rank() == 0) { - ashr0 = toNdArray(ot_prot.B2A(flatten(bshr0))); + ashr0 = ot_prot.B2A(bshr0); } else { - ashr1 = toNdArray(ot_prot.B2A(flatten(bshr1))); + ashr1 = ot_prot.B2A(bshr1); } }); + EXPECT_EQ(ashr0.shape(), ashr1.shape()); + EXPECT_EQ(shape, ashr0.shape()); + DISPATCH_ALL_FIELDS(field, "", [&]() { auto b0 = xt_adapt(bshr0); auto b1 = xt_adapt(bshr1); @@ -80,7 +84,7 @@ TEST_P(BasicOTProtTest, SingleB2A) { if (nbits > 0) { mask = (static_cast(1) << packed_nbits) - 1; } - for (int64_t i = 0; i < n; ++i) { + for (int64_t i = 0; i < shape.numel(); ++i) { ring2k_t e = b0[i] ^ b1[i]; ring2k_t c = (a0[i] + a1[i]) & mask; EXPECT_EQ(e, c); @@ -90,15 +94,15 @@ TEST_P(BasicOTProtTest, SingleB2A) { TEST_P(BasicOTProtTest, PackedB2A) { size_t kWorldSize = 2; - int64_t n = 7; + Shape shape = {11, 12, 13}; FieldType field = GetParam(); for (size_t nbits : {1, 2}) { size_t packed_nbits = 8 * SizeOf(field) - nbits; - auto boolean_t = makeType(field, packed_nbits); + auto boolean_t = makeType(field, packed_nbits); - auto bshr0 = ring_rand(field, {n}).as(boolean_t); - auto bshr1 = ring_rand(field, {n}).as(boolean_t); + auto bshr0 = ring_rand(field, shape).as(boolean_t); + auto bshr1 = ring_rand(field, shape).as(boolean_t); DISPATCH_ALL_FIELDS(field, "", [&]() { auto mask = static_cast(-1); if (nbits > 0) { @@ -112,16 +116,19 @@ TEST_P(BasicOTProtTest, PackedB2A) { } }); - NdArrayRef ashr0, ashr1; + NdArrayRef ashr0; + NdArrayRef ashr1; utils::simulate(kWorldSize, [&](std::shared_ptr ctx) { auto conn = std::make_shared(ctx); BasicOTProtocols ot_prot(conn); if (ctx->Rank() == 0) { - ashr0 = toNdArray(ot_prot.B2A(flatten(bshr0))); + ashr0 = ot_prot.B2A(bshr0); } else { - ashr1 = toNdArray(ot_prot.B2A(flatten(bshr1))); + ashr1 = ot_prot.B2A(bshr1); } }); + EXPECT_EQ(ashr0.shape(), ashr1.shape()); + EXPECT_EQ(ashr0.shape(), shape); DISPATCH_ALL_FIELDS(field, "", [&]() { auto b0 = xt_adapt(bshr0); @@ -129,10 +136,12 @@ TEST_P(BasicOTProtTest, PackedB2A) { auto a0 = xt_adapt(ashr0); auto a1 = xt_adapt(ashr1); auto mask = static_cast(-1); + if (nbits > 0) { mask = (static_cast(1) << packed_nbits) - 1; } - for (int64_t i = 0; i < n; ++i) { + + for (int64_t i = 0; i < shape.numel(); ++i) { ring2k_t e = b0[i] ^ b1[i]; ring2k_t c = (a0[i] + a1[i]) & mask; EXPECT_EQ(e, c); @@ -143,15 +152,15 @@ TEST_P(BasicOTProtTest, PackedB2A) { TEST_P(BasicOTProtTest, PackedB2AFull) { size_t kWorldSize = 2; - int64_t n = 7; + Shape shape = {1, 2, 3, 4, 5}; FieldType field = GetParam(); for (size_t nbits : {0}) { size_t packed_nbits = 8 * SizeOf(field) - nbits; - auto boolean_t = makeType(field, packed_nbits); + auto boolean_t = makeType(field, packed_nbits); - auto bshr0 = ring_rand(field, {n}).as(boolean_t); - auto bshr1 = ring_rand(field, {n}).as(boolean_t); + auto bshr0 = ring_rand(field, shape).as(boolean_t); + auto bshr1 = ring_rand(field, shape).as(boolean_t); DISPATCH_ALL_FIELDS(field, "", [&]() { auto mask = static_cast(-1); if (nbits > 0) { @@ -165,17 +174,21 @@ TEST_P(BasicOTProtTest, PackedB2AFull) { } }); - NdArrayRef ashr0, ashr1; + NdArrayRef ashr0; + NdArrayRef ashr1; utils::simulate(kWorldSize, [&](std::shared_ptr ctx) { auto conn = std::make_shared(ctx); BasicOTProtocols ot_prot(conn); if (ctx->Rank() == 0) { - ashr0 = toNdArray(ot_prot.B2A(flatten(bshr0))); + ashr0 = ot_prot.B2A(bshr0); } else { - ashr1 = toNdArray(ot_prot.B2A(flatten(bshr1))); + ashr1 = ot_prot.B2A(bshr1); } }); + EXPECT_EQ(ashr0.shape(), ashr1.shape()); + EXPECT_EQ(ashr0.shape(), shape); + DISPATCH_ALL_FIELDS(field, "", [&]() { auto b0 = xt_adapt(bshr0); auto b1 = xt_adapt(bshr1); @@ -185,7 +198,7 @@ TEST_P(BasicOTProtTest, PackedB2AFull) { if (nbits > 0) { mask = (static_cast(1) << packed_nbits) - 1; } - for (int64_t i = 0; i < n; ++i) { + for (int64_t i = 0; i < shape.numel(); ++i) { ring2k_t e = b0[i] ^ b1[i]; ring2k_t c = (a0[i] + a1[i]) & mask; EXPECT_EQ(e, c); @@ -196,7 +209,7 @@ TEST_P(BasicOTProtTest, PackedB2AFull) { TEST_P(BasicOTProtTest, AndTripleSparse) { size_t kWorldSize = 2; - int64_t n = 55; + Shape shape = {55, 100}; FieldType field = GetParam(); size_t max_bit = 8 * SizeOf(field); @@ -208,8 +221,8 @@ TEST_P(BasicOTProtTest, AndTripleSparse) { auto conn = std::make_shared(ctx); BasicOTProtocols ot_prot(conn); - for (const auto& t : ot_prot.AndTriple(field, n, target_nbits)) { - triple[ctx->Rank()].emplace_back(toNdArray(t)); + for (const auto& t : ot_prot.AndTriple(field, shape, target_nbits)) { + triple[ctx->Rank()].emplace_back(t); } }); @@ -222,7 +235,7 @@ TEST_P(BasicOTProtTest, AndTripleSparse) { auto b1 = xt_adapt(triple[1][1]); auto c1 = xt_adapt(triple[1][2]); - for (int64_t i = 0; i < n; ++i) { + for (int64_t i = 0; i < shape.numel(); ++i) { EXPECT_TRUE(a0[i] < max && a1[i] < max); EXPECT_TRUE(b0[i] < max && b1[i] < max); EXPECT_TRUE(c0[i] < max && c1[i] < max); @@ -237,7 +250,7 @@ TEST_P(BasicOTProtTest, AndTripleSparse) { TEST_P(BasicOTProtTest, AndTripleFull) { size_t kWorldSize = 2; - int64_t n = 55; + Shape shape = {55, 11}; FieldType field = GetParam(); std::vector packed_triple[2]; @@ -245,8 +258,8 @@ TEST_P(BasicOTProtTest, AndTripleFull) { utils::simulate(kWorldSize, [&](std::shared_ptr ctx) { auto conn = std::make_shared(ctx); BasicOTProtocols ot_prot(conn); - for (const auto& t : ot_prot.AndTriple(field, n, SizeOf(field) * 8)) { - packed_triple[ctx->Rank()].emplace_back(toNdArray(t)); + for (const auto& t : ot_prot.AndTriple(field, shape, SizeOf(field) * 8)) { + packed_triple[ctx->Rank()].emplace_back(t); } }); @@ -259,7 +272,7 @@ TEST_P(BasicOTProtTest, AndTripleFull) { auto c1 = xt_adapt(packed_triple[1][2]); size_t nn = a0.size(); - EXPECT_TRUE(nn * 8 * SizeOf(field) >= (size_t)n); + EXPECT_TRUE(nn * 8 * SizeOf(field) >= (size_t)shape.numel()); for (size_t i = 0; i < nn; ++i) { ring2k_t e = (a0[i] ^ a1[i]) & (b0[i] ^ b1[i]); @@ -272,15 +285,15 @@ TEST_P(BasicOTProtTest, AndTripleFull) { TEST_P(BasicOTProtTest, Multiplexer) { size_t kWorldSize = 2; - int64_t n = 7; + Shape shape = {3, 4, 1, 3}; FieldType field = GetParam(); - auto boolean_t = makeType(field, 1); + auto boolean_t = makeType(field, 1); - auto ashr0 = ring_rand(field, {n}); - auto ashr1 = ring_rand(field, {n}); - auto bshr0 = ring_rand(field, {n}).as(boolean_t); - auto bshr1 = ring_rand(field, {n}).as(boolean_t); + auto ashr0 = ring_rand(field, shape); + auto ashr1 = ring_rand(field, shape); + auto bshr0 = ring_rand(field, shape).as(boolean_t); + auto bshr1 = ring_rand(field, shape).as(boolean_t); DISPATCH_ALL_FIELDS(field, "", [&]() { auto mask = static_cast(1); @@ -297,14 +310,15 @@ TEST_P(BasicOTProtTest, Multiplexer) { auto conn = std::make_shared(ctx); BasicOTProtocols ot_prot(conn); if (ctx->Rank() == 0) { - computed[0] = - toNdArray(ot_prot.Multiplexer(flatten(ashr0), flatten(bshr0))); + computed[0] = ot_prot.Multiplexer(ashr0, bshr0); } else { - computed[1] = - toNdArray(ot_prot.Multiplexer(flatten(ashr1), flatten(bshr1))); + computed[1] = ot_prot.Multiplexer(ashr1, bshr1); } }); + EXPECT_EQ(computed[0].shape(), computed[1].shape()); + EXPECT_EQ(computed[0].shape(), shape); + DISPATCH_ALL_FIELDS(field, "", [&]() { auto a0 = xt_adapt(ashr0); auto a1 = xt_adapt(ashr1); @@ -313,7 +327,7 @@ TEST_P(BasicOTProtTest, Multiplexer) { auto c0 = xt_adapt(computed[0]); auto c1 = xt_adapt(computed[1]); - for (int64_t i = 0; i < n; ++i) { + for (int64_t i = 0; i < shape.numel(); ++i) { ring2k_t msg = (a0[i] + a1[i]); ring2k_t sel = (b0[i] ^ b1[i]); ring2k_t exp = msg * sel; @@ -325,31 +339,37 @@ TEST_P(BasicOTProtTest, Multiplexer) { TEST_P(BasicOTProtTest, CorrelatedAndTriple) { size_t kWorldSize = 2; - size_t n = 10; + Shape shape = {10 * 8}; FieldType field = GetParam(); - std::array corr_triple[2]; + std::array corr_triple[2]; utils::simulate(kWorldSize, [&](std::shared_ptr ctx) { auto conn = std::make_shared(ctx); BasicOTProtocols ot_prot(conn); - corr_triple[ctx->Rank()] = ot_prot.CorrelatedAndTriple(field, n); + corr_triple[ctx->Rank()] = ot_prot.CorrelatedAndTriple(field, shape); }); + EXPECT_EQ(corr_triple[0][0].shape(), corr_triple[1][0].shape()); + for (int i = 1; i < 5; ++i) { + EXPECT_EQ(corr_triple[0][0].shape(), corr_triple[0][i].shape()); + EXPECT_EQ(corr_triple[1][0].shape(), corr_triple[1][i].shape()); + } + DISPATCH_ALL_FIELDS(field, "", [&]() { - auto a0 = ArrayView(corr_triple[0][0]); - auto b0 = ArrayView(corr_triple[0][1]); - auto c0 = ArrayView(corr_triple[0][2]); - auto d0 = ArrayView(corr_triple[0][3]); - auto e0 = ArrayView(corr_triple[0][4]); - - auto a1 = ArrayView(corr_triple[1][0]); - auto b1 = ArrayView(corr_triple[1][1]); - auto c1 = ArrayView(corr_triple[1][2]); - auto d1 = ArrayView(corr_triple[1][3]); - auto e1 = ArrayView(corr_triple[1][4]); - - for (size_t i = 0; i < n; ++i) { + auto a0 = NdArrayView(corr_triple[0][0]); + auto b0 = NdArrayView(corr_triple[0][1]); + auto c0 = NdArrayView(corr_triple[0][2]); + auto d0 = NdArrayView(corr_triple[0][3]); + auto e0 = NdArrayView(corr_triple[0][4]); + + auto a1 = NdArrayView(corr_triple[1][0]); + auto b1 = NdArrayView(corr_triple[1][1]); + auto c1 = NdArrayView(corr_triple[1][2]); + auto d1 = NdArrayView(corr_triple[1][3]); + auto e1 = NdArrayView(corr_triple[1][4]); + + for (int64_t i = 0; i < shape.numel(); ++i) { EXPECT_TRUE(a0[i] < 2 && a1[i] < 2); EXPECT_TRUE(b0[i] < 2 && b1[i] < 2); EXPECT_TRUE(c0[i] < 2 && c1[i] < 2); diff --git a/libspu/mpc/cheetah/ot/ferret.cc b/libspu/mpc/cheetah/ot/ferret.cc index b1bec43d..6d071529 100644 --- a/libspu/mpc/cheetah/ot/ferret.cc +++ b/libspu/mpc/cheetah/ot/ferret.cc @@ -518,6 +518,117 @@ struct FerretOT::Impl { }); } + // Specific for 1-of-2 OT with less hash calls + template + void SendChosenTwoMsgChosenChoice(absl::Span msg_array, + size_t bit_width) { + constexpr size_t N = 2; + const size_t Nn = msg_array.size(); + SPU_ENFORCE(Nn > 0 && 0 == (Nn % N)); + const size_t n = Nn / N; + + std::unique_ptr rcm_data(new OtBaseTyp[n]); + SendRandCorrelatedMsgChosenChoice(rcm_data.get(), n); + + // async a random seed + emp::block seed; + ferret_->prg.random_block(&seed, 1); + io_->send_block(&seed, 1); + io_->flush(); + ferret_->mitccrh.setS(seed); + + const size_t pack_load = 8 * sizeof(T) / bit_width; + std::vector pad(kOTBatchSize * N); + std::vector to_send(kOTBatchSize * N); + std::vector packed_to_send; + if (pack_load > 1) { + // NOTE: pack bit chunks into single T element if possible + packed_to_send.resize(CeilDiv(to_send.size(), pack_load)); + } + + for (size_t i = 0; i < n; i += kOTBatchSize) { + size_t this_batch = std::min(kOTBatchSize, n - i); + std::memset(pad.data(), 0, pad.size() * sizeof(OtBaseTyp)); + + for (size_t j = 0; j < this_batch; ++j) { + pad[2 * j] = rcm_data[i + j]; + pad[2 * j + 1] = rcm_data[i + j] ^ ferret_->Delta; + } + + ferret_->mitccrh.template hash(pad.data()); + + for (size_t j = 0; j < this_batch; ++j) { + const auto* this_msg = msg_array.data() + (i + j) * N; + to_send[2 * j] = ConvFromBlock(pad[2 * j]) ^ this_msg[0]; + to_send[2 * j + 1] = ConvFromBlock(pad[2 * j + 1]) ^ this_msg[1]; + } + + if (pack_load == 1) { + io_->send_data(to_send.data(), sizeof(T) * this_batch * N); + } else { + size_t used = ZipArray({to_send.data(), N * this_batch}, bit_width, + absl::MakeSpan(packed_to_send)); + SPU_ENFORCE(used == CeilDiv(N * this_batch, pack_load)); + io_->send_data(packed_to_send.data(), used * sizeof(T)); + } + } + } + + template + void RecvChosenTwoMsgChosenChoice(absl::Span choices, + absl::Span output, size_t bit_width) { + SPU_ENFORCE(bit_width > 0 && bit_width <= 8 * sizeof(T)); + + constexpr size_t N = 2; + const size_t n = choices.size(); + SPU_ENFORCE_EQ(output.size(), n); + + SPU_ENFORCE(std::all_of(choices.data(), choices.data() + n, + [](uint8_t b) { return b < 2; }), + "choice out-of-bound N=2"); + + std::vector rcm_data(n); + RecvRandCorrelatedMsgChosenChoice(choices, absl::MakeSpan(rcm_data)); + + // async a seed from sender + emp::block seed; + io_->recv_block(&seed, 1); + ferret_->mitccrh.setS(seed); + + const T msg_mask = makeBitsMask(bit_width); + const size_t pack_load = 8 * sizeof(T) / bit_width; + + std::vector pad(kOTBatchSize); + std::vector recv(kOTBatchSize * N); + std::vector packed_recv; + if (pack_load > 1) { + packed_recv.resize(CeilDiv(recv.size(), pack_load)); + } + + for (size_t i = 0; i < n; i += kOTBatchSize) { + size_t this_batch = std::min(kOTBatchSize, n - i); + for (size_t j = 0; j < this_batch; ++j) { + pad[j] = rcm_data[i + j]; + } + ferret_->mitccrh.template hash(pad.data()); + + if (pack_load == 1) { + io_->recv_data(recv.data(), N * this_batch * sizeof(T)); + } else { + size_t used = CeilDiv(N * this_batch, pack_load); + io_->recv_data(packed_recv.data(), used * sizeof(T)); + UnzipArray({packed_recv.data(), used}, bit_width, + {recv.data(), N * this_batch}); + } + + for (size_t j = 0; j < this_batch; ++j) { + output[i + j] = recv[2 * j + choices[i + j]] ^ ConvFromBlock(pad[j]); + output[i + j] &= msg_mask; + } + } + } + + // General 1-of-N OT template void SendChosenMsgChosenChoice(absl::Span msg_array, size_t N, size_t bit_width) { @@ -556,12 +667,13 @@ struct FerretOT::Impl { std::vector pad(kOTBatchSize * N); const T msg_mask = makeBitsMask(bit_width); - size_t pack_load = 8 * sizeof(T) / bit_width; + bool packable = 8 * sizeof(T) > bit_width; + size_t packed_sze = CeilDiv(bit_width * N * kOTBatchSize, sizeof(T) * 8); std::vector to_send(kOTBatchSize * N); std::vector packed_to_send; - if (pack_load > 1) { + if (packable) { // NOTE: pack bit chunks into single T element if possible - packed_to_send.resize(CeilDiv(to_send.size(), pack_load)); + packed_to_send.resize(packed_sze); } for (size_t i = 0; i < n; i += kOTBatchSize) { @@ -599,10 +711,10 @@ struct FerretOT::Impl { } } - if (pack_load > 1) { - size_t used = ZipArray({to_send.data(), N * this_batch}, bit_width, - absl::MakeSpan(packed_to_send)); - SPU_ENFORCE(used == CeilDiv(N * this_batch, pack_load)); + if (packable) { + size_t used = ZipArrayBit({to_send.data(), N * this_batch}, + bit_width, absl::MakeSpan(packed_to_send)); + SPU_ENFORCE(used == CeilDiv(N * this_batch * bit_width, sizeof(T) * 8)); io_->send_data(packed_to_send.data(), used * sizeof(T)); } else { io_->send_data(to_send.data(), N * this_batch * sizeof(T)); @@ -643,16 +755,16 @@ struct FerretOT::Impl { std::vector pad(kOTBatchSize); const T msg_mask = makeBitsMask(bit_width); - const size_t pack_load = 8 * sizeof(T) / bit_width; std::vector recv(kOTBatchSize * N); - std::vector packed_recv(CeilDiv(recv.size(), pack_load)); + const size_t packed_size = CeilDiv(recv.size() * bit_width, sizeof(T) * 8); + std::vector packed_recv(packed_size); for (size_t i = 0; i < n; i += kOTBatchSize) { size_t this_batch = std::min(kOTBatchSize, n - i); - size_t used = CeilDiv(N * this_batch, pack_load); + size_t used = CeilDiv(N * this_batch * bit_width, sizeof(T) * 8); io_->recv_data(packed_recv.data(), used * sizeof(T)); - UnzipArray({packed_recv.data(), used}, bit_width, - {recv.data(), N * this_batch}); + UnzipArrayBit({packed_recv.data(), used}, bit_width, + {recv.data(), N * this_batch}); std::memset(pad.data(), 0, kOTBatchSize * sizeof(OtBaseTyp)); for (size_t j = 0; j < this_batch; ++j) { @@ -754,11 +866,19 @@ size_t CheckBitWidth(size_t bw) { void FerretOT::SendCMCC(absl::Span msg_array, size_t N, \ size_t bit_width) { \ bit_width = CheckBitWidth(bit_width); \ + if (N == 2) { \ + impl_->SendChosenTwoMsgChosenChoice(msg_array, bit_width); \ + return; \ + } \ impl_->SendChosenMsgChosenChoice(msg_array, N, bit_width); \ } \ void FerretOT::RecvCMCC(absl::Span choices, size_t N, \ absl::Span output, size_t bit_width) { \ bit_width = CheckBitWidth(bit_width); \ + if (N == 2) { \ + impl_->RecvChosenTwoMsgChosenChoice(choices, output, bit_width); \ + return; \ + } \ impl_->RecvChosenMsgChosenChoice(choices, N, output, bit_width); \ } \ void FerretOT::SendRMCC(absl::Span output0, absl::Span output1, \ diff --git a/libspu/mpc/cheetah/ot/ferret_test.cc b/libspu/mpc/cheetah/ot/ferret_test.cc index a5137736..d5301f25 100644 --- a/libspu/mpc/cheetah/ot/ferret_test.cc +++ b/libspu/mpc/cheetah/ot/ferret_test.cc @@ -20,7 +20,6 @@ #include "libspu/core/xt_helper.h" #include "libspu/mpc/cheetah/ot/basic_ot_prot.h" -#include "libspu/mpc/semi2k/type.h" #include "libspu/mpc/utils/ring_ops.h" #include "libspu/mpc/utils/simulate.h" @@ -154,7 +153,7 @@ TEST_P(FerretCOTTest, RndMsgChosenChoice) { TEST_P(FerretCOTTest, ChosenMsgChosenChoice) { size_t kWorldSize = 2; - int64_t n = 106; + int64_t n = 100; auto field = GetParam(); DISPATCH_ALL_FIELDS(field, "", [&]() { using scalar_t = ring2k_t; diff --git a/libspu/mpc/cheetah/ot/util.h b/libspu/mpc/cheetah/ot/util.h index f32bc688..b8db0e90 100644 --- a/libspu/mpc/cheetah/ot/util.h +++ b/libspu/mpc/cheetah/ot/util.h @@ -115,44 +115,78 @@ size_t UnzipArray(absl::Span inp, size_t bit_width, } template -size_t PackU8Array(absl::Span u8array, absl::Span packed) { - constexpr size_t elsze = sizeof(T); - const size_t nbytes = u8array.size(); - const size_t numel = CeilDiv(nbytes, elsze); +size_t ZipArrayBit(absl::Span inp, size_t bit_width, + absl::Span oup) { + size_t width = sizeof(T) * 8; + SPU_ENFORCE(bit_width > 0 && width >= bit_width); + size_t numel = inp.size(); + size_t packed_sze = CeilDiv(numel * bit_width, width); - SPU_ENFORCE(packed.size() >= numel); + SPU_ENFORCE(oup.size() >= packed_sze); - for (size_t i = 0; i < nbytes; i += elsze) { - size_t this_batch = std::min(nbytes - i, elsze); - T acc{0}; - for (size_t j = 0; j < this_batch; ++j) { - acc = (acc << 8) | u8array[i + j]; + const T mask = makeBitsMask(bit_width); + for (size_t i = 0; i < packed_sze; ++i) { + oup[i] = 0; + } + // shift will in [0, 2 * width] + for (size_t i = 0, has_done = 0; i < numel; i += 1, has_done += bit_width) { + T real_data = inp[i] & mask; + size_t packed_index0 = i * bit_width / width; + size_t shft0 = has_done % width; + oup[packed_index0] |= (real_data << shft0); + if (shft0 + bit_width > width) { + size_t shft1 = width - shft0; + size_t packed_index1 = packed_index0 + 1; + oup[packed_index1] |= (real_data >> shft1); } - packed[i / elsze] = acc; } + return packed_sze; +} - return numel; +template +size_t UnzipArrayBit(absl::Span inp, size_t bit_width, + absl::Span oup) { + size_t width = sizeof(T) * 8; + SPU_ENFORCE(bit_width > 0 && bit_width <= width); + + size_t packed_sze = inp.size(); + size_t n = oup.size(); + size_t raw_sze = packed_sze * width / bit_width; + SPU_ENFORCE(n > 0 && n <= raw_sze); + + const T mask = makeBitsMask(bit_width); + for (size_t i = 0, has_done = 0; i < n; i += 1, has_done += bit_width) { + size_t packed_index0 = i * bit_width / width; + size_t shft0 = has_done % width; + oup[i] = (inp[packed_index0] >> shft0); + if (shft0 + bit_width > width) { + size_t shft1 = width - shft0; + size_t packed_index1 = packed_index0 + 1; + oup[i] |= (inp[packed_index1] << shft1); + } + oup[i] &= mask; + } + return n; } template -size_t UnpackU8Array(absl::Span input, absl::Span u8array) { - using UT = typename std::make_unsigned::type; +size_t PackU8Array(absl::Span u8array, absl::Span packed) { constexpr size_t elsze = sizeof(T); - const size_t numel = input.size(); const size_t nbytes = u8array.size(); - SPU_ENFORCE(CeilDiv(nbytes, elsze) >= numel); + const size_t numel = CeilDiv(nbytes, elsze); + + SPU_ENFORCE(packed.size() >= numel); - constexpr T mask = (static_cast(1) << 8) - 1; for (size_t i = 0; i < nbytes; i += elsze) { size_t this_batch = std::min(nbytes - i, elsze); - UT acc = static_cast(input[i / elsze]); + T acc{0}; for (size_t j = 0; j < this_batch; ++j) { - u8array[i + this_batch - 1 - j] = acc & mask; - acc >>= 8; + acc = (acc << 8) | u8array[i + j]; } + packed[i / elsze] = acc; } - return nbytes; + return numel; } uint8_t BoolToU8(absl::Span bits); diff --git a/libspu/mpc/cheetah/ot/util_test.cc b/libspu/mpc/cheetah/ot/util_test.cc index 8008b95a..6a7293a2 100644 --- a/libspu/mpc/cheetah/ot/util_test.cc +++ b/libspu/mpc/cheetah/ot/util_test.cc @@ -34,28 +34,27 @@ INSTANTIATE_TEST_SUITE_P( }); TEST_P(UtilTest, ZipArray) { - const int64_t n = 20; + const int64_t n = 200; const auto field = GetParam(); const size_t elsze = SizeOf(field); auto unzip = ring_zeros(field, {n}); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, "UT_ZipArray", [&]() { for (size_t bw : {1, 2, 4, 7, 15, 16}) { int64_t pack_load = elsze * 8 / bw; auto zip = ring_zeros(field, {(n + pack_load - 1) / pack_load}); auto array = ring_rand(field, {n}); - auto inp = xt_mutable_adapt(array); + auto inp = absl::MakeSpan(&array.at(0), array.numel()); auto mask = makeBitsMask(bw); - inp &= mask; + std::transform(inp.begin(), inp.end(), inp.data(), + [&](auto v) { return v & mask; }); - auto _zip = xt_mutable_adapt(zip); - auto _unzip = xt_mutable_adapt(unzip); - size_t zip_sze = ZipArray({inp.data(), inp.size()}, bw, - {_zip.data(), _zip.size()}); + auto _zip = absl::MakeSpan(&zip.at(0), zip.numel()); + auto _unzip = absl::MakeSpan(&unzip.at(0), unzip.numel()); + (void)ZipArray(inp, bw, _zip); - UnzipArray({_zip.data(), zip_sze}, bw, - {_unzip.data(), _unzip.size()}); + UnzipArray(_zip, bw, _unzip); for (size_t i = 0; i < n; ++i) { EXPECT_EQ(inp[i], _unzip[i]); @@ -64,27 +63,42 @@ TEST_P(UtilTest, ZipArray) { }); } -TEST_P(UtilTest, PackU8Array) { - const int64_t num_bytes = 223; +TEST_P(UtilTest, ZipArrayBit) { + const size_t n = 1000; const auto field = GetParam(); - const int64_t elsze = SizeOf(field); - std::uniform_int_distribution uniform(0, -1); - std::default_random_engine rdv; - std::vector u8array(num_bytes); - std::generate_n(u8array.data(), u8array.size(), - [&]() { return uniform(rdv); }); + auto unzip = ring_zeros(field, {n}); + + DISPATCH_ALL_FIELDS(field, "UT_ZipArrayBit", [&]() { + const size_t elsze = SizeOf(field); + for (size_t bw : {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) { + size_t width = elsze * 8; + size_t pack_sze = CeilDiv(bw * n, width); + + auto zip = ring_zeros(field, {(int)pack_sze}); + auto array = ring_rand(field, {n}); + auto mask = makeBitsMask(bw); + + auto inp = absl::MakeSpan(&array.at(0), array.numel()); + auto _zip = absl::MakeSpan(&zip.at(0), zip.numel()); + auto _unzip = absl::MakeSpan(&unzip.at(0), unzip.numel()); + pforeach(0, array.numel(), [&](int64_t i) { inp[i] &= mask; }); - auto packed = ring_zeros(field, {(num_bytes + elsze - 1) / elsze}); + size_t zip_sze = ZipArrayBit(inp, bw, _zip); + SPU_ENFORCE(zip_sze == pack_sze); - DISPATCH_ALL_FIELDS(field, "", [&]() { - auto xp = xt_mutable_adapt(packed); - PackU8Array(absl::MakeSpan(u8array), {xp.data(), xp.size()}); - std::vector _u8(num_bytes, -1); - UnpackU8Array({xp.data(), xp.size()}, absl::MakeSpan(_u8)); + if (((n * bw) % width) != 0) { + // add some noises + _zip[pack_sze - 1] |= (static_cast(1) << (width - 1)); + } + + UnzipArrayBit(_zip, bw, _unzip); - EXPECT_TRUE(std::memcmp(_u8.data(), u8array.data(), num_bytes) == 0); + for (size_t i = 0; i < n; ++i) { + EXPECT_EQ(inp[i], _unzip[i]); + } + } }); } -} // namespace spu::mpc::cheetah::test +} // namespace spu::mpc::cheetah::test \ No newline at end of file diff --git a/libspu/mpc/cheetah/protocol.cc b/libspu/mpc/cheetah/protocol.cc index a25147dc..149811cd 100644 --- a/libspu/mpc/cheetah/protocol.cc +++ b/libspu/mpc/cheetah/protocol.cc @@ -22,15 +22,14 @@ #include "libspu/mpc/cheetah/boolean.h" #include "libspu/mpc/cheetah/conversion.h" #include "libspu/mpc/cheetah/state.h" +#include "libspu/mpc/cheetah/type.h" #include "libspu/mpc/common/pv2k.h" -#include "libspu/mpc/semi2k/state.h" -#include "libspu/mpc/semi2k/type.h" namespace spu::mpc { void regCheetahProtocol(SPUContext* ctx, const std::shared_ptr& lctx) { - semi2k::registerTypes(); + cheetah::registerTypes(); // add communicator ctx->prot()->addState(lctx); @@ -64,7 +63,7 @@ void regCheetahProtocol(SPUContext* ctx, ctx->prot()->regKernel(); ctx->prot()->regKernel(); ctx->prot()->regKernel(); - ctx->prot()->regKernel(); + // ctx->prot()->regKernel(); ctx->prot()->regKernel(); ctx->prot()->regKernel(); ctx->prot()->regKernel(); @@ -90,7 +89,7 @@ void regCheetahProtocol(SPUContext* ctx, std::unique_ptr makeCheetahProtocol( const RuntimeConfig& conf, const std::shared_ptr& lctx) { - semi2k::registerTypes(); + cheetah::registerTypes(); auto ctx = std::make_unique(conf, lctx); diff --git a/libspu/mpc/cheetah/protocol_ab_test.cc b/libspu/mpc/cheetah/protocol_ab_test.cc index 83b5efda..e68d047a 100644 --- a/libspu/mpc/cheetah/protocol_ab_test.cc +++ b/libspu/mpc/cheetah/protocol_ab_test.cc @@ -50,6 +50,17 @@ INSTANTIATE_TEST_SUITE_P( std::get<2>(p.param)); }); +INSTANTIATE_TEST_SUITE_P( + Cheetah, ConversionTest, + testing::Combine(testing::Values(makeCheetahProtocol), // + testing::Values(makeConfig(FieldType::FM32), // + makeConfig(FieldType::FM64)), // + testing::Values(2)), // + [](const testing::TestParamInfo& p) { + return fmt::format("{}x{}", std::get<1>(p.param).field(), + std::get<2>(p.param)); + }); + GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ConversionTest); } // namespace spu::mpc::test diff --git a/libspu/mpc/cheetah/rlwe/BUILD.bazel b/libspu/mpc/cheetah/rlwe/BUILD.bazel index d07fda06..f2871f86 100644 --- a/libspu/mpc/cheetah/rlwe/BUILD.bazel +++ b/libspu/mpc/cheetah/rlwe/BUILD.bazel @@ -25,15 +25,21 @@ spu_cc_library( spu_cc_library( name = "lwe", - srcs = [ - "lwe_ct.cc", - "lwe_decryptor.cc", - "lwe_secret_key.cc", - ], - hdrs = ["lwe_decryptor.h"], + srcs = ["lwe_ct.cc"], + hdrs = ["lwe_ct.h"], deps = [":modswitch_helper"], ) +spu_cc_library( + name = "packlwes", + srcs = ["packlwes.cc"], + hdrs = ["packlwes.h"], + deps = [ + ":lwe", + ":rlwe_utils", + ], +) + spu_cc_library( name = "modswitch_helper", srcs = [ @@ -55,7 +61,6 @@ spu_cc_library( "utils.h", ], deps = [ - "//libspu/mpc/cheetah:array_ref", "//libspu/mpc/utils:ring_ops", "@com_github_microsoft_seal//:seal", ], @@ -69,3 +74,12 @@ spu_cc_test( "@com_github_xtensor_xtensor//:xtensor", ], ) + +spu_cc_test( + name = "packlwes_test", + srcs = ["packlwes_test.cc"], + deps = [ + ":modswitch_helper", + ":packlwes", + ], +) diff --git a/libspu/mpc/cheetah/rlwe/lwe_ct.cc b/libspu/mpc/cheetah/rlwe/lwe_ct.cc index 4beb913b..66e3a55f 100644 --- a/libspu/mpc/cheetah/rlwe/lwe_ct.cc +++ b/libspu/mpc/cheetah/rlwe/lwe_ct.cc @@ -11,6 +11,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +#include "libspu/mpc/cheetah/rlwe/lwe_ct.h" #include "seal/util/ntt.h" #include "seal/util/polyarithsmallmod.h" @@ -20,10 +21,14 @@ namespace spu::mpc::cheetah { -static size_t MaximumLazy(const seal::SEALContext &context) { - if (!context.parameters_set()) return 0; +size_t MaximumLazy(const seal::SEALContext &context) { + if (!context.parameters_set()) { + return 0; + } auto cntxt_dat = context.first_context_data(); - if (!cntxt_dat) return 0; + if (!cntxt_dat) { + return 0; + } const auto &modulus = cntxt_dat->parms().coeff_modulus(); constexpr int kBarrettLimit = 62; @@ -31,7 +36,11 @@ static size_t MaximumLazy(const seal::SEALContext &context) { for (const auto &mod : modulus) { nbits = std::max(mod.bit_count(), nbits); } - if (nbits >= kBarrettLimit) return 0; + + if (nbits >= kBarrettLimit) { + return 0; + } + // NOTE(lwj): not to lazy too much return 1UL << std::min(16, kBarrettLimit - nbits); } @@ -39,6 +48,10 @@ LWECt::LWECt() = default; LWECt::~LWECt() = default; +seal::parms_id_type LWECt::parms_id() const { return vec_.parms_id(); } + +seal::parms_id_type &LWECt::parms_id() { return vec_.parms_id(); } + LWECt::LWECt(const RLWECt &rlwe, size_t coeff_index, const seal::SEALContext &context) { SPU_ENFORCE(seal::is_metadata_valid_for(rlwe, context), @@ -447,4 +460,166 @@ void LWECt::load_members(const seal::SEALContext &context, std::istream &stream, std::swap(*this, tmp); } -} // namespace spu::mpc::cheetah +void LWECt::CastAsRLWE(const seal::SEALContext &context, uint64_t multiplier, + RLWECt *out) const { + SPU_ENFORCE(out != nullptr); + using namespace seal; + using namespace seal::util; + if (!IsValid()) { + out->release(); + return; + } + + SPU_ENFORCE(lazy_counter_ < 2, "invalid lazy_counter={}", lazy_counter_); + SPU_ENFORCE(multiplier > 0, "invalid multiplier={}", multiplier); + + auto work_cntxt = context.get_context_data(parms_id()); + SPU_ENFORCE(work_cntxt != nullptr); + + const auto &parms = work_cntxt->parms(); + const auto &modulus = parms.coeff_modulus(); + const auto *ntt_tables = work_cntxt->small_ntt_tables(); + size_t num_coeff = parms.poly_modulus_degree(); + size_t num_moduli = modulus.size(); + + out->resize(context, parms_id(), 2); + + auto *cnst_ptr = out->data(0); + auto *dst_ptr = out->data(1); + const auto *src_ptr = vec_.data(); + + for (size_t l = 0; l < num_moduli; ++l) { + uint64_t inv_multiplier; + if (multiplier == num_coeff) { + inv_multiplier = ntt_tables[l].inv_degree_modulo().operand; + } else { + // compute multiplier^{-1} mod p + SPU_ENFORCE( + try_invert_uint_mod(multiplier, modulus[l], inv_multiplier), + fmt::format("inverse mod for multiplier={} failed", multiplier)); + } + + MultiplyUIntModOperand fixed_mul; + fixed_mul.set(negate_uint_mod(inv_multiplier, modulus[l]), modulus[l]); + + // [a0, -a1, -a2, ..., -a_{n-1}] + dst_ptr[0] = multiply_uint_mod(src_ptr[0], inv_multiplier, modulus[l]); + multiply_poly_scalar_coeffmod(src_ptr + 1, num_coeff - 1, fixed_mul, + modulus[l], dst_ptr + 1); + // reverse [1, n) + std::reverse(dst_ptr + 1, dst_ptr + num_coeff); + + cnst_ptr[0] = multiply_uint_mod(cnst_term_[l], inv_multiplier, modulus[l]); + std::fill_n(cnst_ptr + 1, num_coeff - 1, 0); + + src_ptr += num_coeff; + cnst_ptr += num_coeff; + dst_ptr += num_coeff; + } + + out->is_ntt_form() = false; + out->scale() = 1.0; +} + +void PhantomLWECt::WrapIt(const RLWECt &ct, size_t coeff_index, + bool only_wrap_zero) { + SPU_ENFORCE(not ct.is_ntt_form() && ct.size() == 2 && + coeff_index < ct.poly_modulus_degree()); + only_wrap_zero_ = only_wrap_zero; + if (only_wrap_zero) { + SPU_ENFORCE(coeff_index == 0); + } + coeff_index_ = coeff_index; + pid_ = ct.parms_id(); + base_ = &ct; +} + +bool PhantomLWECt::IsValid() const { return base_ != nullptr; } + +size_t PhantomLWECt::poly_modulus_degree() const { + if (base_ != nullptr) { + return base_->poly_modulus_degree(); + } + return 0; +} + +size_t PhantomLWECt::coeff_modulus_size() const { + if (base_ != nullptr) { + return base_->coeff_modulus_size(); + } + return 0; +} + +void PhantomLWECt::CastAsRLWE(const seal::SEALContext &context, + uint64_t multiplier, RLWECt *out) const { + SPU_ENFORCE(out != nullptr); + if (!IsValid()) { + out->release(); + return; + } + + auto cntxt_data = context.get_context_data(parms_id()); + SPU_ENFORCE(cntxt_data != nullptr, "invalid pid for this context"); + + out->resize(context, parms_id(), 2); + const auto &modulus = cntxt_data->parms().coeff_modulus(); + const auto *ntt_tables = cntxt_data->small_ntt_tables(); + auto num_modulus = this->coeff_modulus_size(); + auto num_coeff = this->poly_modulus_degree(); + + const uint64_t *src_ptr = base_->data(1); + uint64_t *dst_ptr = out->data(1); + + std::fill_n(out->data(0), num_coeff * num_modulus, 0); + for (size_t l = 0; l < num_modulus; ++l) { + using namespace seal::util; + // multiply N^{-1} mod p to cancel out the multiplier + MultiplyUIntModOperand fixed_mul; + if (multiplier == num_coeff) { + fixed_mul = ntt_tables[l].inv_degree_modulo(); + } else { + // compute multiplier^{-1} mod p + uint64_t inv_multiplier; + SPU_ENFORCE( + try_invert_uint_mod(multiplier, modulus[l], inv_multiplier), + fmt::format("inverse mod for multiplier={} failed", multiplier)); + fixed_mul.set(negate_uint_mod(inv_multiplier, modulus[l]), modulus[l]); + } + + dst_ptr[0] = + multiply_uint_mod(src_ptr[coeff_index_], fixed_mul, modulus[l]); + + size_t offset = num_coeff - coeff_index_; + for (size_t i = 1; i < offset; ++i) { + size_t src_rlwe_idx = coeff_index_ + i; + dst_ptr[i] = + multiply_uint_mod(src_ptr[src_rlwe_idx], fixed_mul, modulus[l]); + } + + for (size_t i = offset; i < num_coeff; ++i) { + size_t src_rlwe_idx = i - offset; + dst_ptr[i] = multiply_uint_mod(modulus[l].value() - src_ptr[src_rlwe_idx], + fixed_mul, modulus[l]); + } + + if (coeff_index_ == 0 && only_wrap_zero_) { + auto ct0_ptr = base_->data(0) + l * num_coeff; + uint64_t acc = 0; + for (size_t i = 0; i < num_coeff; ++i) { + acc = add_uint_mod(ct0_ptr[i], acc, modulus[l]); + } + acc = + multiply_uint_mod(acc, ntt_tables[l].inv_degree_modulo(), modulus[l]); + acc = multiply_uint_mod(acc, fixed_mul, modulus[l]); + out->data(0)[l * num_coeff] = acc; + } else { + out->data(0)[l * num_coeff] = multiply_uint_mod( + base_->data(0)[l * num_coeff + coeff_index_], fixed_mul, modulus[l]); + } + + src_ptr += num_coeff; + dst_ptr += num_coeff; + } +} + +} // namespace spu::mpc::cheetah \ No newline at end of file diff --git a/libspu/mpc/cheetah/rlwe/lwe_ct.h b/libspu/mpc/cheetah/rlwe/lwe_ct.h new file mode 100644 index 00000000..15b8da6f --- /dev/null +++ b/libspu/mpc/cheetah/rlwe/lwe_ct.h @@ -0,0 +1,148 @@ +// Copyright 2022 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include "libspu/mpc/cheetah/rlwe/types.h" +namespace spu::mpc::cheetah { + +class LWECt { + public: + LWECt(); + + ~LWECt(); + + // RLWE(\sum_{i} a_iX^i), k -> LWE(a_k) + LWECt(const RLWECt &rlwe, size_t coeff_index, + const seal::SEALContext &context); + + LWECt &NegateInplace(const seal::SEALContext &context); + + LWECt &AddInplace(const LWECt &oth, const seal::SEALContext &context); + + LWECt &AddPlainInplace(const std::vector &plain, + const seal::SEALContext &context); + + LWECt &SubInplace(const LWECt &oth, const seal::SEALContext &context); + + LWECt &SubPlainInplace(const std::vector &plain, + const seal::SEALContext &context); + + LWECt &AddLazyInplace(const RLWECt &rlwe, size_t coeff_index, + const seal::SEALContext &context); + + LWECt &SubLazyInplace(const RLWECt &rlwe, size_t coeff_index, + const seal::SEALContext &context); + + // LWE(a), multiplier -> RLWE(multiplier * a + \sum_{i>0} a_iX^i) for some + // random \{a_i\} + void CastAsRLWE(const seal::SEALContext &context, uint64_t multiplier, + RLWECt *out) const; + + void Reduce(const seal::SEALContext &context); + + bool IsNeedReduce() const { return lazy_counter_ > 0; } + + bool IsValid() const { return poly_deg_ > 0; } + + seal::parms_id_type parms_id() const; + + seal::parms_id_type &parms_id(); + + size_t poly_modulus_degree() const { return poly_deg_; } + + size_t coeff_modulus_size() const { return cnst_term_.size(); } + + // Returns an upper bound on the size of the LWECt, as if it was written + // to an output stream. + size_t save_size(seal::compr_mode_type compr_mode = + seal::Serialization::compr_mode_default) const; + + // Saves the LWECt to a given memory location. The output is in binary + // format and not human-readable. + size_t save(seal::seal_byte *buffer, size_t size, + seal::compr_mode_type compr_mode = + seal::Serialization::compr_mode_default) const { + using namespace std::placeholders; + return seal::Serialization::Save(std::bind(&LWECt::save_members, this, _1), + save_size(seal::compr_mode_type::none), + buffer, size, compr_mode, false); + } + + // Loads an LWECt from an input stream overwriting the current LWECt. + // The loaded ciphertext is verified to be valid for the given SEALContext. + void load(const seal::SEALContext &context, const seal::seal_byte *buffer, + size_t size); + + // Loads an LWECt from a given memory location overwriting the current + // plaintext. No checking of the validity of the ciphertext data against + // encryption parameters is performed. This function should not be used + // unless the ciphertext comes from a fully trusted source. + void unsafe_load(const seal::SEALContext &context, + const seal::seal_byte *buffer, size_t size) { + using namespace std::placeholders; + seal::Serialization::Load( + std::bind(&LWECt::load_members, this, context, _1, _2), buffer, size, + false); + } + + private: + void save_members(std::ostream &stream) const; + + void load_members(const seal::SEALContext &context, std::istream &stream, + seal::SEALVersion version); + + uint64_t maximum_lazy_{0}; + uint64_t lazy_counter_{0}; + + size_t poly_deg_{0}; + std::vector cnst_term_; + RLWEPt vec_; +}; + +size_t MaximumLazy(const seal::SEALContext &context); + +// Just wrap an RLWE(m(X)) and view it as an LWE of the k-th coefficient, +// LWE(m_k) +class PhantomLWECt { + public: + PhantomLWECt() = default; + + ~PhantomLWECt() = default; + + void WrapIt(const RLWECt &ct, size_t coeff_index, + bool only_wrap_zero = false); + + seal::parms_id_type parms_id() const { return pid_; } + + seal::parms_id_type &parms_id() { return pid_; } + + size_t poly_modulus_degree() const; + + size_t coeff_modulus_size() const; + + bool IsValid() const; + + void CastAsRLWE(const seal::SEALContext &context, uint64_t multiplier, + RLWECt *out) const; + + private: + bool only_wrap_zero_ = false; + size_t coeff_index_ = 0; + seal::parms_id_type pid_ = seal::parms_id_zero; + const RLWECt *base_ = nullptr; +}; + +} // namespace spu::mpc::cheetah \ No newline at end of file diff --git a/libspu/mpc/cheetah/rlwe/modswitch_helper.cc b/libspu/mpc/cheetah/rlwe/modswitch_helper.cc index c3040d63..2bfe3237 100644 --- a/libspu/mpc/cheetah/rlwe/modswitch_helper.cc +++ b/libspu/mpc/cheetah/rlwe/modswitch_helper.cc @@ -88,7 +88,7 @@ struct ModulusSwitchHelper::Impl { } template - void ModulusUpAt(ArrayView src, size_t mod_idx, + void ModulusUpAt(NdArrayView src, size_t mod_idx, absl::Span out) const { using namespace seal::util; SPU_ENFORCE(sizeof(Scalar) < sizeof(uint128_t)); @@ -118,7 +118,7 @@ struct ModulusSwitchHelper::Impl { // NOTE(juhou): we need 256-bit to store the product `x * (Q mod t)` for x, t // \in [2^64, 2^128). - void ModulusUpAt(ArrayView src, size_t mod_idx, + void ModulusUpAt(NdArrayView src, size_t mod_idx, absl::Span out) const { using namespace seal::util; SPU_ENFORCE_EQ(sizeof(uint128_t) * 8, absl::bit_ceil(base_mod_bitlen_), @@ -161,7 +161,7 @@ struct ModulusSwitchHelper::Impl { } template - void CenteralizeAt(ArrayView src, size_t mod_idx, + void CenteralizeAt(NdArrayView src, size_t mod_idx, absl::Span out) const { using namespace seal::util; SPU_ENFORCE_EQ(sizeof(Scalar) * 8, absl::bit_ceil(base_mod_bitlen_), @@ -426,22 +426,23 @@ DEFINE_MODSWITCH_FUNS(uint128_t) #undef DEFINE_MODSWITCH_FUNS -void ModulusSwitchHelper::ModulusUpAt(const ArrayRef &src, size_t mod_idx, +void ModulusSwitchHelper::ModulusUpAt(const NdArrayRef &src, size_t mod_idx, absl::Span out) const { yacl::CheckNotNull(impl_.get()); const Type &eltype = src.eltype(); const size_t numel = src.numel(); SPU_ENFORCE_EQ(numel, out.size()); + SPU_ENFORCE(src.shape().size() == 1, "need 1D array"); SPU_ENFORCE(eltype.isa(), "source must be ring_type, got={}", eltype); const auto field = eltype.as()->field(); DISPATCH_ALL_FIELDS(field, "ModulusUpAt", [&]() { using ring2u = std::make_unsigned::type; - impl_->ModulusUpAt(ArrayView(src), mod_idx, out); + impl_->ModulusUpAt(NdArrayView(src), mod_idx, out); }); } -void ModulusSwitchHelper::CenteralizeAt(const ArrayRef &src, size_t mod_idx, +void ModulusSwitchHelper::CenteralizeAt(const NdArrayRef &src, size_t mod_idx, absl::Span out) const { yacl::CheckNotNull(impl_.get()); const Type &eltype = src.eltype(); @@ -449,37 +450,38 @@ void ModulusSwitchHelper::CenteralizeAt(const ArrayRef &src, size_t mod_idx, SPU_ENFORCE_EQ(numel, out.size()); SPU_ENFORCE(eltype.isa(), "source must be ring_type, got={}", eltype); const auto field = eltype.as()->field(); - DISPATCH_ALL_FIELDS(field, "", [&]() { + DISPATCH_ALL_FIELDS(field, "CenteralizeAt", [&]() { using ring2u = std::make_unsigned::type; - impl_->CenteralizeAt(ArrayView(src), mod_idx, out); + impl_->CenteralizeAt(NdArrayView(src), mod_idx, out); }); } -ArrayRef ModulusSwitchHelper::ModulusDownRNS( - FieldType field, absl::Span src) const { +NdArrayRef ModulusSwitchHelper::ModulusDownRNS( + FieldType field, const Shape &shape, absl::Span src) const { yacl::CheckNotNull(impl_.get()); size_t num_modulus = impl_->coeff_modulus_size(); - int64_t num_elt = src.size() / num_modulus; - SPU_ENFORCE_EQ(num_elt * num_modulus, src.size()); + int64_t numel = src.size() / num_modulus; + SPU_ENFORCE_EQ(numel, shape.numel()); + SPU_ENFORCE_EQ(numel * num_modulus, src.size()); - auto out = flatten(ring_zeros(field, {num_elt})); + auto out = ring_zeros(field, shape); ModulusDownRNS(src, out); return out; } void ModulusSwitchHelper::ModulusDownRNS(absl::Span src, - ArrayRef out) const { + NdArrayRef out) const { yacl::CheckNotNull(impl_.get()); auto eltype = out.eltype(); SPU_ENFORCE(eltype.isa(), "must be ring_type, got={}", eltype); auto field = eltype.as()->field(); - SPU_ENFORCE(out.isCompact()); + SPU_ENFORCE(out.isCompact(), "need compact output"); size_t num_modulus = impl_->coeff_modulus_size(); size_t num_elt = out.numel(); SPU_ENFORCE_EQ(num_elt * num_modulus, src.size()); - return DISPATCH_ALL_FIELDS(field, "", [&]() { + return DISPATCH_ALL_FIELDS(field, "ModulusDownRNS", [&]() { using ring2u = std::make_unsigned::type; absl::Span out_wrap(reinterpret_cast(out.data()), num_elt); diff --git a/libspu/mpc/cheetah/rlwe/modswitch_helper.h b/libspu/mpc/cheetah/rlwe/modswitch_helper.h index 33e1cbb5..53a554c8 100644 --- a/libspu/mpc/cheetah/rlwe/modswitch_helper.h +++ b/libspu/mpc/cheetah/rlwe/modswitch_helper.h @@ -21,7 +21,7 @@ #include "absl/types/span.h" #include "seal/context.h" -#include "libspu/mpc/cheetah/array_ref.h" +#include "libspu/core/ndarray_ref.h" namespace spu::mpc::cheetah { @@ -41,18 +41,19 @@ class ModulusSwitchHelper { uint32_t coeff_modulus_size() const; // Given x \in [0, 2^k), compute round(Q/2^k*x) \in [0, Q) - void ModulusUpAt(const ArrayRef& src, size_t mod_idx, + // Need 1D array + void ModulusUpAt(const NdArrayRef& src, size_t mod_idx, absl::Span out) const; // Cast [-2^{k-1}, 2^{k-1}) to [0, qj) - void CenteralizeAt(const ArrayRef& src, size_t mod_idx, + // Need 1D array + void CenteralizeAt(const NdArrayRef& src, size_t mod_idx, absl::Span out) const; - // Given x' in \ [0, Q), compute round(x'*2^k/Q) \in [0, 2^k) - ArrayRef ModulusDownRNS(FieldType field, - absl::Span src) const; + NdArrayRef ModulusDownRNS(FieldType field, const Shape& shape, + absl::Span src) const; - void ModulusDownRNS(absl::Span src, ArrayRef out) const; + void ModulusDownRNS(absl::Span src, NdArrayRef out) const; void ModulusDownRNS(absl::Span src, absl::Span out) const; diff --git a/libspu/mpc/cheetah/rlwe/modswitch_helper_test.cc b/libspu/mpc/cheetah/rlwe/modswitch_helper_test.cc index 8ae03d16..7ebc4e69 100644 --- a/libspu/mpc/cheetah/rlwe/modswitch_helper_test.cc +++ b/libspu/mpc/cheetah/rlwe/modswitch_helper_test.cc @@ -43,9 +43,9 @@ class RLWE2LWETest : public testing::TestWithParam { inline uint32_t FieldBitLen(FieldType f) const { return 8 * SizeOf(f); } - ArrayRef CPRNG(FieldType field, size_t size) { - return flatten( - ring_rand(field, {static_cast(size)}, seed_, &prng_counter_)); + NdArrayRef CPRNG(FieldType field, size_t size) { + return ring_rand(field, {static_cast(size)}, seed_, + &prng_counter_); } void UniformPoly(const seal::SEALContext &context, RLWEPt *pt) { @@ -105,9 +105,9 @@ INSTANTIATE_TEST_SUITE_P( TEST_P(RLWE2LWETest, ModulusSwitch_UpDown) { // r <- R_t - for (size_t stride : {1, 2, 3}) { - auto vec = CPRNG(field_, poly_deg); - auto _vec = vec.slice(3, vec.numel(), stride); + for (int64_t stride : {1, 2, 3}) { + auto vec = ring_rand(field_, {poly_deg}); + auto _vec = vec.slice({3}, {vec.numel()}, {stride}); // r' = Delta*r \in R_q RLWEPt pt; @@ -120,12 +120,11 @@ TEST_P(RLWE2LWETest, ModulusSwitch_UpDown) { } // e = round(r'/Delta) mod t \in R_t auto src = absl::MakeSpan(pt.data(), pt.coeff_count()); - auto cmp = flatten(ring_zeros(field_, {(int64_t)poly_deg})); - ms_helper_->ModulusDownRNS(src, cmp); + auto cmp = ms_helper_->ModulusDownRNS(field_, {(int64_t)poly_deg}, src); // check r =? e DISPATCH_ALL_FIELDS(field_, "", [&]() { - auto expected = ArrayView(_vec); - auto computed = ArrayView(cmp); + auto expected = NdArrayView(_vec); + auto computed = NdArrayView(cmp); for (int64_t i = 0; i < expected.numel(); ++i) { EXPECT_EQ(expected[i], computed[i]); } @@ -135,7 +134,7 @@ TEST_P(RLWE2LWETest, ModulusSwitch_UpDown) { TEST_P(RLWE2LWETest, ModulusSwitch_DownUp) { // a <- R_t - auto vec_a = CPRNG(field_, poly_deg); + auto vec_a = ring_rand(field_, {poly_deg}); RLWEPt poly0; poly0.resize(poly_deg * ms_helper_->coeff_modulus_size()); // a' = round(Delta*a) in R_q @@ -159,8 +158,8 @@ TEST_P(RLWE2LWETest, ModulusSwitch_DownUp) { // r' = round(r/Delta) mod t \in R_t // b' = round(b/Delta) mod t \in R_t - auto shr0 = flatten(ring_zeros(field_, {(int64_t)poly_deg})); - auto shr1 = flatten(ring_zeros(field_, {(int64_t)poly_deg})); + auto shr0 = ring_zeros(field_, {(int64_t)poly_deg}); + auto shr1 = ring_zeros(field_, {(int64_t)poly_deg}); { auto src = absl::MakeSpan(rnd.data(), rnd.coeff_count()); ms_helper_->ModulusDownRNS(src, shr0); @@ -170,9 +169,9 @@ TEST_P(RLWE2LWETest, ModulusSwitch_DownUp) { } DISPATCH_ALL_FIELDS(field_, "", [&]() { - auto expected = ArrayView(vec_a); - auto computed0 = ArrayView(shr0); - auto computed1 = ArrayView(shr1); + auto expected = NdArrayView(vec_a); + auto computed0 = NdArrayView(shr0); + auto computed1 = NdArrayView(shr1); for (size_t i = 0; i < poly_deg; ++i) { auto cmp = computed0[i] + computed1[i]; diff --git a/libspu/mpc/cheetah/rlwe/packlwes.cc b/libspu/mpc/cheetah/rlwe/packlwes.cc new file mode 100644 index 00000000..c4884bab --- /dev/null +++ b/libspu/mpc/cheetah/rlwe/packlwes.cc @@ -0,0 +1,471 @@ +// Copyright 2022 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "libspu/mpc/cheetah/rlwe/packlwes.h" + +#include + +#include "seal/ciphertext.h" +#include "seal/evaluator.h" +#include "seal/galoiskeys.h" +#include "seal/keygenerator.h" +#include "seal/plaintext.h" +#include "seal/util/ntt.h" +#include "seal/util/polyarithsmallmod.h" +#include "seal/util/rlwe.h" +#include "seal/valcheck.h" +#include "yacl/utils/parallel.h" + +#include "libspu/core/parallel_utils.h" +#include "libspu/core/prelude.h" +#include "libspu/mpc/cheetah/rlwe/lwe_ct.h" +#include "libspu/mpc/cheetah/rlwe/utils.h" + +namespace spu::mpc::cheetah { + +static void NegacyclicRightShiftInplace(RLWECt &ct, size_t shift, + const seal::SEALContext &context); + +static size_t calculateWorkLoad(size_t num_jobs, size_t num_cores = 0) { + if (num_cores == 0) { + num_cores = spu::getNumberOfProc(); + } + return (num_jobs + num_cores - 1) / num_cores; +} + +PackingHelper::PackingHelper(size_t gap, const seal::GaloisKeys &galois_keys, + const seal::SEALContext &gk_context, + const seal::SEALContext &context) + : gap_(gap), + galois_keys_(galois_keys), + gk_context_(gk_context), + context_(context) { + SPU_ENFORCE(gk_context_.parameters_set()); + SPU_ENFORCE(seal::is_metadata_valid_for(galois_keys, gk_context)); + SPU_ENFORCE(context_.parameters_set()); + SPU_ENFORCE(gap > 0 && absl::has_single_bit(gap), "invalid gap={}", gap); + + // NOTE(lwj): dirty hack on SEAL's parms_id + if (context.key_parms_id() != gk_context_.key_parms_id()) { + SPU_ENFORCE_GT(context_.first_context_data()->chain_index(), + gk_context_.first_context_data()->chain_index()); + } + + auto n = gk_context.key_context_data()->parms().poly_modulus_degree(); + + size_t ks_level = absl::bit_width(gap) - 1; + for (size_t i = 0; i < ks_level; ++i) { + uint32_t galois = (n / (1 << i)) + 1; + SPU_ENFORCE(galois_keys.has_key(galois), "missing galois={}", galois); + } + + // pre-compute gap^{-1} mod Q + auto cntxt = context_.first_context_data(); + const auto &modulus = cntxt->parms().coeff_modulus(); + size_t num_modulus = modulus.size(); + inv_gap_.resize(num_modulus); + for (size_t i = 0; i < modulus.size(); ++i) { + uint64_t s = seal::util::barrett_reduce_64(gap, modulus[i]); + uint64_t _inv; + SPU_ENFORCE(seal::util::try_invert_uint_mod(s, modulus[i], _inv), + "failed to compute {}^{-1} mod {}", gap, modulus[i].value()); + inv_gap_[i].set(_inv, modulus[i]); + } +} + +void PackingHelper::MultiplyFixedScalarInplace(RLWECt &ct) const { + auto cntxt = context_.get_context_data(ct.parms_id()); + SPU_ENFORCE(cntxt != nullptr, "invalid ct"); + const auto &modulus = cntxt->parms().coeff_modulus(); + size_t num_modulus = ct.coeff_modulus_size(); + size_t num_coeff = ct.poly_modulus_degree(); + SPU_ENFORCE(num_modulus <= inv_gap_.size(), "invalid ct"); + + for (size_t k = 0; k < ct.size(); ++k) { + uint64_t *dst_ptr = ct.data(k); + for (size_t l = 0; l < num_modulus; ++l) { + seal::util::multiply_poly_scalar_coeffmod(dst_ptr, num_coeff, inv_gap_[l], + modulus.at(l), dst_ptr); + dst_ptr += num_coeff; + } + } +} + +void PackingHelper::PackingWithModulusDrop(absl::Span rlwes, + RLWECt &packed) const { + if (rlwes.empty()) { + packed.release(); + return; + } + SPU_ENFORCE(rlwes.size() <= gap_); + + auto pid = rlwes[0].parms_id(); + for (auto &rlwe : rlwes) { + if (rlwe.size() == 0) { + continue; + } + SPU_ENFORCE(rlwe.size() == 2); + SPU_ENFORCE(pid == rlwe.parms_id()); + } + + doPackingRLWEs(rlwes, packed); +} + +void PackingHelper::doPackingRLWEs(absl::Span rlwes, + RLWECt &out) const { + auto cntxt = context_.first_context_data(); + + int64_t poly_degree = cntxt->parms().poly_modulus_degree(); + int64_t num_ct = rlwes.size(); + + SPU_ENFORCE(num_ct > 0 && num_ct <= (int)gap_, + fmt::format("invalid #rlwes = {} for gap = {}", num_ct, gap_)); + + size_t modulus_for_keyswitch = + gk_context_.first_context_data()->chain_index() + 1; + + yacl::parallel_for( + 0, num_ct, calculateWorkLoad(num_ct), [&](int64_t bgn, int64_t end) { + for (int64_t i = bgn; i < end; ++i) { + InvNttInplace(rlwes[i], context_, true); + // multiply gap^{-1} mod Q + MultiplyFixedScalarInplace(rlwes[i]); + // drop some modulus aiming a lighter KeySwitch + ModulusSwtichInplace(rlwes[i], modulus_for_keyswitch, context_); + // change pid to galois_context for KS + rlwes[i].parms_id() = gk_context_.first_parms_id(); + } + }); + + // FFT-like method to merge RLWEs into one RLWE. + seal::Evaluator evaluator(gk_context_); + const int64_t logn = absl::bit_width(gap_) - 1; + for (int64_t k = logn; k >= 1; --k) { + int64_t h = 1 << (k - 1); + yacl::parallel_for( + 0, h, calculateWorkLoad(h), [&](int64_t bgn, int64_t end) { + RLWECt dummy; // zero-padding with zero RLWE + for (int64_t i = bgn; i < end; ++i) { + // E' <- E + X^k*O + Auto(E - X^k*O, k') + RLWECt &ct_even = i < num_ct ? rlwes[i] : dummy; + RLWECt &ct_odd = i + h < num_ct ? rlwes[i + h] : dummy; + + bool is_odd_empty = ct_odd.size() == 0; + bool is_even_empty = ct_even.size() == 0; + if (is_even_empty && is_odd_empty) { + ct_even.release(); + continue; + } + + NegacyclicRightShiftInplace(ct_odd, h, gk_context_); + + if (!is_even_empty) { + seal::Ciphertext tmp = ct_even; + if (!is_odd_empty) { + // E - X^k*O + // E + X^k*O + CATCH_SEAL_ERROR(evaluator.sub_inplace(ct_even, ct_odd)); + CATCH_SEAL_ERROR(evaluator.add_inplace(tmp, ct_odd)); + } + + CATCH_SEAL_ERROR(evaluator.apply_galois_inplace( + ct_even, poly_degree / h + 1, galois_keys_)); + CATCH_SEAL_ERROR(evaluator.add_inplace(ct_even, tmp)); + } else { + evaluator.negate(ct_odd, ct_even); + CATCH_SEAL_ERROR(evaluator.apply_galois_inplace( + ct_even, poly_degree / h + 1, galois_keys_)); + CATCH_SEAL_ERROR(evaluator.add_inplace(ct_even, ct_odd)); + } + } + }); + } + + SPU_ENFORCE(rlwes[0].size() > 0, fmt::format("all empty RLWEs are invalid")); + out = rlwes[0]; + + out.parms_id() = [&]() -> seal::parms_id_type { + auto cntxt = context_.first_context_data(); + while ((cntxt->chain_index() + 1) > modulus_for_keyswitch) { + cntxt = cntxt->next_context_data(); + } + return cntxt->parms_id(); + }(); +} + +void GenerateGaloisKeyForPacking(const seal::SEALContext &context, + const RLWESecretKey &key, bool save_seed, + GaloisKeys *out) { + SPU_ENFORCE(out != nullptr); + SPU_ENFORCE(context.parameters_set()); + SPU_ENFORCE(seal::is_metadata_valid_for(key, context)); + + size_t N = context.key_context_data()->parms().poly_modulus_degree(); + size_t logN = absl::bit_width(N) - 1; + std::vector galois_elt; + for (uint32_t i = 1; i <= logN; i++) { + galois_elt.push_back((1u << i) + 1); + } + + seal::KeyGenerator keygen(context, key); + + if (save_seed) { + auto gk = keygen.create_galois_keys(galois_elt); + *out = gk.obj(); + } else { + keygen.create_galois_keys(galois_elt, *out); + } +} + +void NegacyclicRightShiftInplace(RLWECt &ct, size_t shift, + const seal::SEALContext &context) { + if (shift == 0 || ct.size() == 0) { + // nothing to do + return; + } + + auto cntxt = context.get_context_data(ct.parms_id()); + SPU_ENFORCE(cntxt != nullptr, "invalid ct"); + SPU_ENFORCE(not ct.is_ntt_form(), "need non-ntt ct for negacyclic shift"); + + size_t num_coeff = ct.poly_modulus_degree(); + SPU_ENFORCE(shift < num_coeff); + + std::vector tmp(shift); + // i < N - s ai*X^i -> ai*X^{i + s} + // i >= N - s ai*X^i -> -ai*X^{(i + s) % N} + const auto &modulus = cntxt->parms().coeff_modulus(); + for (size_t k = 0; k < ct.size(); ++k) { + uint64_t *dst_ptr = ct.data(k); + + for (const auto &prime : modulus) { + // save [N-s, N) + std::copy_n(dst_ptr + num_coeff - shift, shift, tmp.data()); + + // X^i for i \in [0, N-s) + for (size_t i = num_coeff - shift; i > 0; --i) { + dst_ptr[i - 1 + shift] = dst_ptr[i - 1]; + } + + // i \n [N-s, N) + for (size_t i = 0; i < shift; ++i) { + dst_ptr[i] = seal::util::negate_uint_mod(tmp[i], prime); + } + + dst_ptr += num_coeff; + } + } +} + +static void doPackingLWEs(absl::Span rlwes, const GaloisKeys &galois, + const seal::SEALContext &context, RLWECt *out, + bool apply_trace = false) { + SPU_ENFORCE(out != nullptr); + SPU_ENFORCE(context.parameters_set()); + SPU_ENFORCE(seal::is_metadata_valid_for(galois, context)); + + auto cntxt = context.first_context_data(); + + size_t poly_degree = cntxt->parms().poly_modulus_degree(); + size_t num_ct = rlwes.size(); + + SPU_ENFORCE( + num_ct <= poly_degree && absl::has_single_bit(num_ct), + fmt::format("invalid #rlwes = {} for degree = {}", num_ct, poly_degree)); + + // FFT-like method to merge RLWEs into one RLWE. + seal::Evaluator evaluator(context); + size_t depth = 1; + while (depth <= num_ct) { + size_t n = num_ct / depth; + size_t h = n / 2; + depth <<= 1; + + auto merge_callback = [&](int64_t start, int64_t end) { + using namespace seal::util; + for (int64_t i = start; i < end; ++i) { + RLWECt &ct_even = rlwes[i]; + RLWECt &ct_odd = rlwes[i + h]; + + bool is_odd_empty = ct_odd.size() == 0; + bool is_even_empty = ct_even.size() == 0; + if (is_even_empty && is_odd_empty) { + ct_even.release(); + continue; + } + + // GS-style butterfly + // E' <- E + X^k*O + Auto(E - X^k*O, k') + // O' <- E + X^k*O + Auto(E + X^k*O, k') + if (!is_odd_empty) { + NegacyclicRightShiftInplace(ct_odd, poly_degree / depth, context); + } + + if (!is_even_empty) { + RLWECt tmp = ct_even; + if (!is_odd_empty) { + CATCH_SEAL_ERROR(evaluator.sub_inplace(ct_even, ct_odd)); + CATCH_SEAL_ERROR(evaluator.add_inplace(tmp, ct_odd)); + } + CATCH_SEAL_ERROR(evaluator.apply_galois_inplace( + ct_even, static_cast(depth + 1), galois)); + CATCH_SEAL_ERROR(evaluator.add_inplace(ct_even, tmp)); + } else { + evaluator.negate(ct_odd, ct_even); + CATCH_SEAL_ERROR(evaluator.apply_galois_inplace( + ct_even, static_cast(depth + 1), galois)); + CATCH_SEAL_ERROR(evaluator.add_inplace(ct_even, ct_odd)); + } + } + }; + + if (h > 0) { + yacl::parallel_for(0, h, calculateWorkLoad(h), merge_callback); + } + } + + SPU_ENFORCE(rlwes[0].size() > 0, fmt::format("all empty LWes are invalid")); + *out = rlwes[0]; + out->is_ntt_form() = false; + out->scale() = 1.; + if (not apply_trace) { + return; + } + + // Step 2 to remove the extra factor (i.e., N/num_lwes) from Step 2. + size_t log2N = absl::bit_width(poly_degree) - 1; + size_t log2Nn = absl::bit_width(poly_degree / num_ct) - 1; + + for (size_t k = 1; k <= log2Nn; ++k) { + RLWECt tmp{*out}; + uint32_t exp = static_cast((1UL << (log2N - k + 1)) + 1); + CATCH_SEAL_ERROR(evaluator.apply_galois_inplace(tmp, exp, galois)); + evaluator.add_inplace(*out, tmp); + } +} + +template +static void doPackLWEs(absl::Span lwes, const GaloisKeys &galois, + const seal::SEALContext &context, RLWECt *out) { + SPU_ENFORCE(out != nullptr); + SPU_ENFORCE(context.parameters_set()); + SPU_ENFORCE(seal::is_metadata_valid_for(galois, context)); + + auto cntxt = context.first_context_data(); + + size_t poly_degree = cntxt->parms().poly_modulus_degree(); + size_t num_lwes = lwes.size(); + + SPU_ENFORCE( + num_lwes <= poly_degree && absl::has_single_bit(num_lwes), + fmt::format("invalid #lwes = {} for degree = {}", num_lwes, poly_degree)); + + // Step 1: cast all LWEs to RLWEs + std::vector rlwes(num_lwes); + yacl::parallel_for(0, num_lwes, calculateWorkLoad(num_lwes), + [&](size_t start, size_t end) { + for (size_t i = start; i < end; ++i) { + lwes[i].CastAsRLWE(context, poly_degree, &rlwes[i]); + } + }); + + doPackingLWEs(absl::MakeSpan(rlwes), galois, context, out, true); +} + +template +static bool IsValidLWEArray(absl::Span lwes, + size_t *poly_degree) { + size_t n = lwes.size(); + if (n == 0) { + return false; + } + + size_t N = 0; + seal::parms_id_type pid; + + for (size_t i = 0; i < n; ++i) { + if (not lwes[i].IsValid()) { + // skip invalid LWEs + continue; + } + + if (N == 0) { + N = lwes[i].poly_modulus_degree(); + pid = lwes[i].parms_id(); + } else if (lwes[i].poly_modulus_degree() != N || + lwes[i].parms_id() != pid) { + // parms mismatch + return false; + } + } + + // all LWEs are invalid + if (N == 0 || pid == seal::parms_id_zero) { + return false; + } + + if (poly_degree) { + *poly_degree = N; + } + + return true; +} + +size_t PackLWEs(absl::Span lwes, const GaloisKeys &galois, + const seal::SEALContext &context, absl::Span rlwes) { + size_t n = lwes.size(); + size_t m = rlwes.size(); + size_t poly_degree{0}; + + SPU_ENFORCE(IsValidLWEArray(lwes, &poly_degree), + "LWE.length mismatch the poly degree"); + size_t out_sze = (n + poly_degree - 1) / poly_degree; + + SPU_ENFORCE(out_sze <= m, + fmt::format("expect >= {} RLWEs but got={}", out_sze, m)); + + for (size_t o = 0; o < out_sze; ++o) { + size_t bgn = o * poly_degree; + size_t end = std::min(n, bgn + poly_degree); + size_t this_batch = end - bgn; + + doPackLWEs(lwes.subspan(bgn, this_batch), galois, context, &rlwes[o]); + + SPU_ENFORCE(not rlwes[o].is_transparent(), ""); + } + return out_sze; +} + +size_t PackLWEs(absl::Span lwes, const GaloisKeys &galois, + const seal::SEALContext &context, absl::Span rlwes) { + size_t n = lwes.size(); + size_t m = rlwes.size(); + size_t poly_degree{0}; + SPU_ENFORCE(IsValidLWEArray(lwes, &poly_degree), + "LWE.length mismatch the poly degree = {}", poly_degree); + + size_t out_sze = (n + poly_degree - 1) / poly_degree; + SPU_ENFORCE(out_sze <= m, + fmt::format("expect >= {} RLWEs but got={}", out_sze, m)); + + for (size_t o = 0; o < out_sze; ++o) { + size_t bgn = o * poly_degree; + size_t end = std::min(n, bgn + poly_degree); + size_t this_batch = end - bgn; + doPackLWEs(lwes.subspan(bgn, this_batch), galois, context, &rlwes[o]); + SPU_ENFORCE(not rlwes[o].is_transparent(), ""); + } + return out_sze; +} + +} // namespace spu::mpc::cheetah \ No newline at end of file diff --git a/libspu/mpc/cheetah/rlwe/packlwes.h b/libspu/mpc/cheetah/rlwe/packlwes.h new file mode 100644 index 00000000..c9337cfb --- /dev/null +++ b/libspu/mpc/cheetah/rlwe/packlwes.h @@ -0,0 +1,56 @@ +// Copyright 2022 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "absl/types/span.h" + +#include "libspu/mpc/cheetah/rlwe/types.h" + +namespace spu::mpc::cheetah { + +// REF: Efficient Homomorphic Conversion Between (Ring) LWE Ciphertexts +// https://eprint.iacr.org/2020/015.pdf +void GenerateGaloisKeyForPacking(const seal::SEALContext &context, + const RLWESecretKey &key, bool save_seed, + GaloisKeys *out); +class PackingHelper { + public: + PackingHelper(size_t gap, const seal::GaloisKeys &galois_keys, + const seal::SEALContext &gk_context, + const seal::SEALContext &context); + + // require ct_array.size() == gap + void PackingWithModulusDrop(absl::Span rlwes, RLWECt &packed) const; + + private: + void MultiplyFixedScalarInplace(RLWECt &ct) const; + + void doPackingRLWEs(absl::Span rlwes, RLWECt &out) const; + + size_t gap_; + const seal::GaloisKeys &galois_keys_; + const seal::SEALContext &gk_context_; + const seal::SEALContext &context_; + + std::vector inv_gap_; +}; + +// lwes[0, N) -> RLWE[0], lwe[N, 2N) -> RLWE[1] .... +// Return the number of output RLWE ciphertexts. +size_t PackLWEs(absl::Span lwes, const GaloisKeys &galois, + const seal::SEALContext &context, absl::Span rlwes); + +size_t PackLWEs(absl::Span lwes, const GaloisKeys &galois, + const seal::SEALContext &context, absl::Span rlwes); +} // namespace spu::mpc::cheetah \ No newline at end of file diff --git a/libspu/mpc/cheetah/rlwe/packlwes_test.cc b/libspu/mpc/cheetah/rlwe/packlwes_test.cc new file mode 100644 index 00000000..a5a57253 --- /dev/null +++ b/libspu/mpc/cheetah/rlwe/packlwes_test.cc @@ -0,0 +1,360 @@ +#include "libspu/mpc/cheetah/rlwe/packlwes.h" + +#include + +#include "gtest/gtest.h" +#include "seal/seal.h" + +#include "libspu/core/prelude.h" +#include "libspu/core/type_util.h" +#include "libspu/mpc/cheetah/rlwe/lwe_ct.h" +#include "libspu/mpc/cheetah/rlwe/modswitch_helper.h" +#include "libspu/mpc/cheetah/rlwe/types.h" +#include "libspu/mpc/cheetah/rlwe/utils.h" +#include "libspu/mpc/utils/ring_ops.h" + +namespace spu::mpc::cheetah::test { +class VectorEncoder { + public: + explicit VectorEncoder(const seal::SEALContext &context, + const ModulusSwitchHelper &msh); + void Forward(const NdArrayRef &vec, RLWEPt *out, + bool scale_delta = true) const; + + void Backward(const NdArrayRef &vec, RLWEPt *out, + bool scale_delta = false) const; + + const ModulusSwitchHelper &ms_helper() const { return *msh_; } + + size_t poly_degree() const { return poly_deg_; } + + private: + size_t poly_deg_{0}; + std::shared_ptr msh_; +}; + +class PackLWEsTest + : public testing::TestWithParam> { + protected: + static constexpr size_t poly_N = 8192; + static constexpr size_t ring_len = 64; + + std::shared_ptr N_context_; + + std::shared_ptr N_rlwe_sk_; + + std::shared_ptr galois_; + std::shared_ptr N_encoder_; + std::shared_ptr N_ms_helper_; + + void SetUp() override { + std::vector modulus_bits; + modulus_bits = {55, 54, 50}; + + auto modulus = seal::CoeffModulus::Create(poly_N, modulus_bits); + + seal::EncryptionParameters parms(seal::scheme_type::ckks); + parms.set_use_special_prime(true); + parms.set_poly_modulus_degree(poly_N); + parms.set_coeff_modulus(modulus); + N_context_ = std::make_shared( + parms, true, seal::sec_level_type::none); + + auto m = modulus; + auto p = parms; + m.pop_back(); + p.set_coeff_modulus(m); + p.set_use_special_prime(false); + seal::SEALContext N_ms_cntxt(p, false, seal::sec_level_type::none); + N_ms_helper_ = std::make_shared(N_ms_cntxt, ring_len); + + N_encoder_ = std::make_shared(*N_context_, *N_ms_helper_); + + seal::KeyGenerator keygen(*N_context_); + N_rlwe_sk_ = std::make_shared(keygen.secret_key()); + galois_ = std::make_shared(); + GenerateGaloisKeyForPacking(*N_context_, *N_rlwe_sk_, /*seed*/ false, + galois_.get()); + } +}; + +INSTANTIATE_TEST_SUITE_P( + Cheetah, PackLWEsTest, + testing::Combine(testing::Values(FieldType::FM64), + testing::Values(1024UL, 2048UL, 4096UL, 8192UL)), + [](const testing::TestParamInfo &p) { + return fmt::format("{}n{}", std::get<0>(p.param), std::get<1>(p.param)); + }); + +TEST_P(PackLWEsTest, PackRLWEs) { + auto field = std::get<0>(GetParam()); // FM64 + size_t num_rlwes = std::get<1>(GetParam()) / 128; + SPU_ENFORCE(num_rlwes <= poly_N); + + seal::Encryptor encryptor(*N_context_, *N_rlwe_sk_); + seal::Decryptor decryptor(*N_context_, *N_rlwe_sk_); + + using scalar_t = uint64_t; + std::vector arrays(num_rlwes); + std::vector rlwes(num_rlwes); + for (size_t i = 0; i < num_rlwes; ++i) { + arrays[i] = ring_rand(field, {poly_N}); + RLWEPt pt; + N_encoder_->Forward(arrays[i], &pt, true); + NttInplace(pt, *N_context_); + CATCH_SEAL_ERROR(encryptor.encrypt_symmetric(pt, rlwes[i])); + InvNttInplace(rlwes[i], *N_context_); + } + + PackingHelper ph(num_rlwes, *galois_, *N_context_, *N_context_); + + RLWECt packed; + ph.PackingWithModulusDrop(absl::MakeSpan(rlwes), packed); + + EXPECT_EQ(packed.size(), 2); + EXPECT_FALSE(packed.is_ntt_form()); + EXPECT_EQ(packed.poly_modulus_degree(), poly_N); + EXPECT_EQ(packed.coeff_modulus_size(), N_ms_helper_->coeff_modulus_size()); + + NttInplace(packed, *N_context_); + RLWEPt dec; + decryptor.decrypt(packed, dec); + InvNttInplace(dec, *N_context_); + + std::vector coefficients(poly_N); + N_ms_helper_->ModulusDownRNS(absl::MakeSpan(dec.data(), dec.coeff_count()), + absl::MakeSpan(coefficients)); + + for (size_t i = 0; i < poly_N; i += num_rlwes) { + size_t offset = i / num_rlwes; + for (size_t j = 0; j < num_rlwes; ++j) { + NdArrayView expected(arrays[j]); + EXPECT_EQ(expected[offset * num_rlwes], coefficients[i + j]); + } + } +} + +TEST_P(PackLWEsTest, Basic) { + auto field = std::get<0>(GetParam()); // FM64 + size_t num_lwes = std::get<1>(GetParam()) / 128; + SPU_ENFORCE(num_lwes <= poly_N); + + seal::Encryptor encryptor(*N_context_, *N_rlwe_sk_); + seal::Evaluator N_evaluator(*N_context_); + seal::Decryptor decryptor(*N_context_, *N_rlwe_sk_); + + using scalar_t = uint64_t; + NdArrayRef array = ring_rand(field, {poly_N}); + NdArrayView _array(array); + + RLWEPt pt; + N_encoder_->Forward(array, &pt, true); + NttInplace(pt, *N_context_); + + RLWECt ct; + CATCH_SEAL_ERROR(encryptor.encrypt_symmetric(pt, ct)); + if (ct.is_ntt_form()) { + CATCH_SEAL_ERROR(N_evaluator.transform_from_ntt_inplace(ct)); + } + + // Perform some computation on LWEs and results at `num_lwes' LWEs. + size_t n_stride = poly_N / num_lwes; + std::vector n_lwes(num_lwes); + std::vector expects(num_lwes, 0); + + for (size_t i = 0, j = 0; i < poly_N; i += n_stride, ++j) { + for (size_t k = 0; k < n_stride; ++k) { + n_lwes[j].AddLazyInplace(ct, i + k, *N_context_); + expects[j] += _array[i + k]; + } + n_lwes[j].Reduce(*N_context_); + } + + // re-randomize to avoid transparent error + RLWECt zero; + encryptor.encrypt_zero_symmetric(zero); + N_evaluator.transform_from_ntt_inplace(zero); + for (size_t j = 0; j < num_lwes; ++j) { + n_lwes[j].AddLazyInplace(zero, j, *N_context_); + n_lwes[j].Reduce(*N_context_); + } + + RLWECt packed; + PackLWEs(absl::MakeSpan(n_lwes), *galois_, *N_context_, {&packed, 1}); + + EXPECT_EQ(packed.size(), 2); + EXPECT_FALSE(packed.is_ntt_form()); + EXPECT_EQ(packed.poly_modulus_degree(), poly_N); + EXPECT_EQ(packed.coeff_modulus_size(), N_ms_helper_->coeff_modulus_size()); + + N_evaluator.transform_to_ntt_inplace(packed); + RLWEPt dec; + decryptor.decrypt(packed, dec); + InvNttInplace(dec, *N_context_); + + std::vector coefficients(poly_N); + N_ms_helper_->ModulusDownRNS(absl::MakeSpan(dec.data(), dec.coeff_count()), + absl::MakeSpan(coefficients)); + size_t N_stride = poly_N / num_lwes; + for (size_t i = 0, j = 0; i < poly_N; i += N_stride, ++j) { + EXPECT_EQ(expects[j], coefficients[i]); + } +} + +TEST_P(PackLWEsTest, Phantom) { + seal::Encryptor encryptor(*N_context_, *N_rlwe_sk_); + seal::Evaluator N_evaluator(*N_context_); + seal::Decryptor decryptor(*N_context_, *N_rlwe_sk_); + + auto field = std::get<0>(GetParam()); // FM64 + size_t num_lwes = std::get<1>(GetParam()); // n + SPU_ENFORCE(num_lwes <= poly_N); + + using scalar_t = uint64_t; + NdArrayRef array = ring_rand(field, {poly_N}); + NdArrayView _array(array); + + RLWEPt pt; + N_encoder_->Forward(array, &pt, true); + NttInplace(pt, *N_context_); + + RLWECt rlwe0, rlwe1; + CATCH_SEAL_ERROR(encryptor.encrypt_symmetric(pt, rlwe0)); + CATCH_SEAL_ERROR(encryptor.encrypt_symmetric(pt, rlwe1)); + if (rlwe0.is_ntt_form()) { + CATCH_SEAL_ERROR(N_evaluator.transform_from_ntt_inplace(rlwe0)); + CATCH_SEAL_ERROR(N_evaluator.transform_from_ntt_inplace(rlwe1)); + } + + // Perform some computation on LWEs and results at `num_lwes' LWEs. + size_t n_stride = poly_N / num_lwes; + std::vector n_lwes(num_lwes); + std::vector expects(num_lwes, 0); + + for (size_t i = 0, j = 0; i < poly_N; i += n_stride, ++j) { + expects[j] = _array[i]; + if (1 == (j & 1)) { + n_lwes[j].WrapIt(rlwe0, i); + } else { + n_lwes[j].WrapIt(rlwe1, i); + } + } + + RLWECt packed; + PackLWEs(absl::MakeSpan(n_lwes), *galois_, *N_context_, {&packed, 1}); + + EXPECT_EQ(packed.size(), 2); + EXPECT_FALSE(packed.is_ntt_form()); + EXPECT_EQ(packed.poly_modulus_degree(), poly_N); + EXPECT_EQ(packed.coeff_modulus_size(), N_ms_helper_->coeff_modulus_size()); + + N_evaluator.transform_to_ntt_inplace(packed); + RLWEPt dec; + decryptor.decrypt(packed, dec); + InvNttInplace(dec, *N_context_); + + std::vector coefficients(poly_N); + N_ms_helper_->ModulusDownRNS(absl::MakeSpan(dec.data(), dec.coeff_count()), + absl::MakeSpan(coefficients)); + size_t N_stride = poly_N / num_lwes; + for (size_t i = 0, j = 0; i < poly_N; i += N_stride, ++j) { + ASSERT_EQ(expects[j], coefficients[i]); + } +} + +VectorEncoder::VectorEncoder(const seal::SEALContext &context, + const ModulusSwitchHelper &msh) { + SPU_ENFORCE(context.parameters_set()); + auto pid0 = context.first_parms_id(); + auto pid1 = msh.parms_id(); + SPU_ENFORCE_EQ(0, std::memcmp(&pid0, &pid1, sizeof(seal::parms_id_type)), + fmt::format("parameter set mismatch")); + msh_ = std::make_shared(msh); + poly_deg_ = context.first_context_data()->parms().poly_modulus_degree(); +} + +void VectorEncoder::Forward(const NdArrayRef &vec, RLWEPt *out, + bool scale_delta) const { + // Place the vector elements as polynomial coefficients forwardly. + // a0, a1, ..., an -> \sum_i ai*X^i + yacl::CheckNotNull(out); + + size_t num_coeffs = vec.numel(); + size_t num_modulus = msh_->coeff_modulus_size(); + SPU_ENFORCE(vec.shape().size() == 1, "need 1D array"); + SPU_ENFORCE_GT(num_coeffs, 0UL); + SPU_ENFORCE(num_coeffs <= poly_deg_); + + out->parms_id() = seal::parms_id_zero; + out->resize(seal::util::mul_safe(poly_deg_, num_modulus)); + + uint64_t *dst = out->data(); + for (size_t mod_idx = 0; mod_idx < num_modulus; ++mod_idx) { + std::fill_n(dst, poly_deg_, 0); + absl::Span dst_wrap(dst, num_coeffs); + + if (scale_delta) { + msh_->ModulusUpAt(vec, mod_idx, dst_wrap); + } else { + msh_->CenteralizeAt(vec, mod_idx, dst_wrap); + } + dst += poly_deg_; + } + + out->parms_id() = msh_->parms_id(); + out->scale() = 1.; +} + +void VectorEncoder::Backward(const NdArrayRef &vec, RLWEPt *out, + bool scale_delta) const { + // Place the vector elements as polynomial coefficients in backward. + // a0, a1, ..., an -> a0 - \sum_{i>0} ai*X^{N-i} + // where N defines the base ring X^N + 1. + yacl::CheckNotNull(out); + SPU_ENFORCE(vec.shape().size() == 1, "need 1D array"); + + size_t num_coeffs = vec.numel(); + size_t num_modulus = msh_->coeff_modulus_size(); + SPU_ENFORCE_GT(num_coeffs, 0UL); + SPU_ENFORCE(num_coeffs <= poly_deg_); + + const Type &eltype = vec.eltype(); + SPU_ENFORCE(eltype.isa(), "must be ring_type, got={}", eltype); + out->parms_id() = seal::parms_id_zero; + out->resize(seal::util::mul_safe(poly_deg_, num_modulus)); + + const auto field = eltype.as()->field(); + + DISPATCH_ALL_FIELDS(field, "Backward", [&]() { + auto tmp_buff = ring_zeros(field, {(int64_t)poly_deg_}); + auto xvec = NdArrayView(vec); + auto xtmp = NdArrayView(tmp_buff); + + xtmp[0] = xvec[0]; + // reverse and sign flip + for (size_t i = 1; i < num_coeffs; ++i) { + xtmp[num_coeffs - 1 - i] = -xvec[i]; + } + + uint64_t *dst = out->data(); + for (size_t mod_idx = 0; mod_idx < num_modulus; ++mod_idx) { + std::fill_n(dst, poly_deg_, 0); + absl::Span dst_wrap(dst, poly_deg_); + + if (scale_delta) { + msh_->ModulusUpAt(tmp_buff, mod_idx, dst_wrap); + } else { + msh_->CenteralizeAt(tmp_buff, mod_idx, dst_wrap); + } + dst += poly_deg_; + } + + // clean up sensitive data + seal::util::seal_memzero(&tmp_buff.at(0), + sizeof(ring2k_t) * poly_deg_); + }); + + out->parms_id() = msh_->parms_id(); + out->scale() = 1.; +} +} // namespace spu::mpc::cheetah::test \ No newline at end of file diff --git a/libspu/mpc/cheetah/rlwe/types.h b/libspu/mpc/cheetah/rlwe/types.h index 1377fd58..008595e6 100644 --- a/libspu/mpc/cheetah/rlwe/types.h +++ b/libspu/mpc/cheetah/rlwe/types.h @@ -25,129 +25,18 @@ namespace spu::mpc::cheetah { using RLWESecretKey = seal::SecretKey; -using RLWECt = seal::Ciphertext; - -using RLWEPt = seal::Plaintext; - -class LWEDecryptor; - -class LWECt; - -class LWESecretKey { - public: - LWESecretKey() = default; - - explicit LWESecretKey(const RLWESecretKey &rlwe_sk, - const seal::SEALContext &context); - ~LWESecretKey(); - - size_t save_size(seal::compr_mode_type compr_mode = - seal::Serialization::compr_mode_default) const; - - size_t save(seal::seal_byte *buffer, size_t size, - seal::compr_mode_type compr_mode = - seal::Serialization::compr_mode_default) const; - - void load(const seal::SEALContext &context, const seal::seal_byte *buffer, - size_t size); - - void unsafe_load(const seal::SEALContext &context, - const seal::seal_byte *buffer, size_t size); - - private: - friend class LWEDecryptor; - // the secret key in the conventional form - RLWESecretKey secret_non_ntt_; -}; - -class LWECt { - public: - LWECt(); - - ~LWECt(); +using RLWEPublicKey = seal::PublicKey; - LWECt(const RLWECt &rlwe, size_t coeff_index, - const seal::SEALContext &context); +using KSwitchKeys = seal::KSwitchKeys; - LWECt &NegateInplace(const seal::SEALContext &context); +using GaloisKeys = seal::GaloisKeys; - LWECt &AddInplace(const LWECt &oth, const seal::SEALContext &context); - - LWECt &AddPlainInplace(const std::vector &plain, - const seal::SEALContext &context); - - LWECt &SubInplace(const LWECt &oth, const seal::SEALContext &context); - - LWECt &SubPlainInplace(const std::vector &plain, - const seal::SEALContext &context); - - LWECt &AddLazyInplace(const RLWECt &rlwe, size_t coeff_index, - const seal::SEALContext &context); - - LWECt &SubLazyInplace(const RLWECt &rlwe, size_t coeff_index, - const seal::SEALContext &context); - - void Reduce(const seal::SEALContext &context); - - inline bool IsValid() const { return poly_deg_ > 0; } - - seal::parms_id_type parms_id() const { return vec_.parms_id(); } - - inline size_t poly_modulus_degree() const { return poly_deg_; } - - inline size_t coeff_modulus_size() const { return cnst_term_.size(); } - - /** - Returns an upper bound on the size of the LWECt, as if it was written - to an output stream. - */ - size_t save_size(seal::compr_mode_type compr_mode = - seal::Serialization::compr_mode_default) const; - /** - Saves the LWECt to a given memory location. The output is in binary - format and not human-readable. - */ - size_t save(seal::seal_byte *buffer, size_t size, - seal::compr_mode_type compr_mode = - seal::Serialization::compr_mode_default) const { - using namespace std::placeholders; - return seal::Serialization::Save(std::bind(&LWECt::save_members, this, _1), - save_size(seal::compr_mode_type::none), - buffer, size, compr_mode, false); - } - /** - Loads an LWECt from an input stream overwriting the current LWECt. - The loaded ciphertext is verified to be valid for the given SEALContext. - */ - void load(const seal::SEALContext &context, const seal::seal_byte *buffer, - size_t size); - /** - Loads an LWECt from a given memory location overwriting the current - plaintext. No checking of the validity of the ciphertext data against - encryption parameters is performed. This function should not be used - unless the ciphertext comes from a fully trusted source. - */ - void unsafe_load(const seal::SEALContext &context, - const seal::seal_byte *buffer, size_t size) { - using namespace std::placeholders; - seal::Serialization::Load( - std::bind(&LWECt::load_members, this, context, _1, _2), buffer, size, - false); - } - - private: - void save_members(std::ostream &stream) const; +using RLWECt = seal::Ciphertext; - void load_members(const seal::SEALContext &context, std::istream &stream, - seal::SEALVersion version); +using RLWEPt = seal::Plaintext; - friend class LWEDecryptor; - uint64_t maximum_lazy_{0}; - uint64_t lazy_counter_{0}; +class LWECt; - size_t poly_deg_{0}; - std::vector cnst_term_; - RLWEPt vec_; -}; +class PhantomLWECt; -} // namespace spu::mpc::cheetah +} // namespace spu::mpc::cheetah \ No newline at end of file diff --git a/libspu/mpc/cheetah/rlwe/utils.cc b/libspu/mpc/cheetah/rlwe/utils.cc index 63a50f81..ef1dd278 100644 --- a/libspu/mpc/cheetah/rlwe/utils.cc +++ b/libspu/mpc/cheetah/rlwe/utils.cc @@ -14,10 +14,11 @@ #include "libspu/mpc/cheetah/rlwe/utils.h" +#include "seal/util/polyarithsmallmod.h" +#include "seal/util/rlwe.h" #include "spdlog/spdlog.h" #include "libspu/core/prelude.h" -#include "libspu/mpc/cheetah/array_ref.h" namespace spu::mpc::cheetah { @@ -63,32 +64,6 @@ void KeepCoefficientsInplace(RLWECt& ciphertext, RemoveCoefficientsInplace(ciphertext, to_remove); } -AutoMemGuard::AutoMemGuard(ArrayRef* obj) : obj_(obj) { - if (obj_ != nullptr) { - SPU_ENFORCE(obj_->stride() == 1); - } -} - -AutoMemGuard::AutoMemGuard(RLWEPt* pt) : pt_(pt){}; - -AutoMemGuard::~AutoMemGuard() { - try { - if (obj_ && obj_->numel() > 0 && obj_->elsize() > 0) { - auto* ptr = reinterpret_cast(obj_->data()); - auto nbytes = seal::util::mul_safe(obj_->numel(), obj_->elsize()); - seal::util::seal_memzero(ptr, nbytes); - } - - if (pt_ && pt_->coeff_count() > 0) { - size_t nbytes = - seal::util::mul_safe(pt_->coeff_count(), sizeof(uint64_t)); - seal::util::seal_memzero(pt_->data(), nbytes); - } - } catch (const std::exception& e) { - SPDLOG_ERROR("Error in ~MemGuard(): {}", e.what()); - } -} - // Truncate ciphertexts for a smaller communication. // NOTE: after truncation, further homomorphic operation is meaningless. void TruncateBFVForDecryption(seal::Ciphertext& ct, @@ -141,8 +116,53 @@ void TruncateBFVForDecryption(seal::Ciphertext& ct, #endif } -void NttInplace(seal::Plaintext& pt, const seal::SEALContext& context) { - using namespace seal::util; +void NttInplace(RLWECt& ct, const seal::SEALContext& context, bool lazy) { + namespace sut = seal::util; + if (ct.is_ntt_form()) { + return; + } + + SPU_ENFORCE(context.parameters_set()); + auto cntxt_data = context.get_context_data(ct.parms_id()); + SPU_ENFORCE(cntxt_data != nullptr); + + const auto* ntt_tables = cntxt_data->small_ntt_tables(); + + for (size_t k = 0; k < ct.size(); ++k) { + if (lazy) { + sut::ntt_negacyclic_harvey_lazy(sut::iter(ct)[k], ct.coeff_modulus_size(), + ntt_tables); + } else { + sut::ntt_negacyclic_harvey(sut::iter(ct)[k], ct.coeff_modulus_size(), + ntt_tables); + } + } + ct.is_ntt_form() = true; +} + +void InvNttInplace(RLWECt& ct, const seal::SEALContext& context, bool lazy) { + namespace sut = seal::util; + if (not ct.is_ntt_form()) { + return; + } + + SPU_ENFORCE(context.parameters_set()); + auto cntxt_data = context.get_context_data(ct.parms_id()); + SPU_ENFORCE(cntxt_data != nullptr); + + const auto* ntt_tables = cntxt_data->small_ntt_tables(); + size_t L = ct.coeff_modulus_size(); + for (size_t k = 0; k < ct.size(); ++k) { + if (lazy) { + sut::inverse_ntt_negacyclic_harvey_lazy(sut::iter(ct)[k], L, ntt_tables); + } else { + sut::inverse_ntt_negacyclic_harvey(sut::iter(ct)[k], L, ntt_tables); + } + } + ct.is_ntt_form() = false; +} + +void NttInplace(RLWEPt& pt, const seal::SEALContext& context, bool lazy) { SPU_ENFORCE(context.parameters_set()); auto cntxt_data = context.get_context_data(pt.parms_id()); SPU_ENFORCE(cntxt_data != nullptr); @@ -154,13 +174,16 @@ void NttInplace(seal::Plaintext& pt, const seal::SEALContext& context) { size_t n = pt.coeff_count() / L; auto* pt_ptr = pt.data(); for (size_t l = 0; l < L; ++l) { - ntt_negacyclic_harvey(pt_ptr, ntt_tables[l]); + if (lazy) { + seal::util::ntt_negacyclic_harvey_lazy(pt_ptr, ntt_tables[l]); + } else { + seal::util::ntt_negacyclic_harvey(pt_ptr, ntt_tables[l]); + } pt_ptr += n; } } -void InvNttInplace(seal::Plaintext& pt, const seal::SEALContext& context) { - using namespace seal::util; +void InvNttInplace(RLWEPt& pt, const seal::SEALContext& context, bool lazy) { SPU_ENFORCE(context.parameters_set()); auto cntxt_data = context.get_context_data(pt.parms_id()); SPU_ENFORCE(cntxt_data != nullptr); @@ -172,9 +195,251 @@ void InvNttInplace(seal::Plaintext& pt, const seal::SEALContext& context) { size_t n = pt.coeff_count() / L; auto* pt_ptr = pt.data(); for (size_t l = 0; l < L; ++l) { - inverse_ntt_negacyclic_harvey(pt_ptr, ntt_tables[l]); + if (lazy) { + seal::util::inverse_ntt_negacyclic_harvey_lazy(pt_ptr, ntt_tables[l]); + } else { + seal::util::inverse_ntt_negacyclic_harvey(pt_ptr, ntt_tables[l]); + } pt_ptr += n; } } +// ct <- ct + pt +// Handling ct.parms_id = context.key_context_id +void AddPlainInplace(RLWECt& ct, const RLWEPt& pt, + const seal::SEALContext& context) { + namespace sut = seal::util; + SPU_ENFORCE(context.parameters_set()); + auto cntxt_data = context.get_context_data(ct.parms_id()); + SPU_ENFORCE(cntxt_data != nullptr); + SPU_ENFORCE(ct.parms_id() == pt.parms_id()); + + const auto& modulus = cntxt_data->parms().coeff_modulus(); + + sut::RNSIter ct_iter(ct.data(0), ct.poly_modulus_degree()); + sut::ConstRNSIter pt_iter(pt.data(), ct.poly_modulus_degree()); + + sut::add_poly_coeffmod(ct_iter, pt_iter, modulus.size(), modulus, ct_iter); +} + +// ct <- ct - pt +// Handling ct.parms_id = context.key_context_id +void SubPlainInplace(RLWECt& ct, const RLWEPt& pt, + const seal::SEALContext& context) { + namespace sut = seal::util; + SPU_ENFORCE(context.parameters_set()); + auto cntxt_data = context.get_context_data(ct.parms_id()); + SPU_ENFORCE(cntxt_data != nullptr); + SPU_ENFORCE(ct.parms_id() == pt.parms_id()); + + const auto& modulus = cntxt_data->parms().coeff_modulus(); + + sut::RNSIter ct_iter(ct.data(0), ct.poly_modulus_degree()); + sut::ConstRNSIter pt_iter(pt.data(), ct.poly_modulus_degree()); + + sut::sub_poly_coeffmod(ct_iter, pt_iter, modulus.size(), modulus, ct_iter); +} + +void ModulusSwtichInplace(RLWECt& ct, size_t num_modulus_to_keep, + const seal::SEALContext& context) { + namespace sut = seal::util; + SPU_ENFORCE(num_modulus_to_keep >= 1 && + num_modulus_to_keep <= ct.coeff_modulus_size()); + if (num_modulus_to_keep == ct.coeff_modulus_size()) { + // nothing to do + return; + } + + auto cntxt = context.get_context_data(ct.parms_id()); + YACL_ENFORCE(cntxt != nullptr); + size_t index = cntxt->chain_index(); + YACL_ENFORCE((index + 1) >= num_modulus_to_keep); + + auto target_context = cntxt; + auto pool = seal::MemoryManager::GetPool(); + + while (target_context->chain_index() >= num_modulus_to_keep) { + const auto* rns_tool = target_context->rns_tool(); + const auto* ntt_tables = target_context->small_ntt_tables(); + if (ct.is_ntt_form()) { + SEAL_ITERATE(sut::iter(ct), ct.size(), [&](auto I) { + rns_tool->divide_and_round_q_last_ntt_inplace(I, ntt_tables, pool); + }); + } else { + SEAL_ITERATE(sut::iter(ct), ct.size(), [&](auto I) { + rns_tool->divide_and_round_q_last_inplace(I, pool); + }); + } + + auto next_context = target_context->next_context_data(); + SPU_ENFORCE(next_context != nullptr); + + RLWECt next_ct(pool); + next_ct.resize(context, next_context->parms_id(), ct.size()); + SEAL_ITERATE(sut::iter(ct, next_ct), ct.size(), [&](auto I) { + sut::set_poly(std::get<0>(I), ct.poly_modulus_degree(), + ct.coeff_modulus_size() - 1, std::get<1>(I)); + }); + next_ct.is_ntt_form() = ct.is_ntt_form(); + target_context = next_context; + std::swap(next_ct, ct); + } + SPU_ENFORCE_EQ(num_modulus_to_keep, ct.coeff_modulus_size()); +} + +// NOTE(lwj): the following code is modified from +// seal/util/rlwe.cpp#encrypt_zero_symmetric +// The symmetric RLWE encryption of m is given as (m + e - a*sk, a) +// The origin SEAL uses a general API encrypt_zero_symmetric() to generate (e +// - a*sk, a) first. Then the encryption of `m` is followed by addition `m + e +// - a*sk` +// +// Observation: we can perform the addition `m + e` first to save one NTT +// That is when `need_ntt=true` NTT(m) + NTT(e) is replaced by NTT(m + e) +void SymmetricRLWEEncrypt(const RLWESecretKey& sk, + const seal::SEALContext& context, + absl::Span msg_non_ntt, bool need_ntt, + bool need_seed, absl::Span out_ct) { + using namespace seal; + using namespace seal::util; + size_t n = msg_non_ntt.size(); + SPU_ENFORCE(out_ct.size() >= n); + SPU_ENFORCE(seal::is_metadata_valid_for(sk, context)); + if (n == 0) { + return; + } + + SPU_ENFORCE(context.parameters_set(), "invalid SEALContext"); + SPU_ENFORCE(std::all_of(msg_non_ntt.data(), msg_non_ntt.data() + n, + [&](const RLWEPt& pt) { + return context.get_context_data(pt.parms_id()) != + nullptr; + }), + "invalid plaintext to encrypt"); + + auto pool = MemoryManager::GetPool(mm_prof_opt::mm_force_thread_local, true); + std::shared_ptr bootstrap_prng = nullptr; + + for (size_t i = 0; i < n; ++i) { + const RLWEPt& msg = msg_non_ntt[i]; + RLWECt& destination = out_ct[i]; + + auto parms_id = msg.parms_id(); + auto& context_data = *context.get_context_data(parms_id); + auto& parms = context_data.parms(); + auto& coeff_modulus = parms.coeff_modulus(); + size_t coeff_modulus_size = coeff_modulus.size(); + size_t coeff_count = parms.poly_modulus_degree(); + auto ntt_tables = context_data.small_ntt_tables(); + size_t encrypted_size = 2; + + // If a polynomial is too small to store UniformRandomGeneratorInfo, + // it is best to just disable save_seed. Note that the size needed is + // the size of UniformRandomGeneratorInfo plus one (uint64_t) because + // of an indicator word that indicates a seeded ciphertext. + size_t poly_uint64_count = mul_safe(coeff_count, coeff_modulus_size); + if (msg.coeff_count() != poly_uint64_count) { + throw std::invalid_argument("msg coeff_count mismatch"); + } + size_t prng_info_byte_count = static_cast( + UniformRandomGeneratorInfo::SaveSize(compr_mode_type::none)); + size_t prng_info_uint64_count = divide_round_up( + prng_info_byte_count, static_cast(bytes_per_uint64)); + if (need_ntt && poly_uint64_count < prng_info_uint64_count + 1) { + need_ntt = false; + } + + destination.resize(context, parms_id, encrypted_size); + destination.is_ntt_form() = need_ntt; + destination.scale() = 1.0; + destination.correction_factor() = 1; + + // Create an instance of a random number generator. We use this for + // sampling a seed for a second PRNG used for sampling u (the seed can be + // public information. This PRNG is also used for sampling the noise/error + // below. + if (!bootstrap_prng) { + bootstrap_prng = parms.random_generator()->create(); + } + + // Sample a public seed for generating uniform randomness + prng_seed_type public_prng_seed; + bootstrap_prng->generate( + prng_seed_byte_count, + reinterpret_cast(public_prng_seed.data())); + + // Set up a new default PRNG for expanding u from the seed sampled above + auto ciphertext_prng = + UniformRandomGeneratorFactory::DefaultFactory()->create( + public_prng_seed); + + // Generate ciphertext: (c[0], c[1]) = ([msg + e - a*s]_q, a) in BFV/CKKS + uint64_t* c0 = destination.data(); + uint64_t* c1 = destination.data(1); + + // Sample a uniformly at random + if (need_ntt || !need_seed) { + // Sample the NTT form directly + sample_poly_uniform(ciphertext_prng, parms, c1); + } else if (need_seed) { + // Sample non-NTT form and store the seed + sample_poly_uniform(ciphertext_prng, parms, c1); + for (size_t i = 0; i < coeff_modulus_size; i++) { + // Transform the c1 into NTT representation + ntt_negacyclic_harvey(c1 + i * coeff_count, ntt_tables[i]); + } + } + + // Sample e <-- chi + auto noise(allocate_poly(coeff_count, coeff_modulus_size, pool)); + SEAL_NOISE_SAMPLER(bootstrap_prng, parms, noise.get()); + + // Calculate -(as + e) (mod q) and store in c[0] in BFV/CKKS + for (size_t i = 0; i < coeff_modulus_size; i++) { + dyadic_product_coeffmod(sk.data().data() + i * coeff_count, + c1 + i * coeff_count, coeff_count, + coeff_modulus[i], c0 + i * coeff_count); + if (need_ntt) { + // Peform the addition m + e first + // NOTE: lazy reduction here which will be obsorbed by + // ntt_negacyclic_harvey + std::transform(noise.get() + i * coeff_count, + noise.get() + i * coeff_count + coeff_count, + msg.data() + i * coeff_count, + noise.get() + i * coeff_count, std::plus()); + + // Then transform m + e to NTT form + // noise <- m + e + ntt_negacyclic_harvey(noise.get() + i * coeff_count, ntt_tables[i]); + } else { + // c0 <- a*s - m + inverse_ntt_negacyclic_harvey(c0 + i * coeff_count, ntt_tables[i]); + sub_poly_coeffmod(c0 + i * coeff_count, msg.data() + i * coeff_count, + coeff_count, coeff_modulus[i], c0 + i * coeff_count); + } + + // c0 <- noise - c0 + // <- m + e - a*s (need_ntt=true) + // <- e - (a*s - m) (need_ntt=false) + sub_poly_coeffmod(noise.get() + i * coeff_count, c0 + i * coeff_count, + coeff_count, coeff_modulus[i], c0 + i * coeff_count); + } + + if (!need_ntt && !need_seed) { + for (size_t i = 0; i < coeff_modulus_size; i++) { + // Transform the c1 into non-NTT representation + inverse_ntt_negacyclic_harvey(c1 + i * coeff_count, ntt_tables[i]); + } + } + + if (need_seed) { + UniformRandomGeneratorInfo prng_info = ciphertext_prng->info(); + + // Write prng_info to destination.data(1) after an indicator word + c1[0] = static_cast(0xFFFFFFFFFFFFFFFFULL); + prng_info.save(reinterpret_cast(c1 + 1), prng_info_byte_count, + compr_mode_type::none); + } + } +} } // namespace spu::mpc::cheetah diff --git a/libspu/mpc/cheetah/rlwe/utils.h b/libspu/mpc/cheetah/rlwe/utils.h index fc04d1df..d8eab673 100644 --- a/libspu/mpc/cheetah/rlwe/utils.h +++ b/libspu/mpc/cheetah/rlwe/utils.h @@ -116,9 +116,17 @@ void DecodeSEALObjects(const std::vector &buf_view, void TruncateBFVForDecryption(seal::Ciphertext &ct, const seal::SEALContext &context); -void NttInplace(seal::Plaintext &pt, const seal::SEALContext &context); +void NttInplace(seal::Plaintext &pt, const seal::SEALContext &context, + bool lazy = false); -void InvNttInplace(seal::Plaintext &pt, const seal::SEALContext &context); +void InvNttInplace(seal::Plaintext &pt, const seal::SEALContext &context, + bool lazy = false); + +void NttInplace(RLWECt &ct, const seal::SEALContext &context, + bool lazy = false); + +void InvNttInplace(RLWECt &ct, const seal::SEALContext &context, + bool lazy = false); // requires ciphertext.is_ntt_form() is `false` // ciphertext.size() is `2` @@ -128,6 +136,17 @@ void RemoveCoefficientsInplace(RLWECt &ciphertext, void KeepCoefficientsInplace(RLWECt &ciphertext, const std::set &to_keep); +// ct mod Q' <- ct mod Q +void ModulusSwtichInplace(RLWECt &ct, size_t num_modulus_to_keep, + const seal::SEALContext &context); +// ct <- ct + pt +void AddPlainInplace(seal::Ciphertext &ct, const seal::Plaintext &pt, + const seal::SEALContext &context); + +// ct <- ct - pt +void SubPlainInplace(seal::Ciphertext &ct, const seal::Plaintext &pt, + const seal::SEALContext &context); + // x mod prime template uint64_t BarrettReduce(T x, const seal::Modulus &prime) { @@ -139,16 +158,9 @@ uint64_t BarrettReduce(T x, const seal::Modulus &prime) { } } -// Erase the memory automatically -struct AutoMemGuard { - explicit AutoMemGuard(ArrayRef *obj); - - explicit AutoMemGuard(RLWEPt *pt); - - ~AutoMemGuard(); - - ArrayRef *obj_{nullptr}; - RLWEPt *pt_{nullptr}; -}; +void SymmetricRLWEEncrypt(const RLWESecretKey &sk, + const seal::SEALContext &context, + absl::Span msg_non_ntt, bool need_ntt, + bool need_seed, absl::Span out_ct); } // namespace spu::mpc::cheetah diff --git a/libspu/mpc/cheetah/state.cc b/libspu/mpc/cheetah/state.cc index fcafdf2a..0c70a28e 100644 --- a/libspu/mpc/cheetah/state.cc +++ b/libspu/mpc/cheetah/state.cc @@ -16,9 +16,30 @@ #include "spdlog/spdlog.h" +#include "libspu/core/context.h" #include "libspu/mpc/utils/ring_ops.h" namespace spu::mpc::cheetah { +size_t InitOTState(KernelEvalContext* ctx, size_t njobs) { + constexpr size_t kMinWorkSize = 1500; + if (njobs == 0) { + return 0; + } + auto* comm = ctx->getState(); + auto* ot_state = ctx->getState(); + size_t nworker = ot_state->parallel_size(); + nworker = std::min(nworker, CeilDiv(njobs, kMinWorkSize)); + for (size_t w = 0; w < nworker; ++w) { + ot_state->LazyInit(comm, w); + } + return nworker; +} + +size_t CheetahOTState::parallel_size() const { + // NOTE(lwj): div-2 to prevent making too many threads on small CPU cores. + return std::max( + 1, std::min(getNumberOfProc() / 2, kMaxOTParallel)); +} void CheetahMulState::makeSureCacheSize(FieldType field, int64_t numel) { // NOTE(juhou): make sure the lock is obtained @@ -41,39 +62,43 @@ void CheetahMulState::makeSureCacheSize(FieldType field, int64_t numel) { // where c0 = a0*b0 + + // c1 = a1*b1 + + const int rank = mul_prot_->Rank(); - const size_t ole_sze = mul_prot_->OLEBatchSize(); - const auto num_ole = CeilDiv(2 * numel, ole_sze); - const size_t num_beaver = (num_ole * ole_sze) / 2; - auto rand = - flatten(ring_rand(field, {static_cast(num_ole * ole_sze)})); + const int64_t ole_sze = mul_prot_->OLEBatchSize(); + const int64_t num_ole = CeilDiv(2 * numel, ole_sze); + const int64_t num_beaver = (num_ole * ole_sze) / 2; + auto rand = ring_rand(field, {num_ole * ole_sze}); auto cross = mul_prot_->MulOLE(rand, rank == 0); - ArrayRef beaver[3]; - ArrayRef a0b1, a1b0; + NdArrayRef beaver[3]; + NdArrayRef a0b1; + NdArrayRef a1b0; if (rank == 0) { - beaver[0] = rand.slice(0, num_beaver); - beaver[1] = rand.slice(num_beaver, num_beaver * 2); + beaver[0] = rand.slice({0}, {num_beaver}, {1}); + beaver[1] = rand.slice({num_beaver}, {num_beaver * 2}, {1}); } else { - beaver[0] = rand.slice(num_beaver, num_beaver * 2); - beaver[1] = rand.slice(0, num_beaver); + beaver[0] = rand.slice({num_beaver}, {num_beaver * 2}, {1}); + beaver[1] = rand.slice({0}, {num_beaver}, {1}); } - a0b1 = cross.slice(0, num_beaver); - a1b0 = cross.slice(num_beaver, num_beaver * 2); + a0b1 = cross.slice({0}, {num_beaver}, {1}); + a1b0 = cross.slice({num_beaver}, {num_beaver * 2}, {1}); - beaver[2] = flatten( - ring_add(ring_add(toNdArray(cross.slice(0, num_beaver)), - toNdArray(cross.slice(num_beaver, 2 * num_beaver))), - ring_mul(toNdArray(beaver[0]), toNdArray(beaver[1])))); + beaver[2] = + ring_add(ring_add(cross.slice({0}, {num_beaver}, {1}), + cross.slice({num_beaver}, {2 * num_beaver}, {1})), + ring_mul(beaver[0], beaver[1])); DISPATCH_ALL_FIELDS(field, "makeSureCacheSize", [&]() { for (size_t i : {0, 1, 2}) { - auto tmp = - flatten(ring_zeros(field, {(int64_t)num_beaver + cached_sze_})); - ArrayView old_cache(cached_beaver_[i]); - ArrayView _beaver(beaver[i]); - ArrayView new_cache(tmp); + auto tmp = ring_zeros(field, {num_beaver + cached_sze_}); + NdArrayView new_cache(tmp); + // concate two array - pforeach(0, cached_sze_, [&](int64_t j) { new_cache[j] = old_cache[j]; }); + if (cached_sze_ > 0) { + NdArrayView old_cache(cached_beaver_[i]); + pforeach(0, cached_sze_, + [&](int64_t j) { new_cache[j] = old_cache[j]; }); + } + + NdArrayView _beaver(beaver[i]); pforeach(0, num_beaver, [&](int64_t j) { new_cache[cached_sze_ + j] = _beaver[j]; }); cached_beaver_[i] = tmp; @@ -86,21 +111,22 @@ void CheetahMulState::makeSureCacheSize(FieldType field, int64_t numel) { SPU_ENFORCE(cached_sze_ >= numel); } -std::array CheetahMulState::TakeCachedBeaver(FieldType field, - int64_t numel) { +std::array CheetahMulState::TakeCachedBeaver(FieldType field, + int64_t numel) { SPU_ENFORCE(numel > 0); std::unique_lock guard(lock_); makeSureCacheSize(field, numel); - std::array ret; + std::array ret; for (size_t i : {0, 1, 2}) { SPU_ENFORCE(cached_beaver_[i].numel() >= numel); - ret[i] = cached_beaver_[i].slice(0, numel); + ret[i] = cached_beaver_[i].slice({0}, {numel}, {1}); if (cached_sze_ == numel) { // empty cache now - cached_beaver_[i] = ArrayRef(); + // NOTE(lwj): should use `{0}` as empty while `{}` is a scalar + cached_beaver_[i] = NdArrayRef(cached_beaver_[i].eltype(), {0}); } else { - cached_beaver_[i] = cached_beaver_[i].slice(numel, cached_sze_); + cached_beaver_[i] = cached_beaver_[i].slice({numel}, {cached_sze_}, {1}); } } diff --git a/libspu/mpc/cheetah/state.h b/libspu/mpc/cheetah/state.h index 6c5c6de8..c221753a 100644 --- a/libspu/mpc/cheetah/state.h +++ b/libspu/mpc/cheetah/state.h @@ -24,15 +24,19 @@ namespace spu::mpc::cheetah { +// Return num_workers for the given size of jobs +size_t InitOTState(KernelEvalContext* ctx, size_t njobs); + class CheetahMulState : public State { private: mutable std::mutex lock_; // a[2] = a[0] * a[1] mutable int64_t cached_sze_{0}; FieldType field_{FT_INVALID}; - ArrayRef cached_beaver_[3]; + NdArrayRef cached_beaver_[3]; std::unique_ptr mul_prot_; + std::shared_ptr duplx_; // NOTE(juhou): make sure the lock is obtained void makeSureCacheSize(FieldType, int64_t numel); @@ -43,15 +47,19 @@ class CheetahMulState : public State { public: static constexpr char kBindName[] = "CheetahMul"; - explicit CheetahMulState(const std::shared_ptr& lctx) { - mul_prot_ = std::make_unique(lctx); + explicit CheetahMulState(const std::shared_ptr& lctx, + bool allow_mul_error = false) { + mul_prot_ = std::make_unique(lctx, allow_mul_error); + duplx_ = lctx->Spawn(); } ~CheetahMulState() override = default; CheetahMul* get() { return mul_prot_.get(); } - std::array TakeCachedBeaver(FieldType field, int64_t num); + std::shared_ptr duplx() { return duplx_; } + + std::array TakeCachedBeaver(FieldType field, int64_t num); }; class CheetahDotState : public State { @@ -64,8 +72,9 @@ class CheetahDotState : public State { public: static constexpr char kBindName[] = "CheetahDot"; - explicit CheetahDotState(const std::shared_ptr& lctx) { - dot_prot_ = std::make_unique(lctx); + explicit CheetahDotState(const std::shared_ptr& lctx, + bool enable_matmul_pack = true) { + dot_prot_ = std::make_unique(lctx, enable_matmul_pack); } ~CheetahDotState() override = default; @@ -82,13 +91,13 @@ class CheetahOTState : public State { public: static constexpr char kBindName[] = "CheetahOT"; - static constexpr size_t kParallel = 16; + static constexpr size_t kMaxOTParallel = 32; - explicit CheetahOTState() : basic_ot_prot_(kParallel) {} + explicit CheetahOTState() : basic_ot_prot_(kMaxOTParallel) {} ~CheetahOTState() override = default; - constexpr size_t parallel_size() const { return kParallel; } + size_t parallel_size() const; void LazyInit(Communicator* comm, size_t idx = 0) { SPU_ENFORCE(idx < parallel_size(), "idx={} out-of-bound", idx); diff --git a/libspu/mpc/cheetah/type.cc b/libspu/mpc/cheetah/type.cc new file mode 100644 index 00000000..1497189e --- /dev/null +++ b/libspu/mpc/cheetah/type.cc @@ -0,0 +1,32 @@ +// Copyright 2021 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "libspu/mpc/cheetah/type.h" + +#include + +#include "libspu/mpc/common/pv2k.h" + +namespace spu::mpc::cheetah { + +void registerTypes() { + regPV2kTypes(); + + static std::once_flag flag; + std::call_once(flag, []() { + TypeContext::getTypeContext()->addTypes(); + }); +} + +} // namespace spu::mpc::cheetah \ No newline at end of file diff --git a/libspu/mpc/cheetah/type.h b/libspu/mpc/cheetah/type.h new file mode 100644 index 00000000..e681581a --- /dev/null +++ b/libspu/mpc/cheetah/type.h @@ -0,0 +1,61 @@ +// Copyright 2021 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "libspu/core/type.h" + +namespace spu::mpc::cheetah { + +class AShrTy : public TypeImpl { + using Base = TypeImpl; + + public: + using Base::Base; + static std::string_view getStaticId() { return "cheetah.AShr"; } + explicit AShrTy(FieldType field) { field_ = field; } +}; + +class BShrTy : public TypeImpl { + using Base = TypeImpl; + + static constexpr size_t kDefaultNumBits = std::numeric_limits::max(); + + public: + using Base::Base; + explicit BShrTy(FieldType field, size_t nbits = kDefaultNumBits) { + field_ = field; + nbits_ = nbits == kDefaultNumBits ? SizeOf(field) * 8 : nbits; + SPU_ENFORCE(nbits_ <= SizeOf(field) * 8); + } + + static std::string_view getStaticId() { return "cheetah.BShr"; } + + void fromString(std::string_view detail) override { + auto comma = detail.find_first_of(','); + auto field_str = detail.substr(0, comma); + auto nbits_str = detail.substr(comma + 1); + SPU_ENFORCE(FieldType_Parse(std::string(field_str), &field_), + "parse failed from={}", detail); + nbits_ = std::stoul(std::string(nbits_str)); + }; + + std::string toString() const override { + return fmt::format("{},{}", FieldType_Name(field()), nbits_); + } +}; + +void registerTypes(); + +} // namespace spu::mpc::cheetah \ No newline at end of file diff --git a/libspu/mpc/spdz2k/BUILD.bazel b/libspu/mpc/spdz2k/BUILD.bazel index af9af9b6..cad7c4ed 100644 --- a/libspu/mpc/spdz2k/BUILD.bazel +++ b/libspu/mpc/spdz2k/BUILD.bazel @@ -178,7 +178,6 @@ spu_cc_library( deps = [ ":type", "//libspu/core", - "//libspu/mpc/cheetah:array_ref", "//libspu/mpc/common:pv2k", "//libspu/mpc/utils:ring_ops", ], diff --git a/libspu/mpc/utils/BUILD.bazel b/libspu/mpc/utils/BUILD.bazel index cce3ca41..f8351338 100644 --- a/libspu/mpc/utils/BUILD.bazel +++ b/libspu/mpc/utils/BUILD.bazel @@ -55,6 +55,12 @@ spu_cc_library( name = "ring_ops", srcs = ["ring_ops.cc"], hdrs = ["ring_ops.h"], + copts = select({ + "@platforms//cpu:x86_64": [ + "-mavx", + ], + "//conditions:default": [], + }), deps = [ ":linalg", "//libspu/core", diff --git a/libspu/pir/BUILD.bazel b/libspu/pir/BUILD.bazel index e3b66852..514ebba6 100644 --- a/libspu/pir/BUILD.bazel +++ b/libspu/pir/BUILD.bazel @@ -13,6 +13,7 @@ # limitations under the License. load("//bazel:spu.bzl", "spu_cc_library", "spu_cc_test") +load("@rules_cc//cc:defs.bzl", "cc_proto_library") package(default_visibility = ["//visibility:public"]) diff --git a/libspu/psi/core/BUILD.bazel b/libspu/psi/core/BUILD.bazel index 1ab22927..719d29e9 100644 --- a/libspu/psi/core/BUILD.bazel +++ b/libspu/psi/core/BUILD.bazel @@ -13,6 +13,7 @@ # limitations under the License. load("//bazel:spu.bzl", "spu_cc_binary", "spu_cc_library", "spu_cc_test") +load("@rules_cc//cc:defs.bzl", "cc_proto_library") package(default_visibility = ["//visibility:public"]) diff --git a/libspu/spu.proto b/libspu/spu.proto index 36affd27..67a8e40c 100644 --- a/libspu/spu.proto +++ b/libspu/spu.proto @@ -274,6 +274,9 @@ message RuntimeConfig { // Enable a simpler rsqrt approximation bool enable_lower_accuracy_rsqrt = 57; + // Sine/Cosine approximation iterations + int64 sine_cosine_iters = 58; + /// - MPC protocol related definitions. enum BeaverType { diff --git a/sml/decomposition/tests/pca_test.py b/sml/decomposition/tests/pca_test.py index f0afc975..2e3f398a 100644 --- a/sml/decomposition/tests/pca_test.py +++ b/sml/decomposition/tests/pca_test.py @@ -96,7 +96,98 @@ def proc_reconstruct(X): X_reconstructed_sklearn = sklearn_pca.inverse_transform(X_transformed_sklearn) # Compare the results - self.assertTrue(np.allclose(X_reconstructed_sklearn, result, atol=1e-3)) + self.assertTrue(np.allclose(X_reconstructed_sklearn, result, atol=0.01)) + + def test_rsvd(self): + config = spu_pb2.RuntimeConfig( + protocol=spu_pb2.ProtocolKind.ABY3, + field=spu_pb2.FieldType.FM128, + fxp_fraction_bits=30, + ) + sim = spsim.Simulator(3, config) + + # Test fit_transform + def proc_transform(X, random_matrix): + model = PCA( + method='rsvd', + n_components=n_components, + n_oversamples=n_oversamples, + random_matrix=random_matrix, + scale=[10000000, 10000], + ) + + model.fit(X) + X_transformed = model.transform(X) + X_variances = model._variances + + return X_transformed, X_variances + + # Create a simple dataset + X = random.normal(random.PRNGKey(0), (1000, 20)) + n_components = 5 + n_oversamples = 10 + + # Create random_matrix + random_state = np.random.RandomState(0) + random_matrix = random_state.normal( + size=(X.shape[1], n_components + n_oversamples) + ) + + # Run the simulation + result = spsim.sim_jax(sim, proc_transform)(X, random_matrix) + + # The transformed data should have 2 dimensions + self.assertEqual(result[0].shape[1], n_components) + + # The mean of the transformed data should be approximately 0 + self.assertTrue(jnp.allclose(jnp.mean(result[0], axis=0), 0, atol=1e-3)) + + X_np = np.array(X) + + # Run fit_transform using sklearn + sklearn_pca = SklearnPCA( + n_components=n_components, + svd_solver="randomized", + power_iteration_normalizer="QR", + random_state=0, + ) + sklearn_pca.fit(X_np) + X_transformed_sklearn = sklearn_pca.transform(X_np) + + # Compare the transform results + print("X_transformed_sklearn: ", X_transformed_sklearn) + print("X_transformed_jax", result[0]) + + # Compare the variance results + print( + "X_transformed_sklearn.explained_variance_: ", + sklearn_pca.explained_variance_, + ) + print("X_transformed_jax.explained_variance_: ", result[1]) + + # Test inverse_transform + def proc_reconstruct(X, random_matrix): + model = PCA( + method='rsvd', + n_components=n_components, + n_oversamples=n_oversamples, + random_matrix=random_matrix, + scale=[10000000, 10000], + ) + + model.fit(X) + X_reconstructed = model.inverse_transform(model.transform(X)) + + return X_reconstructed + + # Run the simulation + result = spsim.sim_jax(sim, proc_reconstruct)(X, random_matrix) + + # Run inverse_transform using sklearn + X_reconstructed_sklearn = sklearn_pca.inverse_transform(X_transformed_sklearn) + + # Compare the results + self.assertTrue(np.allclose(X_reconstructed_sklearn, result, atol=1e-1)) def test_rsvd(self): config = spu_pb2.RuntimeConfig( diff --git a/sml/utils/BUILD.bazel b/sml/utils/BUILD.bazel index fd977923..26866239 100644 --- a/sml/utils/BUILD.bazel +++ b/sml/utils/BUILD.bazel @@ -18,13 +18,15 @@ package(default_visibility = ["//visibility:public"]) py_library( name = "emulation", - srcs = ["emulation.py"], + srcs = [ + "emulation.py", + "//libspu:spu_py_proto", + ], data = [ "//examples/python/conf", # FIXME ], deps = [ "//spu:init", - "//spu:spu_py_proto", "//spu/utils:distributed", ], ) @@ -32,11 +34,9 @@ py_library( py_library( name = "fxp_approx", srcs = ["fxp_approx.py"], - deps = [], ) py_library( name = "extmath", srcs = ["extmath.py"], - deps = [], ) diff --git a/spu/BUILD.bazel b/spu/BUILD.bazel index c9222b8b..9685b774 100644 --- a/spu/BUILD.bazel +++ b/spu/BUILD.bazel @@ -14,7 +14,6 @@ load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") load("@rules_python//python:defs.bzl", "py_library") -load("@com_github_grpc_grpc//bazel:python_rules.bzl", "py_proto_library") load("@rules_python//python:packaging.bzl", "py_package") package(default_visibility = ["//visibility:public"]) @@ -54,52 +53,40 @@ pybind_extension( ], ) -py_proto_library( - name = "spu_py_proto", - deps = ["//libspu:spu_proto"], -) - -py_proto_library( - name = "psi_py_proto", - deps = ["//libspu/psi:psi_proto"], -) - -py_proto_library( - name = "pir_py_proto", - deps = ["//libspu/pir:pir_proto"], -) - py_library( name = "api", - srcs = ["api.py"], + srcs = [ + "api.py", + "spu_pb2.py", + "//libspu:spu_py_proto", + ], data = [ ":libspu.so", ], - deps = [ - ":spu_py_proto", - ], ) py_library( name = "psi", - srcs = ["psi.py"], + srcs = [ + "psi.py", + "psi_pb2.py", + "//libspu:psi_py_proto", + ], data = [ ":libspu.so", ], - deps = [ - ":psi_py_proto", - ], ) py_library( name = "pir", - srcs = ["pir.py"], + srcs = [ + "pir.py", + "pir_pb2.py", + "//libspu:pir_py_proto", + ], data = [ ":libspu.so", ], - deps = [ - ":pir_py_proto", - ], ) py_library( @@ -109,10 +96,7 @@ py_library( "version.py", ":api", ":pir", - ":pir_py_proto", ":psi", - ":psi_py_proto", - ":spu_py_proto", "//spu/intrinsic:all_intrinsics", "//spu/utils:simulation", ], diff --git a/spu/pir_pb2.py b/spu/pir_pb2.py new file mode 100644 index 00000000..4d56b827 --- /dev/null +++ b/spu/pir_pb2.py @@ -0,0 +1 @@ +from libspu.pir.pir_pb2 import * diff --git a/spu/psi_pb2.py b/spu/psi_pb2.py new file mode 100644 index 00000000..b70528d5 --- /dev/null +++ b/spu/psi_pb2.py @@ -0,0 +1 @@ +from libspu.psi.psi_pb2 import * diff --git a/spu/spu_pb2.py b/spu/spu_pb2.py new file mode 100644 index 00000000..43260825 --- /dev/null +++ b/spu/spu_pb2.py @@ -0,0 +1 @@ +from libspu.spu_pb2 import * diff --git a/spu/tests/jnp_testbase.py b/spu/tests/jnp_testbase.py index 27705389..7ce841eb 100644 --- a/spu/tests/jnp_testbase.py +++ b/spu/tests/jnp_testbase.py @@ -140,24 +140,8 @@ def post(x): REC("subtract", 2, number_dtypes, all_shapes, rand_default), REC("signbit", 1, number_dtypes, all_shapes, rand_default), REC("trunc", 1, number_dtypes, all_shapes, rand_default), - REC( - "sin", - 1, - number_dtypes, - all_shapes, - rand_default, - Status.SysError, - "stablehlo.sine", - ), # FIXME: stablehlo.sine - REC( - "cos", - 1, - number_dtypes, - all_shapes, - rand_default, - Status.SysError, - "stablehlo.cosine", - ), # FIXME: stablehlo.cosine + REC("sin", 1, number_dtypes, all_shapes, rand_default), + REC("cos", 1, number_dtypes, all_shapes, rand_default), REC( "tan", 1, @@ -165,17 +149,9 @@ def post(x): all_shapes, partial(jtu.rand_uniform, low=-1.5, high=1.5), Status.SysError, - "stablehlo.sine", - ), # FIXME: stablehlo.sine - REC( - "sinh", - 1, - number_dtypes, - all_shapes, - rand_default, - Status.SysError, - "stablehlo.cosine", - ), # FIXME: stablehlo.cosine + "stablehlo.tan", + ), # FIXME: stablehlo.tan + REC("sinh", 1, number_dtypes, all_shapes, jtu.rand_small), REC("cosh", 1, number_dtypes, all_shapes, jtu.rand_small), REC("tanh", 1, number_dtypes, all_shapes, jtu.rand_default), REC( @@ -341,15 +317,7 @@ def post(x): REC( "copysign", 2, number_dtypes, all_shapes, rand_default, Status.SysError, "shift" ), # FIXME: shift - REC( - "sinc", - 1, - number_dtypes, - all_shapes, - rand_default, - Status.SysError, - "stablehlo.sine", - ), # FIXME: stablehlo.sine + REC("sinc", 1, number_dtypes, all_shapes, rand_default), REC("square", 1, number_dtypes, all_shapes, rand_default), REC("sqrt", 1, number_dtypes, all_shapes, jtu.rand_positive), REC("transpose", 1, all_dtypes, all_shapes, rand_default), diff --git a/spu/utils/BUILD.bazel b/spu/utils/BUILD.bazel index c46ac918..2dfbe7f8 100644 --- a/spu/utils/BUILD.bazel +++ b/spu/utils/BUILD.bazel @@ -14,7 +14,7 @@ load("@rules_python//python:defs.bzl", "py_library") load("@rules_proto//proto:defs.bzl", "proto_library") -load("@com_github_grpc_grpc//bazel:python_rules.bzl", "py_grpc_library", "py_proto_library") +load("@rules_proto_grpc//python:defs.bzl", "python_grpc_compile") package(default_visibility = ["//visibility:public"]) @@ -32,30 +32,23 @@ proto_library( srcs = ["distributed.proto"], ) -py_proto_library( - name = "distributed_py_proto", - deps = [ - ":distributed_proto", - ], -) - -py_grpc_library( +python_grpc_compile( name = "distributed_py_proto_grpc", - srcs = ["distributed_proto"], - deps = [ - ":distributed_py_proto", - ], + output_mode = "NO_PREFIX", + prefix_path = "../..", + protos = ["distributed_proto"], ) py_library( name = "distributed", - srcs = ["distributed.py"], - deps = [ - ":distributed_py_proto", + srcs = [ + "distributed.py", ":distributed_py_proto_grpc", + "//libspu:spu_py_proto", + ], + deps = [ ":frontend", "//spu:api", - "//spu:spu_py_proto", ], )