Skip to content

Commit

Permalink
clean up the module code
Browse files Browse the repository at this point in the history
  • Loading branch information
GiacomoPope committed Jul 19, 2024
1 parent 5390ada commit 5eb35bf
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 74 deletions.
30 changes: 14 additions & 16 deletions kyber.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from hashlib import sha3_256, sha3_512, shake_128, shake_256
from polynomials import *
from modules import *
from polynomials import PolynomialRing
from modules import Module
from ntt_helper import NTTHelperKyber
try:
from aes256_ctr_drbg import AES256_CTR_DRBG
Expand Down Expand Up @@ -76,7 +76,7 @@ def reseed_drbg(self, seed):
(Seemed overkill to code my own AES for Kyber)
"""
if self.drbg is None:
raise Warning(f"Cannot reseed DRBG without first initialising. Try using `set_drbg_seed`")
raise Warning("Cannot reseed DRBG without first initialising. Try using `set_drbg_seed`")
else:
self.drbg.reseed(seed)

Expand All @@ -87,7 +87,7 @@ def _xof(bytes32, a, b, length):
"""
input_bytes = bytes32 + a + b
if len(input_bytes) != 34:
raise ValueError(f"Input bytes should be one 32 byte array and 2 single bytes.")
raise ValueError("Input bytes should be one 32 byte array and 2 single bytes.")
return shake_128(input_bytes).digest(length)

@staticmethod
Expand All @@ -112,7 +112,7 @@ def _prf(s, b, length):
"""
input_bytes = s + b
if len(input_bytes) != 33:
raise ValueError(f"Input bytes should be one 32 byte array and one single byte.")
raise ValueError("Input bytes should be one 32 byte array and one single byte.")
return shake_256(input_bytes).digest(length)

@staticmethod
Expand All @@ -128,12 +128,12 @@ def _generate_error_vector(self, sigma, eta, N, is_ntt=False):
module from the Centered Binomial Distribution.
"""
elements = []
for i in range(self.k):
for _ in range(self.k):
input_bytes = self._prf(sigma, bytes([N]), 64*eta)
poly = self.R.cbd(input_bytes, eta, is_ntt=is_ntt)
elements.append(poly)
N = N + 1
v = self.M(elements).transpose()
v = self.M.vector(elements)
return v, N

def _generate_matrix_from_seed(self, rho, transpose=False, is_ntt=False):
Expand Down Expand Up @@ -189,8 +189,8 @@ def _cpapke_keygen(self):
t = (A @ s) + e

# Reduce vectors mod^+ q
t.reduce_coefficents()
s.reduce_coefficents()
t.reduce_coefficients()
s.reduce_coefficients()

