Skip to content

Commit

Permalink
Merge pull request #43 from GiacomoPope/align_api
Browse files Browse the repository at this point in the history
align API between ML-KEM and Kyber
  • Loading branch information
GiacomoPope authored Jul 22, 2024
2 parents 3e27a32 + 2958c66 commit 3c41543
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 22 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,16 +105,16 @@ There are three functions exposed on the `Kyber` class which are intended for
use:

- `Kyber.keygen()`: generate a keypair `(pk, sk)`
- `Kyber.enc(pk)`: generate a challenge and a shared key `(c, key)`
- `Kyber.dec(c, sk)`: generate the shared key `key`
- `Kyber.encaps(pk)`: generate shared key and challenge `(key, c)`
- `Kyber.decaps(c, sk)`: generate the shared key `key`

#### Example

```python
>>> from kyber import Kyber512
>>> pk, sk = Kyber512.keygen()
>>> c, key = Kyber512.enc(pk)
>>> _key = Kyber512.dec(c, sk)
>>> key, c = Kyber512.encaps(pk)
>>> _key = Kyber512.decaps(c, sk)
>>> assert key == _key
```

Expand Down
10 changes: 5 additions & 5 deletions benchmarks/benchmark_kyber.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

def profile_kyber(Kyber):
pk, sk = Kyber.keygen()
c, key = Kyber.enc(pk)
key, c = Kyber.encaps(pk)

gvars = {}
lvars = {"Kyber": Kyber, "c": c, "pk": pk, "sk": sk}
Expand All @@ -17,13 +17,13 @@ def profile_kyber(Kyber):
sort=1,
)
cProfile.runctx(
"[Kyber.enc(pk) for _ in range(100)]",
"[Kyber.encaps(pk) for _ in range(100)]",
globals=gvars,
locals=lvars,
sort=1,
)
cProfile.runctx(
"[Kyber.dec(c, sk) for _ in range(100)]",
"[Kyber.decaps(c, sk) for _ in range(100)]",
globals=gvars,
locals=lvars,
sort=1,
Expand All @@ -41,11 +41,11 @@ def benchmark_kyber(Kyber, name, count):
keygen_times.append(time() - t0)

t1 = time()
c, key = Kyber.enc(pk)
key, c = Kyber.encaps(pk)
enc_times.append(time() - t1)

t2 = time()
dec = Kyber.dec(c, sk)
dec = Kyber.decaps(c, sk)
dec_times.append(time() - t2)

avg_keygen = sum(keygen_times) / count
Expand Down
13 changes: 9 additions & 4 deletions kyber/kyber.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,25 +276,30 @@ def keygen(self):
sk = _sk + pk + self._h(pk) + z
return pk, sk

def enc(self, pk, key_length=32):
def encaps(self, pk, key_length=32):
"""
Algorithm 8 (CCA KEM Encapsulation)
https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf
Input:
pk: Public Key
Output:
c: Ciphertext
K: Shared key
c: Ciphertext
NOTE::
We switch the order of the output (c, K) as (K, c) to align encaps output
with FIPS 203.
"""
m = self.random_bytes(32)
m_hash = self._h(m)
Kbar, r = self._g(m_hash + self._h(pk))
c = self._cpapke_enc(pk, m_hash, r)
K = self._kdf(Kbar + self._h(c), key_length)
return c, K
return K, c

def dec(self, c, sk, key_length=32):
def decaps(self, c, sk, key_length=32):
"""
Algorithm 9 (CCA KEM Decapsulation)
https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf
Expand Down
17 changes: 8 additions & 9 deletions tests/test_kyber.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,16 @@ def parse_kat_data(data):

class TestKyber(unittest.TestCase):
"""
Test Kyber levels for internal
consistency by generating keypairs
and shared secrets.
Test Kyber levels for internal consistency by generating keypairs and
shared secrets.
"""

def generic_test_kyber(self, Kyber, count):
for _ in range(count):
pk, sk = Kyber.keygen()
for _ in range(count):
c, key = Kyber.enc(pk)
_key = Kyber.dec(c, sk)
key, c = Kyber.encaps(pk)
_key = Kyber.decaps(c, sk)
self.assertEqual(key, _key)

def test_kyber512(self):
Expand Down Expand Up @@ -80,8 +79,8 @@ def generic_test_kyber_deterministic(self, Kyber, count):
pk, sk = Kyber.keygen()
for _ in range(count):
Kyber.set_drbg_seed(seed)
c, key = Kyber.enc(pk)
_key = Kyber.dec(c, sk)
key, c = Kyber.encaps(pk)
_key = Kyber.decaps(c, sk)
# Check key derivation works
self.assertEqual(key, _key)
key_output.append(c + key)
Expand Down Expand Up @@ -125,12 +124,12 @@ def generic_test_kyber_known_answer(self, Kyber, filename):
self.assertEqual(sk, data["sk"])

# Assert encapsulation matches
ct, ss = Kyber.enc(pk)
ss, ct = Kyber.encaps(pk)
self.assertEqual(ct, data["ct"])
self.assertEqual(ss, data["ss"])

# Assert decapsulation matches
_ss = Kyber.dec(ct, sk)
_ss = Kyber.decaps(ct, sk)
self.assertEqual(ss, data["ss"])

def test_kyber512_known_answer(self):
Expand Down

0 comments on commit 3c41543

Please sign in to comment.