Skip to content

Commit

Permalink
Merge pull request #76 from GiacomoPope/optimise_encode_decode
Browse files Browse the repository at this point in the history
optimise encoding and decoding
  • Loading branch information
GiacomoPope authored Jul 25, 2024
2 parents 52ca101 + e2f86e6 commit ab4f388
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 47 deletions.
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,9 @@ The above example would also work with `ML_KEM_768` and `ML_KEM_1024`.

| Params | keygen | keygen/s | encap | encap/s | decap | decap/s |
|------------|---------:|-----------:|--------:|----------:|--------:|--------:|
| ML-KEM-512 | 3.84ms | 260.47 | 4.99ms | 200.44 | 6.40ms | 156.15 |
| ML-KEM-768 | 5.67ms | 176.26 | 7.15ms | 139.84 | 8.99ms | 111.27 |
| ML-KEM-1024| 8.32ms | 120.15 | 10.10ms | 99.02 | 12.40ms | 80.66 |
| ML-KEM-512 | 1.96ms | 511.30 | 2.92ms | 342.26 | 4.20ms | 237.91 |
| ML-KEM-768 | 3.31ms | 302.51 | 4.48ms | 223.04 | 6.14ms | 162.86 |
| ML-KEM-1024| 5.02ms | 199.07 | 6.41ms | 155.89 | 8.47ms | 118.01 |

All times recorded using a Intel Core i7-9750H CPU and averaged over 1000 runs.

Expand Down Expand Up @@ -146,9 +146,9 @@ currently only support $q = 3329$ and $n = 256$.

| Params | keygen | keygen/s | encap | encap/s | decap | decap/s |
|------------|---------:|-----------:|--------:|----------:|--------:|--------:|
| Kyber512 | 3.86ms | 258.85 | 4.43ms | 225.78 | 5.82ms | 171.72 |
| Kyber768 | 5.75ms | 173.96 | 6.38ms | 156.68 | 8.20ms | 121.93 |
| Kyber1024 | 8.26ms | 121.01 | 8.88ms | 112.60 | 11.15ms | 89.71 |
| Kyber512 | 2.02ms | 493.99 | 2.84ms | 352.53 | 4.12ms | 242.82 |
| Kyber768 | 3.40ms | 294.13 | 4.38ms | 228.41 | 6.06ms | 165.13 |
| Kyber1024 | 5.09ms | 196.61 | 6.22ms | 160.72 | 8.29ms | 120.68 |

All times recorded using a Intel Core i7-9750H CPU and averaged over 1000 runs.

Expand Down
44 changes: 20 additions & 24 deletions src/kyber_py/polynomials/polynomials.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ..utilities.utils import bit_count
from .polynomials_generic import PolynomialRing, Polynomial
from ..utilities.utils import bytes_to_bits, bitstring_to_bytes


class PolynomialRingKyber(PolynomialRing):
Expand Down Expand Up @@ -63,10 +63,14 @@ def cbd(self, input_bytes, eta, is_ntt=False):
"""
assert 64 * eta == len(input_bytes)
coefficients = [0 for _ in range(256)]
list_of_bits = bytes_to_bits(input_bytes)
b_int = int.from_bytes(input_bytes, "little")
mask = (1 << eta) - 1
mask2 = (1 << 2 * eta) - 1
for i in range(256):
a = sum(list_of_bits[eta * 2 * i : eta * (2 * i + 1)])
b = sum(list_of_bits[eta * (2 * i + 1) : eta * (2 * i + 2)])
x = b_int & mask2
a = bit_count(x & mask)
b = bit_count((x >> eta) & mask)
b_int >>= 2 * eta
coefficients[i] = (a - b) % 3329
return self(coefficients, is_ntt=is_ntt)

Expand All @@ -86,27 +90,15 @@ def decode(self, input_bytes, d, is_ntt=False):
if d == 12:
m = 3329
else:
m = 2**d
m = 1 << d

# Helper values
tmp, idx = 0, 0
bit_index = 0
mask = (1 << d) - 1
coeffs = [0 for _ in range(256)]
b_int = int.from_bytes(input_bytes, "little")
mask = (1 << d) - 1
for i in range(256):
coeffs[i] = (b_int & mask) % m
b_int >>= d

# Iterate through all bytes
for b in input_bytes:
tmp |= b << bit_index
bit_index += 8

while bit_index >= d:
# Set the coefficient
coeffs[idx] = (tmp & mask) % m

# Update helpers
bit_index -= d
tmp >>= d
idx += 1
return self(coeffs, is_ntt=is_ntt)

def __call__(self, coefficients, is_ntt=False):
Expand All @@ -133,8 +125,12 @@ def encode(self, d):
"""
Encode (Inverse of Algorithm 3)
"""
bit_string = "".join(format(c, f"0{d}b")[::-1] for c in self.coeffs)
return bitstring_to_bytes(bit_string)
t = 0
for i in range(255):
t |= self.coeffs[256 - i - 1]
t <<= d
t |= self.coeffs[0]
return t.to_bytes(32 * d, "little")

def _compress_ele(self, x, d):
"""
Expand Down
32 changes: 15 additions & 17 deletions src/kyber_py/utilities/utils.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,21 @@
def bytes_to_bits(input_bytes):
"""
FIPS 203: Algorithm 3
import sys

Convert bytes to an array of bits. Bytes are converted little endianness
following the paper
"""
b = [0 for _ in range(8 * len(input_bytes))]
for i, byte in enumerate(input_bytes):
for j in range(8):
b[8 * i + j] = byte % 2
byte //= 2
return b
# int.bit_count() was only made available in 3.10
if sys.version_info >= (3, 10):

def bit_count(x: int) -> int:
"""
Count the number of bits in x
"""
return x.bit_count()

def bitstring_to_bytes(s):
"""
Convert a string of bits to bytes with bytes stored little endian
"""
return bytes([int(s[i : i + 8][::-1], 2) for i in range(0, len(s), 8)])
else:

def bit_count(x: int) -> int:
"""
Count the number of bits in x
"""
return bin(x).count("1")


def xor_bytes(a, b):
Expand Down

0 comments on commit ab4f388

Please sign in to comment.