# Encode elements to bytes and return
pk = t.encode(l=12) + rho
Expand All @@ -211,9 +211,7 @@ def _cpapke_enc(self, pk, m, coins):
"""
N = 0
rho = pk[-32:]

tt = self.M.decode(pk, 1, self.k, l=12, is_ntt=True)

t = self.M.decode_vector(pk, self.k, l=12, is_ntt=True)
# Encode message as polynomial
m_poly = self.R.decode(m, l=1).decompress(1)

Expand All @@ -233,7 +231,7 @@ def _cpapke_enc(self, pk, m, coins):

# Module/Polynomial arithmetic
u = (At @ r).from_ntt() + e1
v = (tt @ r)[0][0].from_ntt()
v = t.dot(r).from_ntt()
v = v + e2 + m_poly

# Ciphertext to bytes
Expand All @@ -258,17 +256,17 @@ def _cpapke_dec(self, sk, c):
c2 = c[index:]

# Recover the vector u and convert to NTT form
u = self.M.decode(c, self.k, 1, l=self.du).decompress(self.du)
u = self.M.decode_vector(c, self.k, l=self.du).decompress(self.du)
u.to_ntt()

# Recover the polynomial v
v = self.R.decode(c2, l=self.dv).decompress(self.dv)

# s_transpose (already in NTT form)
st = self.M.decode(sk, 1, self.k, l=12, is_ntt=True)
s = self.M.decode_vector(sk, self.k, l=12, is_ntt=True)

# Recover message as polynomial
m = (st @ u)[0][0].from_ntt()
m = s.dot(u).from_ntt()
m = v - m

# Return message as bytes
Expand Down
174 changes: 119 additions & 55 deletions modules.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
class Module:
def __init__(self, ring):
self.ring = ring

def decode(self, input_bytes, m, n, l=None, is_ntt=False):
if l is None:
# Input length must be 32*l*m*n bytes long
Expand All @@ -19,111 +19,161 @@ def decode(self, input_bytes, m, n, l=None, is_ntt=False):
mij = self.ring.decode(byte_chunks[n*i+j], l=l, is_ntt=is_ntt)
matrix[i][j] = mij
return self(matrix)

def decode_vector(self, input_bytes, k, l=None, is_ntt=False):
if l is None:
# Input length must be 32*l*k bytes long
l, check = divmod(8*len(input_bytes), self.ring.n*k)
if check != 0:
raise ValueError("input bytes must be a multiple of (polynomial degree) / 8")
else:
if self.ring.n*l*k > len(input_bytes)*8:
raise ValueError("Byte length is too short for given l")

# Bytes needed to decode a polynomial
chunk_length = 32*l

# Break input_bytes into blocks of length chunk_length
poly_bytes = [input_bytes[i:i+chunk_length] for i in range(0, len(input_bytes), chunk_length)]

# Encode each chunk of bytes as a polynomial, we iterate only the first k elements in case we've
# been sent too many bytes to decode for the vector
elements = [self.ring.decode(poly_bytes[i], l=l, is_ntt=is_ntt) for i in range(k)]

return self.vector(elements)

def __repr__(self):
return f"Module over the commutative ring: {self.ring}"

def __str__(self):
return f"Module over the commutative ring: {self.ring}"

def __call__(self, matrix_elements):
def __call__(self, matrix_elements, transpose=False):
if not isinstance(matrix_elements, list):
raise TypeError(f"Elements of a module are matrices, with elements .")
raise TypeError("elements of a module are matrices, built from elements of the base ring")

if isinstance(matrix_elements[0], list):
for element_list in matrix_elements:
if not all(isinstance(aij, self.ring.element) for aij in element_list):
raise TypeError(f"All elements of the matrix must be elements of the ring: {self.ring}")
return Module.Matrix(self, matrix_elements)
return Module.Matrix(self, matrix_elements, transpose=transpose)

elif isinstance(matrix_elements[0], self.ring.element):
if not all(isinstance(aij, self.ring.element) for aij in matrix_elements):
raise TypeError(f"All elements of the matrix must be elements of the ring: {self.ring}")
return Module.Matrix(self, [matrix_elements])
return Module.Matrix(self, [matrix_elements], transpose=transpose)

else:
raise TypeError(f"Elements of a module are matrices, built from elements of the base ring.")
raise TypeError("elements of a module are matrices, built from elements of the base ring")

def vector(self, elements):
"""
Construct a vector with the given elements
"""
return Module.Matrix(self, [elements], transpose=True)

class Matrix:
def __init__(self, parent, matrix_elements):
def __init__(self, parent, matrix_data, transpose=False):
self.parent = parent
self.rows = matrix_elements
self.m = len(matrix_elements)
self.n = len(matrix_elements[0])
self._data = matrix_data
self._transpose = transpose
if not self.check_dimensions():
raise ValueError("Inconsistent row lengths in matrix")

def get_dim(self):
return self.m, self.n
def dim(self):
"""
Return the dimensions of the matrix with m rows
and n columns"""
if not self._transpose:
return len(self._data), len(self._data[0])
else:
return len(self._data[0]), len(self._data)

def check_dimensions(self):
return all(len(row) == self.n for row in self.rows)
"""
Ensure that the matrix is rectangluar
"""
return len(set(map(len, self._data))) == 1

def transpose(self):
new_rows = [list(item) for item in zip(*self.rows)]
return self.parent(new_rows)
"""
Swap rows and columns of self
"""
return self.parent(self._data, not self._transpose)

def transpose_self(self):
self.m, self.n = self.n, self.m
self.rows = [list(item) for item in zip(*self.rows)]
return self
"""
Transpose in place
"""
self._transpose = not self._transpose
return

T = property(transpose)

def reduce_coefficents(self):
for row in self.rows:
def reduce_coefficients(self):
"""
Reduce every element in the polynomial
using the modulus of the PolynomialRing
"""
for row in self._data:
for ele in row:
ele.reduce_coefficents()
ele.reduce_coefficients()
return self

def encode(self, l=None):
output = b""
for row in self.rows:
for j in range(self.n):
output += row[j].encode(l=l)
for row in self._data:
for ele in row:
output += ele.encode(l=l)
return output

def compress(self, d):
for row in self.rows:
for row in self._data:
for ele in row:
ele.compress(d)
return self

def decompress(self, d):
for row in self.rows:
for row in self._data:
for ele in row:
ele.decompress(d)
return self

def to_ntt(self):
for row in self.rows:
for row in self._data:
for ele in row:
ele.to_ntt()
return self

def from_ntt(self):
for row in self.rows:
for row in self._data:
for ele in row:
ele.from_ntt()
return self

def __getitem__(self, i):
return self.rows[i]
def __getitem__(self, idx):
"""
matrix[i, j] returns the element on row i, column j
"""
assert isinstance(idx, tuple) and len(idx) == 2, "Can't access individual rows"
if not self._transpose:
return self._data[idx[0]][idx[1]]
else:
return self._data[idx[1]][idx[0]]

def __eq__(self, other):
return other.rows == self.rows
return other._data == self._data and other._transpose == self._transpose

def __add__(self, other):
if not isinstance(other, Module.Matrix):
raise TypeError("Can only add matrcies to other matrices")
raise TypeError("Can only add matrices to other matrices")
if self.parent != other.parent:
raise TypeError("Matricies must have the same base ring")
if self.get_dim() != other.get_dim():
raise TypeError("Matrices must have the same base ring")
if self.dim() != other.dim():
raise ValueError("Matrices are not of the same dimensions")

new_elements = []
for i in range(self.m):
new_elements.append([a+b for a,b in zip(self.rows[i], other.rows[i])])
return self.parent(new_elements)

m, n = self.dim()
return self.parent([[self[i, j] + other[i, j] for j in range(n)] for i in range(m)], False)

def __radd__(self, other):
return self.__add__(other)
Expand All @@ -134,16 +184,14 @@ def __iadd__(self, other):

def __sub__(self, other):
if not isinstance(other, Module.Matrix):
raise TypeError("Can only subtract matrcies from other matrices")
raise TypeError("Can only add matrices to other matrices")
if self.parent != other.parent:
raise TypeError("Matricies must have the same base ring")
if self.get_dim() != other.get_dim():
raise TypeError("Matrices must have the same base ring")
if self.dim() != other.dim():
raise ValueError("Matrices are not of the same dimensions")

new_elements = []
for i in range(self.m):
new_elements.append([a-b for a,b in zip(self.rows[i], other.rows[i])])
return self.parent(new_elements)
m, n = self.dim()
return self.parent([[self[i, j] - other[i, j] for j in range(n)] for i in range(m)], False)

def __rsub__(self, other):
return self.__sub__(other)
Expand All @@ -160,18 +208,34 @@ def __matmul__(self, other):
raise TypeError("Can only multiply matrcies with other matrices")
if self.parent != other.parent:
raise TypeError("Matricies must have the same base ring")
if self.n != other.m:

m, n = self.dim()
n_, l = other.dim()
if not n == n_:
raise ValueError("Matrices are of incompatible dimensions")

new_elements = [[sum(a*b for a,b in zip(A_row, B_col)) for B_col in other.transpose().rows] for A_row in self.rows]
return self.parent(new_elements)
return self.parent(
[
[sum(self[i, k] * other[k, j] for k in range(n)) for j in range(l)]
for i in range(m)
]
)

def dot(self, other):
"""
Inner product
"""
res = self.T @ other
assert res.dim() == (1, 1)
return res[0, 0]

def __repr__(self):
if len(self.rows) == 1:
return str(self.rows[0])
n, m = self.dim()

if n == 1:
return str(self._data[0])
max_col_width = []
for n_col in range(self.n):
max_col_width.append(max(len(str(row[n_col])) for row in self.rows))
info = ']\n['.join([', '.join([f'{str(x):>{max_col_width[i]}}' for i,x in enumerate(r)]) for r in self.rows])
for n_col in range(n):
max_col_width.append(max(len(str(row[n_col])) for row in self._data))
info = ']\n['.join([', '.join([f'{str(x):>{max_col_width[i]}}' for i,x in enumerate(r)]) for r in self._data])
return f"[{info}]"

6 changes: 3 additions & 3 deletions polynomials.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def decode(self, input_bytes, l=None, is_ntt=False):
raise ValueError("input bytes must be a multiple of (polynomial degree) / 8")
else:
if self.n*l != len(input_bytes)*8:
raise ValueError("input bytes must be a multiple of (polynomial degree) / 8")
raise ValueError(f"input bytes must be a multiple of (polynomial degree) / 8, {self.n*l = }, {len(input_bytes)*8 = }")
coefficients = [0 for _ in range(self.n)]
list_of_bits = bytes_to_bits(input_bytes)
for i in range(self.n):
Expand Down Expand Up @@ -121,9 +121,9 @@ def parse_coefficients(self, coefficients):
coefficients = coefficients + [0 for _ in range (self.parent.n - l)]
return coefficients

def reduce_coefficents(self):
def reduce_coefficients(self):
"""
Reduce all coefficents modulo q
Reduce all coefficients modulo q
"""
self.coeffs = [c % self.parent.q for c in self.coeffs]
return self
Expand Down

0 comments on commit 5eb35bf

Please sign in to comment.