Skip to content

Commit

Permalink
clean up drbg
Browse files Browse the repository at this point in the history
  • Loading branch information
GiacomoPope committed Jul 25, 2024
1 parent ab4f388 commit befa134
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 21 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ have made a child class `PolynomialRingKyber(PolynomialRing)` which has the
following additional methods:

- `PolynomialRingKyber`
- `parse(bytes)` takes $3n$ bytes and produces a random polynomial in $R_q$
- `ntt_sample(bytes)` takes $3n$ bytes and produces a random polynomial in $R_q$
- `decode(bytes, l)` takes $\ell n$ bits and produces a polynomial in $R_q$
- `cbd(beta, eta)` takes $\eta \cdot n / 4$ bytes and produces a polynomial in
$R_q$ with coefficents taken from a centered binomial distribution
Expand Down
72 changes: 56 additions & 16 deletions src/kyber_py/drbg/aes256_ctr_drbg.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,38 @@
import os
from ..utilities.utils import xor_bytes
from Crypto.Cipher import AES
from typing import Optional


class AES256_CTR_DRBG:
def __init__(self, seed=None, personalization=b""):
def __init__(
self, seed: Optional[bytes] = None, personalization: bytes = b""
):
"""
DRBG implementation based on AES-256 CTR following the document NIST SP
800-90A Section 10.2.1
https://csrc.nist.gov/pubs/sp/800/90/a/r1/final
Used for deterministic randomness, particularly used for comparing the
output of Kyber/ML-KEM against known answer tests.
:param bytes seed: 48 byte seed, if none is supplied a seed is generated
using ``os.urandom(48)``.
:param bytes personalization: optional bytes, of length at most 48 used
during instantiation of the DRBG
"""
self.seed_length = 48
self.reseed_interval = 2**48
self.key = bytes([0]) * 32
self.V = bytes([0]) * 16
self.entropy_input = self.__check_entropy_input(seed)

seed_material = self.__instantiate(personalization=personalization)
self.ctr_drbg_update(seed_material)
self.__ctr_drbg_update(seed_material)
self.reseed_ctr = 1

