-
Notifications
You must be signed in to change notification settings - Fork 48
/
test_ml_kem.py
145 lines (114 loc) · 4.56 KB
/
test_ml_kem.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import unittest
import json
from kyber_py.ml_kem import ML_KEM_512, ML_KEM_768, ML_KEM_1024
class TestML_KEM(unittest.TestCase):
"""
Test ML_KEM levels for internal
consistency by generating key pairs
and shared secrets.
"""
def generic_test_ML_KEM(self, ML_KEM, count):
for _ in range(count):
(ek, dk) = ML_KEM.keygen()
for _ in range(count):
(K, c) = ML_KEM.encaps(ek)
K_prime = ML_KEM.decaps(dk, c)
self.assertEqual(K, K_prime)
def test_ML_KEM_512(self):
self.generic_test_ML_KEM(ML_KEM_512, 5)
def test_ML_KEM_768(self):
self.generic_test_ML_KEM(ML_KEM_768, 5)
def test_ML_KEM_1024(self):
self.generic_test_ML_KEM(ML_KEM_1024, 5)
def test_encaps_type_check_failure(self):
"""
Send an ecaps key of the wrong length
"""
self.assertRaises(ValueError, lambda: ML_KEM_512.encaps(b"1"))
def test_encaps_modulus_check_failure(self):
"""
We create a vector of polynomials with non-canonical values for
coefficents to fail the modulus check
"""
(ek, _) = ML_KEM_512.keygen()
rho = ek[-32:]
bad_f_hat = ML_KEM_512.R([3329] * 256)
bad_t_hat = ML_KEM_512.M.vector([bad_f_hat, bad_f_hat])
bad_t_hat_bytes = bad_t_hat.encode(12)
bad_ek = bad_t_hat_bytes + rho
self.assertEqual(len(bad_ek), len(ek))
self.assertRaises(ValueError, lambda: ML_KEM_512.encaps(bad_ek))
def test_xof_failure(self):
self.assertRaises(
ValueError, lambda: ML_KEM_512._xof(b"1", b"2", b"3")
)
def test_prf_failure(self):
self.assertRaises(ValueError, lambda: ML_KEM_512._prf(2, b"1", b"2"))
def test_decaps_ct_type_check_failure(self):
"""
Send a ciphertext of the wrong length
"""
ek, dk = ML_KEM_512.keygen()
K, c = ML_KEM_512.encaps(ek)
self.assertRaises(ValueError, lambda: ML_KEM_512.decaps(dk, b"1"))
def test_decaps_dk_type_check_failure(self):
"""
Send a ciphertext of the wrong length
"""
ek, dk = ML_KEM_512.keygen()
K, c = ML_KEM_512.encaps(ek)
self.assertRaises(ValueError, lambda: ML_KEM_512.decaps(b"1", c))
def test_decaps_hash_check_failure(self):
"""
Send a ciphertext of the wrong length
"""
ek, dk = ML_KEM_512.keygen()
K, c = ML_KEM_512.encaps(ek)
dk_bad = b"0" * len(dk)
self.assertRaises(ValueError, lambda: ML_KEM_512.decaps(dk_bad, c))
class TestML_KEM_KAT(unittest.TestCase):
"""
Test ML-KEM against test vectors collected from
https://github.com/usnistgov/ACVP-Server/releases/tag/v1.1.0.35
"""
def generic_keygen_kat(self, ML_KEM, index):
with open("assets/ML-KEM-keyGen-FIPS203/internalProjection.json") as f:
data = json.load(f)
kat_data = data["testGroups"][index]["tests"]
for test in kat_data:
d_kat = bytes.fromhex(test["d"])
z_kat = bytes.fromhex(test["z"])
ek_kat = bytes.fromhex(test["ek"])
dk_kat = bytes.fromhex(test["dk"])
ek, dk = ML_KEM._keygen_internal(d_kat, z_kat)
self.assertEqual(ek, ek_kat)
self.assertEqual(dk, dk_kat)
def generic_encap_decap_kat(self, ML_KEM, index):
with open(
"assets/ML-KEM-encapDecap-FIPS203/internalProjection.json"
) as f:
data = json.load(f)
kat_data = data["testGroups"][index]["tests"]
for test in kat_data:
ek_kat = bytes.fromhex(test["ek"])
dk_kat = bytes.fromhex(test["dk"])
c_kat = bytes.fromhex(test["c"])
k_kat = bytes.fromhex(test["k"])
m_kat = bytes.fromhex(test["m"])
K, c = ML_KEM._encaps_internal(ek_kat, m_kat)
self.assertEqual(K, k_kat)
self.assertEqual(c, c_kat)
K_prime = ML_KEM.decaps(dk_kat, c_kat)
self.assertEqual(K_prime, k_kat)
def test_ML_KEM_512_keygen(self):
self.generic_keygen_kat(ML_KEM_512, 0)
def test_ML_KEM_768_keygen(self):
self.generic_keygen_kat(ML_KEM_768, 1)
def test_ML_KEM_1024_keygen(self):
self.generic_keygen_kat(ML_KEM_1024, 2)
def test_ML_KEM_512_encap_decap(self):
self.generic_encap_decap_kat(ML_KEM_512, 0)
def test_ML_KEM_768_encap_decap(self):
self.generic_encap_decap_kat(ML_KEM_768, 1)
def test_ML_KEM_1024_encap_decap(self):
self.generic_encap_decap_kat(ML_KEM_1024, 2)