Skip to content

Commit

Permalink
some docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
GiacomoPope committed Jul 25, 2024
1 parent 4b4e6bb commit 76e7a13
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 5 deletions.
24 changes: 24 additions & 0 deletions src/kyber_py/modules/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@ def __init__(self):
self.matrix = MatrixKyber

def decode_vector(self, input_bytes, k, d, is_ntt=False):
"""
Decode bytes into a a vector of polynomial elements.
Each element is assumed to be encoded as a polynomial with ``d``-bit
coefficients (hence a polynomial is encoded into ``256 * d`` bits).
A vector of length ``k`` then has ``256 * d * k`` bits.
"""
# Ensure the input bytes are the correct length to create k elements with
# d bits used for each coefficient
if self.ring.n * d * k != len(input_bytes) * 8:
Expand All @@ -32,28 +40,44 @@ def __init__(self, parent, matrix_data, transpose=False):
super().__init__(parent, matrix_data, transpose=transpose)

def encode(self, d):
"""
Encode every element of a matrix into bytes and concatenate
"""
output = b""
for row in self._data:
for ele in row:
output += ele.encode(d)
return output

def compress(self, d):
"""
Compress every element of the matrix to have at most ``d`` bits
"""
for row in self._data:
for ele in row:
ele.compress(d)
return self

def decompress(self, d):
"""
Perform (lossy) decompression of the polynomial assuming it has been
compressed to have at most ``d`` bits.
"""
for row in self._data:
for ele in row:
ele.decompress(d)
return self

def to_ntt(self):
"""
Convert every element of the matrix into NTT form
"""
data = [[x.to_ntt() for x in row] for row in self._data]
return self.parent(data, transpose=self._transpose)

def from_ntt(self):
"""
Convert every element of the matrix from NTT form
"""
data = [[x.from_ntt() for x in row] for row in self._data]
return self.parent(data, transpose=self._transpose)
24 changes: 20 additions & 4 deletions src/kyber_py/modules/modules_generic.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
class Module:
def __init__(self, ring):
"""
Initialise a module over the ring ``ring``.
"""
self.ring = ring
self.matrix = Matrix

def random_element(self, m, n):
"""
Generate a random element of the module of dimension m x n
:param int m: the number of rows in the matrix
:param int m: the number of columns in tge matrix
:return: an element of the module with dimension `m times n`
"""
elements = [
[self.ring.random_element() for _ in range(n)] for _ in range(m)
]
Expand Down Expand Up @@ -47,7 +57,10 @@ def __call__(self, matrix_elements, transpose=False):

def vector(self, elements):
"""
Construct a vector with the given elements
Construct a vector given a list of elements of the module's ring
:param list: a list of elements of the ring
:return: a vector of the module
"""
return self.matrix(self, [elements], transpose=True)

Expand All @@ -64,6 +77,9 @@ def dim(self):
"""
Return the dimensions of the matrix with m rows
and n columns
:return: the dimension of the matrix ``(m, n)``
:rtype: tuple(int, int)
"""
if not self._transpose:
return len(self._data), len(self._data[0])
Expand All @@ -78,13 +94,13 @@ def _check_dimensions(self):

def transpose(self):
"""
Swap rows and columns of self
Return a matrix with the rows and columns of swapped
"""
return self.parent(self._data, not self._transpose)

def transpose_self(self):
"""
Transpose in place
Swap the rows and columns of the matrix in place
"""
self._transpose = not self._transpose
return
Expand Down Expand Up @@ -193,7 +209,7 @@ def __matmul__(self, other):

def dot(self, other):
"""
Inner product
Compute the inner product of two vectors
"""
if not isinstance(other, type(self)):
raise TypeError("Can only perform dot product with other matrices")
Expand Down
13 changes: 12 additions & 1 deletion src/kyber_py/polynomials/polynomials.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def _decompress_ele(self, x, d):
def compress(self, d):
"""
Compress the polynomial by compressing each coefficient
NOTE: This is lossy compression
"""
self.coeffs = [self._compress_ele(c, d) for c in self.coeffs]
Expand All @@ -165,6 +166,7 @@ def compress(self, d):
def decompress(self, d):
"""
Decompress the polynomial by decompressing each coefficient
NOTE: This as compression is lossy, we have
x' = decompress(compress(x)), which x' != x, but is
close in magnitude.
Expand Down Expand Up @@ -198,6 +200,9 @@ def to_ntt(self):
return self.parent(coeffs, is_ntt=True)

def from_ntt(self):
"""
Not supported, raises a ``TypeError``
"""
raise TypeError(f"Polynomial not in the NTT domain: {type(self) = }")


Expand All @@ -207,6 +212,9 @@ def __init__(self, parent, coefficients):
self.coeffs = self._parse_coefficients(coefficients)

def to_ntt(self):
"""
Not supported, raises a ``TypeError``
"""
raise TypeError(
f"Polynomial is already in the NTT domain: {type(self) = }"
)
Expand Down Expand Up @@ -249,6 +257,10 @@ def _ntt_base_multiplication(a0, a1, b0, b1, zeta):
return r0, r1

def _ntt_coefficient_multiplication(self, f_coeffs, g_coeffs):
"""
Given the coefficients of two polynomials compute the coefficients of
their product
"""
new_coeffs = []
zetas = self.parent.ntt_zetas
for i in range(64):
Expand All @@ -272,7 +284,6 @@ def _ntt_coefficient_multiplication(self, f_coeffs, g_coeffs):
def _ntt_multiplication(self, other):
"""
Number Theoretic Transform multiplication.
Only implemented (currently) for n = 256
"""
new_coeffs = self._ntt_coefficient_multiplication(
self.coeffs, other.coeffs
Expand Down
7 changes: 7 additions & 0 deletions src/kyber_py/polynomials/polynomials_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,16 @@ def __init__(self, q, n):
self.element = Polynomial

def gen(self):
"""
Return the generator `x` of the polynomial ring
"""
return self([0, 1])

def random_element(self):
"""
Compute a random element of the polynomial ring with coefficients in the
canonical range: ``[0, q-1]``
"""
coefficients = [random.randint(0, self.q - 1) for _ in range(self.n)]
return self(coefficients)

Expand Down

0 comments on commit 76e7a13

Please sign in to comment.