diff --git a/setup.py b/setup.py index b04cb1e9..b66a2b65 100644 --- a/setup.py +++ b/setup.py @@ -57,16 +57,18 @@ def build_extension(self, ext): cfg = "Debug" if self.debug else "Release" build_args = ["--config", cfg] + env = os.environ.copy() + if platform.system() == "Windows": cmake_args += [f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{cfg.upper()}={extdir}"] - if sys.maxsize > 2 ** 32: + if sys.maxsize > 2**32: cmake_args += ["-A", "x64"] build_args += ["--", "/m", "/p:TrackFileAccess=false"] else: cmake_args += ["-DCMAKE_BUILD_TYPE=" + cfg] - build_args += ["--", "-j"] + if "MAKEFLAGS" not in env: + build_args += ["--", "-j"] - env = os.environ.copy() env["CXXFLAGS"] = '{} -DVERSION_INFO=\\"{}\\"'.format( env.get("CXXFLAGS", ""), self.distribution.get_version() ) diff --git a/tenseal/cpp/tensors/encrypted_vector.h b/tenseal/cpp/tensors/encrypted_vector.h index 9df2191e..19763539 100644 --- a/tenseal/cpp/tensors/encrypted_vector.h +++ b/tenseal/cpp/tensors/encrypted_vector.h @@ -186,28 +186,36 @@ class EncryptedVector : public EncryptedTensor { auto diag = matrix.get_diagonal( -local_i, this->tenseal_context()->template slot_count()); - replicate_vector( - diag, - this->tenseal_context()->template slot_count()); - - rotate(diag.begin(), diag.begin() + diag.size() - local_i, - diag.end()); - - this->tenseal_context()->template encode(diag, - pt_diag); - if (this->_ciphertexts[0].parms_id() != pt_diag.parms_id()) { - this->set_to_same_mod(pt_diag, _ciphertexts[0]); + // don't add zero diagonals to (a) improve performance and (b) + // avoid transparent ciphertext issues + bool is_diag_nonzero = std::any_of( + diag.begin(), diag.end(), [](plain_t x) { return x != 0; }); + if (is_diag_nonzero) { + replicate_vector(diag, + this->tenseal_context() + ->template slot_count()); + + rotate(diag.begin(), diag.begin() + diag.size() - local_i, + diag.end()); + + this->tenseal_context()->template encode( + diag, pt_diag); + + if (this->_ciphertexts[0].parms_id() != + pt_diag.parms_id()) { + this->set_to_same_mod(pt_diag, _ciphertexts[0]); + } + this->tenseal_context()->evaluator->multiply_plain( + this->_ciphertexts[0], pt_diag, ct); + + this->tenseal_context()->evaluator->rotate_vector_inplace( + ct, local_i, *this->tenseal_context()->galois_keys()); + + // accumulate thread results + this->tenseal_context()->evaluator->add_inplace( + thread_result, ct); } - this->tenseal_context()->evaluator->multiply_plain( - this->_ciphertexts[0], pt_diag, ct); - - this->tenseal_context()->evaluator->rotate_vector_inplace( - ct, local_i, *this->tenseal_context()->galois_keys()); - - // accumulate thread results - this->tenseal_context()->evaluator->add_inplace(thread_result, - ct); } return thread_result; }; diff --git a/tests/python/tenseal/tensors/test_ckks_tensor.py b/tests/python/tenseal/tensors/test_ckks_tensor.py index 15011a0e..74984098 100644 --- a/tests/python/tenseal/tensors/test_ckks_tensor.py +++ b/tests/python/tenseal/tensors/test_ckks_tensor.py @@ -121,19 +121,39 @@ def test_reshape_batching(context, data, new_shape): @pytest.mark.parametrize( "data, slices, new_shape", [ - ([0, 1, 2, 3, 4, 5], [slice(1, 4, None)], [3]), - ([0, 1, 2, 3, 4, 5], [slice(1, None, None)], [5]), - ([0, 1, 2, 3, 4, 5], [slice(None, 4, None)], [4]), - ([[0, 1, 2], [0, 1, 2], [0, 1, 2]], [slice(1, 3, None), slice(0, 2, None)], [2, 2]), - ([[0, 1, 2], [0, 1, 2], [0, 1, 2]], [slice(1, None, None), slice(0, 2, None)], [2, 2]), + ([0, 1, 2, 3, 4, 5], (slice(1, 4, None),), [3]), + ([0, 1, 2, 3, 4, 5], (slice(1, None, None),), [5]), + ([0, 1, 2, 3, 4, 5], (slice(None, 4, None),), [4]), ( [[0, 1, 2], [0, 1, 2], [0, 1, 2]], - [slice(1, None, None), slice(None, None, None)], + ( + slice(1, 3, None), + slice(0, 2, None), + ), + [2, 2], + ), + ( + [[0, 1, 2], [0, 1, 2], [0, 1, 2]], + ( + slice(1, None, None), + slice(0, 2, None), + ), + [2, 2], + ), + ( + [[0, 1, 2], [0, 1, 2], [0, 1, 2]], + ( + slice(1, None, None), + slice(None, None, None), + ), [2, 3], ), ( [[0, 1, 2], [0, 1, 2], [0, 1, 2]], - [slice(None, None, None), slice(None, None, None)], + ( + slice(None, None, None), + slice(None, None, None), + ), [3, 3], ), ([[0, 1, 2], [0, 1, 2], [0, 1, 2]], 1, [1, 3]),