Skip to content

Commit

Permalink
Python object wrappers (#185)
Browse files Browse the repository at this point in the history
* Python wrapper for CKKSTensor + some refactor

* skeleton context wrapper + more wrapping ckkstensor

* complete context with keys

* skeleton for PlainTensor

* finalize plaintensor

* decrypt tensor

* CKKSTensor wrapper complete

* factorizing code into abstract_tensor

* shape as an attr, not a function

* tolist part of plain_tensor

* fix copy

* fix negate

* vectors

* fix decryption with secretkey

* fix vector packing

* fix copy

* reorganized packages

* complete the plaintensor implementation

* lint

* useless lines

* Bazel fix

* bazel paths issue

* remove keys from public API

* docs and type hints

* SCHEME_TYPE as a pytohn enum + some fixes

* type fix

Co-authored-by: Cebere Bogdan <bogdan.cebere@gmail.com>
  • Loading branch information
youben11 and bcebere authored Dec 8, 2020
1 parent 67e2979 commit 6ee6f33
Show file tree
Hide file tree
Showing 19 changed files with 1,078 additions and 337 deletions.
1 change: 1 addition & 0 deletions tenseal/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ py_library(
name = "tenseal",
srcs = [
"__init__.py",
"enc_context.py",
"version.py",
],
data = ["//tenseal:_tenseal_cpp.so"],
Expand Down
153 changes: 85 additions & 68 deletions tenseal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,90 +5,108 @@
import _tenseal_cpp as _ts_cpp
except ImportError:
import tenseal._tenseal_cpp as _ts_cpp
from tenseal.tensors import CKKSTensor, CKKSVector, BFVVector, PlainTensor

from tenseal.tensors import (
bfv_vector,
bfv_vector_from,
ckks_vector,
ckks_vector_from,
ckks_tensor,
ckks_tensor_from,
plain_tensor,
tolist,
)
from tenseal.enc_context import Context, SCHEME_TYPE
from tenseal.version import __version__


SCHEME_TYPE = _ts_cpp.SCHEME_TYPE
PublicKey = _ts_cpp.PublicKey
SecretKey = _ts_cpp.SecretKey
RelinKeys = _ts_cpp.RelinKeys
GaloisKeys = _ts_cpp.GaloisKeys

# Vectors
BFVVector = _ts_cpp.BFVVector
CKKSVector = _ts_cpp.CKKSVector
CKKSTensor = _ts_cpp.CKKSTensor

# utils
im2col_encoding = _ts_cpp.im2col_encoding
enc_matmul_encoding = _ts_cpp.enc_matmul_encoding


def context(
scheme, poly_modulus_degree, plain_modulus=None, coeff_mod_bit_sizes=None, n_threads=None
):
"""Construct a context that holds keys and parameters needed for operating
encrypted tensors using either BFV or CKKS scheme.
def im2col_encoding(
context: Context, tensor, kernel_n_rows: int, kernel_n_cols: int, stride: int
) -> CKKSVector:
"""Encoding an image into a CKKSVector. This serves for doing efficient Conv2d evaluation.
Args:
scheme : define the scheme to be used, either SCHEME_TYPE.BFV or SCHEME_TYPE.CKKS.
poly_modulus_degree (int): The degree of the polynomial modulus, must be a power of two.
plain_modulus (int): The plaintext modulus. Should not be passed when the scheme is CKKS.
coeff_mod_bit_sizes (list of int): List of bit size for each coeffecient modulus.
Can be an empty list for BFV, a default value will be given.
context: a Context object, holding the encryption parameters and keys.
tensor: tensor-like object.
kernel_n_rows: number of rows in the kernel that will be used for conv2d.
kernel_n_cols: number of columns in the kernel that will be used for conv2d.
stride: stride that will be used for conv2d.
Returns:
A TenSEALContext object.
Encrypted image into a CKKSVector.
"""
if scheme == SCHEME_TYPE.BFV:
if plain_modulus is None:
raise ValueError("plain_modulus must be provided")
if coeff_mod_bit_sizes is None:
coeff_mod_bit_sizes = []

elif scheme == SCHEME_TYPE.CKKS:
# must be int, but the value doesn't matter for ckks
plain_modulus = 0
if coeff_mod_bit_sizes is None:
raise ValueError("coeff_mod_bit_sizes must be provided")

else:
raise ValueError("Invalid scheme type, use either SCHEME_TYPE.BFV or SCHEME_TYPE.CKKS")

# We can't pass None here, everything should be set prior to this call
if isinstance(n_threads, int) and n_threads > 0:
return _ts_cpp.TenSEALContext.new(
scheme, poly_modulus_degree, plain_modulus, coeff_mod_bit_sizes, n_threads
)

return _ts_cpp.TenSEALContext.new(
scheme, poly_modulus_degree, plain_modulus, coeff_mod_bit_sizes
if not isinstance(context, Context):
raise TypeError("context must be of type tenseal.Context")
if not isinstance(tensor, PlainTensor):
tensor = plain_tensor(tensor)
if len(tensor.shape) != 2:
raise ValueError("tensor must be a matrix")
matrix = tensor.tolist()

ckks_vec, windows_nb = _ts_cpp.im2col_encoding(
context.data, matrix, kernel_n_cols, kernel_n_rows, stride
)
return CKKSVector._wrap(ckks_vec), windows_nb


def context_from(buff, n_threads=None):
"""Construct a context from a serialized buffer.
def enc_matmul_encoding(context: Context, tensor) -> CKKSVector:
"""Encode a matrix into a CKKSVector for later matrix(encrypted)-vector(plain) multiplication.
Args:
buff : bytes buffer from the original context .
context: a Context object, holding the encryption parameters and keys.
tensor: tensor-like object of shape 2 (matrix).
Returns:
A TenSEALContext object.
Encrypted matrix into a CKKSVector.
"""
if n_threads:
return _ts_cpp.TenSEALContext.deserialize(buff, n_threads)
return _ts_cpp.TenSEALContext.deserialize(buff)
if not isinstance(context, Context):
raise TypeError("context must be of type tenseal.Context")
if not isinstance(tensor, PlainTensor):
tensor = plain_tensor(tensor)
if len(tensor.shape) != 2:
raise ValueError("tensor must be a matrix")
matrix = tensor.tolist()
return CKKSVector._wrap(_ts_cpp.enc_matmul_encoding(context.data, matrix))


def context(*args, **kwargs) -> Context:
"""Constructor function for tenseal.Context"""
return Context(*args, **kwargs)


def context_from(data: bytes, n_threads: int = None) -> Context:
"""Load a Context from a protocol buffer.
n_threads set the concurrency for the context if parallel computation is requested."""
return Context.load(data, n_threads)


def plain_tensor(*args, **kwargs) -> PlainTensor:
"""Constructor function for tenseal.PlainTensor"""
return PlainTensor(*args, **kwargs)


def bfv_vector(*args, **kwargs) -> BFVVector:
"""Constructor function for tenseal.BFVVector"""
return BFVVector(*args, **kwargs)


def bfv_vector_from(context: Context, data: bytes) -> BFVVector:
"""Load a BFVVector from a protocol buffer.
Requires the context to be linked with."""
return BFVVector.load(context, data)


def ckks_vector(*args, **kwargs) -> CKKSVector:
"""Constructor function for tenseal.CKKSVector"""
return CKKSVector(*args, **kwargs)


def ckks_vector_from(context: Context, data: bytes) -> CKKSVector:
"""Load a CKKSVector from a protocol buffer.
Requires the context to be linked with."""
return CKKSVector.load(context, data)


def ckks_tensor(*args, **kwargs) -> CKKSTensor:
"""Constructor function for tenseal.CKKSTensor"""
return CKKSTensor(*args, **kwargs)


def ckks_tensor_from(context: Context, data: bytes) -> CKKSTensor:
"""Load a CKKSTensor from a protocol buffer.
Requires the context to be linked with."""
return CKKSTensor.load(context, data)


__all__ = [
Expand All @@ -102,6 +120,5 @@ def context_from(buff, n_threads=None):
"context_from",
"im2col_encoding",
"plain_tensor",
"tolist",
"__version__",
]
4 changes: 1 addition & 3 deletions tenseal/deps.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ load("@rules_foreign_cc//:workspace_definitions.bzl", "rules_foreign_cc_dependen

load("@pybind11_bazel//:python_configure.bzl", "python_configure")
load("@rules_python//python:repositories.bzl", "py_repositories")
load("@rules_python_external//:repositories.bzl", "rules_python_external_dependencies")
load("@rules_python_external//:defs.bzl", "pip_install")
load("@rules_python//python:pip.bzl", "pip_install")

def tenseal_deps():
if "com_google_googletest" not in native.existing_rules():
Expand Down Expand Up @@ -45,7 +44,6 @@ def tenseal_deps():
python_configure(name = "local_config_python", python_version = "3")

# Install pip requirements for Python tests.
rules_python_external_dependencies()
pip_install(
name = "org_openmined_tenseal_python_deps",
requirements = "@org_openmined_tenseal//:requirements_dev.txt",
Expand Down
Loading

0 comments on commit 6ee6f33

Please sign in to comment.