From 9f3570838481bce479e0f7eb3e8cfa1e59270591 Mon Sep 17 00:00:00 2001 From: anakinxc Date: Mon, 8 Jul 2024 14:01:59 +0800 Subject: [PATCH] repo-sync-2024-07-08T14:01:52+0800 --- libspu/mpc/cheetah/boolean_semi2k.cc | 4 ++-- spu/tests/BUILD.bazel | 3 ++- spu/tests/jnp_cheetah_r64_test.py | 3 +-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/libspu/mpc/cheetah/boolean_semi2k.cc b/libspu/mpc/cheetah/boolean_semi2k.cc index a64da4cc..786d0c38 100644 --- a/libspu/mpc/cheetah/boolean_semi2k.cc +++ b/libspu/mpc/cheetah/boolean_semi2k.cc @@ -81,8 +81,8 @@ NdArrayRef P2B::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { if (comm->getRank() == 0) { ring_xor_(x, in); } - - return makeBShare(x, field, getNumBits(in)); + auto nbits = getNumBits(in) == 0 ? 1 : getNumBits(in); + return makeBShare(x, field, nbits); } NdArrayRef AndBP::proc(KernelEvalContext* ctx, const NdArrayRef& lhs, diff --git a/spu/tests/BUILD.bazel b/spu/tests/BUILD.bazel index d7453e60..cfd029b8 100644 --- a/spu/tests/BUILD.bazel +++ b/spu/tests/BUILD.bazel @@ -87,6 +87,7 @@ py_test( py_test( name = "jnp_cheetah_r64_test", + size = "enormous", timeout = "long", srcs = ["jnp_cheetah_r64_test.py"], deps = [ @@ -96,7 +97,7 @@ py_test( py_test( name = "jnp_cheetah_r64_test_x64", - timeout = "long", + size = "enormous", srcs = ["jnp_cheetah_r64_test.py"], env = { "ENABLE_X64_TEST": "1", diff --git a/spu/tests/jnp_cheetah_r64_test.py b/spu/tests/jnp_cheetah_r64_test.py index b113dbcd..10c02a53 100644 --- a/spu/tests/jnp_cheetah_r64_test.py +++ b/spu/tests/jnp_cheetah_r64_test.py @@ -22,8 +22,7 @@ from spu.tests.jnp_testbase import JnpTests -@unittest.skip("too slow, last run succeed") -class JnpTestAby3FM64(JnpTests.JnpTestBase): +class JnpTestCheetahFM64(JnpTests.JnpTestBase): def setUp(self): self._sim = ppsim.Simulator.simple( 2, spu_pb2.ProtocolKind.CHEETAH, spu_pb2.FieldType.FM64