Caution
This is an educational resource and has not been designed to be secure against any form of side-channel attack. The intended use of this project is for learning and experimenting with ML-DSA and Dilithium
This repository contains a pure python implementation of both:
- ML-DSA the NIST Module-Lattice-Based Digital Signature Standard following the FIPS 204 based off the Dilithium submission to the NIST post-quantum project.
- CRYSTALS-Dilithium: following (at the time of writing) the most recent specification (v3.1)
Note: This project has followed
kyber-py
which is a pure-python
implementation of CRYSTALS-Kyber and ML-KEM and reuses a lot of code.
I have written dilithium-py
as a way to learn about the way protocol works,
and to try and create a clean, well commented implementation which people can
learn from.
This code is not constant time, or written to be performant. Rather, it was
written so that reading though the pseudocode of the
specification
closely matches the code which we use within dilithium.py
and supporting files.
This work started by simply implementing Dilithium for fun, however after NIST picked Dilithium to standardise as ML-DSA, the repository grew and now includes both implementations of Dilithium and ML-DSA. I assume as this repository ages, the Dilithium implementation will get less useful and the ML-DSA one will be the focus, but for historical reasons we will include both. If only so that people can study the differences which NIST introduced during the standardisation of the protocol.
This implementation passes all the KAT vectors for dilithium
and ml_dsa
. For more information see the unit tests in test_ml_dsa.py
and test_dilithium.py
The KAT files were either downloaded or generated:
- For ML-DSA, the KAT files were download from the GitHub repository
usnistgov/ACVP-Server/ release 1.1.0.35, and are included in
assets/ML-DSA-*
directories. - For Dilithium, the KAT files were generated from the projects GitHub
repository and are included in
assets/PQCsignKAT_*.rsp
This implementation is based off the most recent specification (v3.1). There were breaking changes to the KAT files submitted to NIST when Dilithium was updated to 3.1, so the NIST KAT files will not match our code.
To deal with this, we generated our own KAT files from the reference implementation for version 3.1. These are the files inside assets.
Originally, as with kyber-py
, this project was planned to have zero
dependencies, however like kyber-py
, to pass the KATs, I need a
deterministic CSRNG. The reference implementation uses
AES256 CTR DRBG. I have implemented this in ase256_ctr_drbg.py
.
However, I have not implemented AES itself, instead I import this from pycryptodome
.
To install dependencies, run pip install -r requirements.txt
.
If you're happy to use system randomness (os.urandom
) then you don't need
this dependency.
There is an additional optional dependency of
xoflib
which is a python package with
bindings to many Rust implementations of eXtendable-Output Functions (XOFx). The
creation of this package was inspired by this repository as Dilithium needs a streaming API from the shake XOFs which hashlib
doesn't support.
xoflib
can be installed by running pip install xoflib
or by installing from requirements as above.
If you do not wish to install this dependency, then we include a small
shake_wrapper
to mimic xoflib
but
with a much higher memory consumption due to the limitations of hashlib
.
There are three functions exposed on the ML_DSA
class which are intended
for use:
ML_DSA.keygen()
: generate a bit-packed keypair(pk, sk)
ML_DSA.sign(sk, msg)
: generate a bit-packed signaturesig
from the messagemsg
and bit-packed secret keysk
.ML_DSA.verify(pk, msg, sig)
: verify that the bit-packedsig
is valid for a given messagemsg
and bit-packed public keypk
.
To use ML_DSA()
, it must be initialised with a dictionary of the
protocol parameters. An example can be seen in DEFAULT_PARAMETERS
in
the file ml_dsa.py
Additionally, the class has been initialised with these default parameters, so you can simply import the NIST level you want to play with:
>>> from dilithium_py.ml_dsa import ML_DSA_44
>>>
>>> # Example of signing
>>> pk, sk = ML_DSA_44.keygen()
>>> msg = b"Your message signed by ML_DSA"
>>> sig = ML_DSA_44.sign(sk, msg)
>>> assert ML_DSA_44.verify(pk, msg, sig)
>>>
>>> # Verification will fail with the wrong msg or pk
>>> assert not ML_DSA_44.verify(pk, b"", sig)
>>> pk_new, sk_new = ML_DSA_44.keygen()
>>> assert not ML_DSA_44.verify(pk_new, msg, sig)
The above example would also work with the other NIST levels
ML_DSA_65
and ML_DSA_87
.
Some very rough benchmarks to give an idea about performance:
1000 Iterations | ML_DSA_44 |
ML_DSA_65 |
ML_DSA_87 |
---|---|---|---|
KeyGen() Median Time |
6 ms | 10 ms | 14 ms |
Sign() Median Time |
29 ms | 49 ms | 59 ms |
Sign() Average Time |
36 ms | 62 ms | 75 ms |
Verify() Median Time |
8 ms | 11 ms | 17 ms |
All times recorded using a Intel Core i7-9750H CPU averaged over 1000 calls.
There are three functions exposed on the Dilithium
class which are intended
for use:
Dilithium.keygen()
: generate a bit-packed keypair(pk, sk)
Dilithium.sign(sk, msg)
: generate a bit-packed signaturesig
from the messagemsg
and bit-packed secret keysk
.Dilithium.verify(pk, msg, sig)
: verify that the bit-packedsig
is valid for a given messagemsg
and bit-packed public keypk
.
To use Dilithium()
, it must be initialised with a dictionary of the
protocol parameters. An example can be seen in DEFAULT_PARAMETERS
in
the file dilithium.py
Additionally, the class has been initialised with these default parameters, so you can simply import the NIST level you want to play with:
>>> from dilithium_py.dilithium import Dilithium2
>>>
>>> # Example of signing
>>> pk, sk = Dilithium2.keygen()
>>> msg = b"Your message signed by Dilithium"
>>> sig = Dilithium2.sign(sk, msg)
>>> assert Dilithium2.verify(pk, msg, sig)
>>>
>>> # Verification will fail with the wrong msg or pk
>>> assert not Dilithium2.verify(pk, b"", sig)
>>> pk_new, sk_new = Dilithium2.keygen()
>>> assert not Dilithium2.verify(pk_new, msg, sig)
The above example would also work with the other NIST levels
Dilithium3
and Dilithium5
.
Some very rough benchmarks to give an idea about performance:
1000 Iterations | Dilithium2 |
Dilithium3 |
Dilithium5 |
---|---|---|---|
KeyGen() Median Time |
6 ms | 9 ms | 15 ms |
Sign() Median Time |
27 ms | 46 ms | 58 ms |
Sign() Average Time |
35 ms | 58 ms | 72 ms |
Verify() Median Time |
7 ms | 11 ms | 18 ms |
All times recorded using a Intel Core i7-9750H CPU averaged over 1000 calls.
You may notice that ML DSA has marginally slower signing than the reported Dilithium times included above. This is because the ML DSA implementation follows the NIST spec while the Dilithium implementation includes an optimisation of decomposition and hint generation. Details of this are given in Section 5.1 of the Dilithium specification. We discuss it informally below and would like to thank Keegan Ryan for helping understand the differences between the two hint generation functions.
When it comes to implementing the optimisation, not only are slightly different
vectors used during the computation, but the generation of the hint
For the NIST specification the hint is generated by considering the vectors
r
and r + z
are added
together. This is computed using algorithm 39 from FIPS 204:
def make_hint(z, r, a, q):
"""
Check whether the top bit of z will change when r is added
"""
r1 = high_bits(r, a, q)
v1 = high_bits(r + z, a, q)
return int(r1 != v1)
This function is used pairwise for every coefficient of every polynomial in the
two vectors:
For the Dilithium optimisation, rather than computing only the high bits of
As the inputs to the hint generation are now used from the decomposition, where
top bits have already been removed, the make_hint()
function has to check
whether high bits are set in the result of low_bits(r) + z
, which is computed
using the following function:
def make_hint_optimised(z, r, a, q):
"""
Optimised version of the above used when the low bits w0 are extracted from
`w = (A_hat @ y_hat).from_ntt()` during signing
"""
gamma2 = a >> 1
if z <= gamma2 or z > (q - gamma2) or (z == (q - gamma2) and r == 0):
return 0
return 1
In particular, when z = q-1
, make_hint()
will return 1
, while the make_hint_optimised()
returns 0
.
As this optimisation is present in most implementations, this has caused a
confusion about whether make_hint()
is correct as the output is different to
make_hint_optimised()
.
It's important to realise the output of these make hints functions is different with the same input vectors, but the hint vector will be identical in the cases that:
-
make_hint()
is used with$-c\mathbf{t}_0$ and$\mathbf{w} -c\mathbf{s}_1 + -c\mathbf{t}_0$ -
make_hint_optimised()
is used with$\mathbf{w}_0 -c\mathbf{s}_1 + -c\mathbf{t}_0$ and$\mathbf{w}_1$
This means that for an implementation one can pick either the NIST documented algorithm or the optimisation in 5.1 of Dilithium and the resulting signatures will be identical (indeed, you can use either method and have all KAT vectors pass.)
The file polynomials.py
contains the classes
PolynomialRing
and
Polynomial
. This implements the univariate polynomial ring
The implementation is inspired by SageMath
and you can create the
ring
>>> R = PolynomialRing(11, 8)
>>> x = R.gen()
>>> f = 3*x**3 + 4*x**7
>>> g = R.random_element(); g
5 + x^2 + 5*x^3 + 4*x^4 + x^5 + 3*x^6 + 8*x^7
>>> f*g
8 + 9*x + 10*x^3 + 7*x^4 + 2*x^5 + 5*x^6 + 10*x^7
>>> f + f
6*x^3 + 8*x^7
>>> g - g
0
The file modules.py
contains the classes Module
and Matrix
.
A module is a generalisation of a vector space, where the field
of scalars is replaced with a ring. In the case of Dilithium, we
need the module with the ring
Matrix
allows elements of the module to be of size
As an example of the operations we can perform with out Module
lets revisit the ring from the previous example:
>>> R = PolynomialRing(11, 8)
>>> x = R.gen()
>>>
>>> M = Module(R)
>>> # We create a matrix by feeding the coefficients to M
>>> A = M([[x + 3*x**2, 4 + 3*x**7], [3*x**3 + 9*x**7, x**4]])
>>> A
[ x + 3*x^2, 4 + 3*x^7]
[3*x^3 + 9*x^7, x^4]
>>> # We can add and subtract matricies of the same size
>>> A + A
[ 2*x + 6*x^2, 8 + 6*x^7]
[6*x^3 + 7*x^7, 2*x^4]
>>> A - A
[0, 0]
[0, 0]
>>> # A vector can be constructed by a list of coefficents
>>> v = M([3*x**5, x])
>>> v
[3*x^5, x]
>>> # We can compute the transpose
>>> v.transpose()
[3*x^5]
[ x]
>>> v + v
[6*x^5, 2*x]
>>> # We can also compute the transpose in place
>>> v.transpose_self()
[3*x^5]
[ x]
>>> v + v
[6*x^5]
[ 2*x]
>>> # Matrix multiplication follows python standards and is denoted by @
>>> A @ v
[8 + 4*x + 3*x^6 + 9*x^7]
[ 2 + 6*x^4 + x^5]
We can transform polynomials to NTT form and from NTT form
with poly.to_ntt()
and poly.from_ntt()
.
When we perform operations between polynomials, (+, -, *)
either both or neither must be in NTT form.
>>> f = R.random_element()
>>> f == f.to_ntt().from_ntt()
True
>>> g = R.random_element()
>>> h = f*g
>>> h == (f.to_ntt() * g.to_ntt()).from_ntt()
True
While writing this README, performing multiplication of of polynomials in NTT form is about 100x faster when working with the ring used by Dilithium.
>>> # Lets work in the ring we use for Dilithium
>>> R = Dilithium2.R
>>> # Generate some random elements
>>> f = R.random_element()
>>> g = R.random_element()
>>> # Takes about 10 seconds to perform 1000 multiplications
>>> timeit.timeit("f*g", globals=globals(), number=1000)
9.621509193995735
>>> # Now lets convert to NTT and try again
>>> f.to_ntt()
>>> g.to_ntt()
>>> # Now it only takes ~0.1s to perform 1000 multiplications!
>>> timeit.timeit("f*g", globals=globals(), number=1000)
0.12979038299818058
These functions extend to modules
>>> M = Dilithium2.M
>>> R = Dilithium2.R
>>> v = M([R.random_element(), R.random_element()])
>>> u = M([R.random_element(), R.random_element()]).transpose()
>>> A = u @ v
>>> A == (u.to_ntt() @ v.to_ntt()).from_ntt()
True
As operations on the module are just operations between elements, we expect a similar 100x speed up when working in NTT form:
>>> u = M([R.random_element(), R.random_element()]).transpose()
>>> v = M([R.random_element(), R.random_element()])
>>> timeit.timeit("u@v", globals=globals(), number=1000)
38.39359304799291
>>> u = u.to_ntt()
>>> v = v.to_ntt()
>>> timeit.timeit("u@v", globals=globals(), number=1000)
0.495470915993792