def __check_entropy_input(self, entropy_input):
def __check_entropy_input(self, entropy_input: bytes) -> bytes:
"""
If no entropy given, us os.urandom, else
check that the input is of the right length.
Expand All @@ -29,32 +46,44 @@ def __check_entropy_input(self, entropy_input):
)
return entropy_input

def __instantiate(self, personalization=b""):
def __instantiate(self, personalization: bytes = b"") -> bytes:
"""
Combine the input seed and optional personalisation
string into the seed material for the DRBG
Section 10.2.1.3.1, Page 52 (CTR_DRBG_Instantiate_algorithm)
"""
if len(personalization) > self.seed_length:
raise ValueError(
f"The Personalization String must be at most length: "
f"{self.seed_length}. Input has length {len(personalization)}"
)
elif len(personalization) < self.seed_length:
personalization += bytes([0]) * (
self.seed_length - len(personalization)
)
# Ensure personalization has exactly seed_length bytes
personalization += bytes([0]) * (
self.seed_length - len(personalization)
)
# debugging
assert len(personalization) == self.seed_length
return xor_bytes(self.entropy_input, personalization)

def __increment_counter(self):
def __increment_counter(self) -> None:
"""
Increment the internal counter of the DRBG
"""
int_V = int.from_bytes(self.V, "big")
new_V = (int_V + 1) % 2 ** (8 * 16)
new_V = (int_V + 1) % 2**128
self.V = new_V.to_bytes(16, byteorder="big")

def ctr_drbg_update(self, provided_data):
def __ctr_drbg_update(self, provided_data: bytes) -> None:
"""
Updates the internal state of the CTR_DRBG using the
provided_data
Section 10.2.1.2, Page 51 (CTR_DRBG_Update)
"""
tmp = b""
cipher = AES.new(self.key, AES.MODE_ECB)

# Collect bytes from AES ECB
while len(tmp) != self.seed_length:
self.__increment_counter()
Expand All @@ -68,7 +97,19 @@ def ctr_drbg_update(self, provided_data):
self.key = tmp[:32]
self.V = tmp[32:]

def random_bytes(self, num_bytes, additional=None):
def random_bytes(
self, num_bytes: int, additional: Optional[bytes] = None
) -> bytes:
"""
Generate pseudorandom bytes without a generating function
Section 10.2.1.5.1, Page 56 (CTR_DRBG_Generate_algorithm)
:param int num_bytes: the number of random bytes requested
:param bytes additional: optional bytes to be mixed into the generation
:return: pseudorandom bytes extracted from the DRBG of length ``num_bytes``.
:rtype: bytes
"""
# We don't cover this in coverage as we would need to run the counter 2^48 times
if self.reseed_ctr >= self.reseed_interval: # pragma: no cover
raise Warning("The DRBG has been exhausted! Reseed!")
Expand All @@ -82,9 +123,8 @@ def random_bytes(self, num_bytes, additional=None):
f"The additional input must be of length at most: "
f"{self.seed_length}. Input has length {len(additional)}"
)
elif len(additional) < self.seed_length:
additional += bytes([0]) * (self.seed_length - len(additional))
self.ctr_drbg_update(additional)
additional += bytes([0]) * (self.seed_length - len(additional))
self.__ctr_drbg_update(additional)

# Collect bytes!
tmp = b""
Expand All @@ -95,6 +135,6 @@ def random_bytes(self, num_bytes, additional=None):

# Collect only the requested number of bits
output_bytes = tmp[:num_bytes]
self.ctr_drbg_update(additional)
self.__ctr_drbg_update(additional)
self.reseed_ctr += 1
return output_bytes
2 changes: 1 addition & 1 deletion src/kyber_py/kyber/kyber.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def _generate_matrix_from_seed(self, rho, transpose=False):
for i in range(self.k):
for j in range(self.k):
input_bytes = self._xof(rho, bytes([j]), bytes([i]))
A_data[i][j] = self.R.parse(input_bytes, is_ntt=True)
A_data[i][j] = self.R.ntt_sample(input_bytes)
A_hat = self.M(A_data, transpose=transpose)
return A_hat

Expand Down
2 changes: 1 addition & 1 deletion src/kyber_py/ml_kem/ml_kem.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def _generate_matrix_from_seed(self, rho, transpose=False):
for i in range(self.k):
for j in range(self.k):
xof_bytes = self._xof(rho, bytes([j]), bytes([i]))
A_data[i][j] = self.R.parse(xof_bytes, is_ntt=True)
A_data[i][j] = self.R.ntt_sample(xof_bytes)
A_hat = self.M(A_data, transpose=transpose)
return A_hat

Expand Down
10 changes: 8 additions & 2 deletions src/kyber_py/polynomials/polynomials.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,14 @@ def _br(i, k):
bin_i = bin(i & (2**k - 1))[2:].zfill(k)
return int(bin_i[::-1], 2)

def parse(self, input_bytes, is_ntt=False):
def ntt_sample(self, input_bytes):
"""
Algorithm 1 (Parse)
https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf
Algorithm 6 (Sample NTT)
FIPS 203-ipd
Parse: B^* -> R
"""
i, j = 0, 0
Expand All @@ -51,13 +54,16 @@ def parse(self, input_bytes, is_ntt=False):
j = j + 1

i = i + 3
return self(coefficients, is_ntt=is_ntt)
return self(coefficients, is_ntt=True)

def cbd(self, input_bytes, eta, is_ntt=False):
"""
Algorithm 2 (Centered Binomial Distribution)
https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf
Algorithm 6 (Sample Poly CBD)
FIPS 203-ipd
Expects a byte array of length (eta * deg / 4)
For Kyber, this is 64 eta.
"""
Expand Down

0 comments on commit befa134

Please sign in to comment.