Skip to content

Commit

Permalink
Merge pull request #5 from CQCL/review-sburton
Browse files Browse the repository at this point in the history
  • Loading branch information
y-richie-y authored Aug 9, 2023
2 parents 305df19 + 67fe267 commit 48101e2
Show file tree
Hide file tree
Showing 12 changed files with 499 additions and 411 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ __pycache__/
build/
dist/
htmlcov/
.idea/
cmake-build-*/

# ignore the built documentation
docs/_*
Expand Down
6 changes: 6 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,14 @@ cmake_minimum_required(VERSION 3.22.1)

project(cryptomite)

include(CTest)

set(CMAKE_CXX_STANDARD 20)

add_subdirectory(src)

add_subdirectory(cryptomite)

if(BUILD_TESTING)
add_subdirectory(test)
endif()
2 changes: 1 addition & 1 deletion cryptomite/pycryptomite.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <ntt.cpp>
#include <ntt.h>
#include <trevisan.cpp>


Expand Down
2 changes: 1 addition & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
add_library(trevisan trevisan.cpp irreducible_poly.hpp)
add_library(trevisan trevisan.cpp irreducible_poly.cpp ntt.cpp)

target_include_directories(trevisan PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})

Expand Down
10 changes: 4 additions & 6 deletions src/irreducible_poly.hpp → src/irreducible_poly.cpp
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
#include "irreducible_poly.h"

#include <vector>

// Minimal weight primitive polynomials over Z/2Z
// In addition the coefficients/bits (apart from the
// constant and the leading term) are as close to
// the low end as possible.
//
// Generated by Joerg Arndt, 2003-January-30

#include <vector>

using namespace std;

vector<vector<int> > minweight_primpoly_coeffs {
const std::vector<std::vector<int>> minweight_primpoly_coeffs {
// coeffs // (deg) [weight]
{0}, // (0) [1]
{1, 0}, // (1) [2]
Expand Down
5 changes: 5 additions & 0 deletions src/irreducible_poly.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#pragma once

#include <vector>

extern const std::vector<std::vector<int>> minweight_primpoly_coeffs;
142 changes: 59 additions & 83 deletions src/ntt.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
#include "ntt.h"

#include <stdexcept>
#include <vector>

#define P ((3u<<30) + 1)
Expand Down Expand Up @@ -29,7 +32,7 @@ uint32_t mul(uint32_t a, uint32_t b) {
return n % P;
}

std::vector<uint32_t> mul_vec(std::vector<uint32_t> a, std::vector<uint32_t> b) {
std::vector<uint32_t> mul_vec(const std::vector<uint32_t> &a, const std::vector<uint32_t> &b) {
std::vector<uint32_t> c(a.size());
for (uint32_t i = 0; i < a.size(); i++) {
c[i] = mul(a[i], b[i]);
Expand Down Expand Up @@ -66,106 +69,79 @@ static uint32_t reverse_bits(unsigned l, uint32_t x) {
return y;
}

NTT::NTT(unsigned l) : L(1<<l) {
if (l < 1 || l > 30) {
throw std::runtime_error("Must have 1 <= l <= 30.");
}

class NTT {
private:
/** Sequence length (power of 2) */
uint32_t L;

/** Inverse of L mod p */
uint32_t Linv;

/**
* Powers 1, r, r^2, ..., r^(L/2-1) mod p, where r is a primitive L'th
* root of unity mod p
*/
std::vector<uint32_t> R;

/**
* Inverse powers 1, r^{-1}, r^{-2}, ..., r{-^(L/2-1)} mod p
*/
std::vector<uint32_t> Rinv;

/**
* Lookup table for bit reversals
*/
std::vector<uint32_t> revbits;

public:
NTT(unsigned l) : L(1<<l) {
if (l < 1 || l > 30) {
throw "Must have 1 <= l <= 30.";
}

Linv = modexp(L, P-2);
Linv = modexp(L, P-2);

uint32_t half_L = L/2;
uint32_t half_L = L/2;

R = std::vector<uint32_t>(half_L);
Rinv = std::vector<uint32_t>(half_L);
revbits = std::vector<uint32_t>(L);
R = std::vector<uint32_t>(half_L);
Rinv = std::vector<uint32_t>(half_L);
revbits = std::vector<uint32_t>(L);

uint32_t r = modexp(G, (P - 1) >> l); // primitive L'th root of unity
uint32_t r = modexp(G, (P - 1) >> l); // primitive L'th root of unity

{
{
{
uint64_t t = 1;
for (uint32_t i = 0; i < half_L; i++) {
R[i] = t;
t = mul(t, r);
}
}

{
// r^(L/2) = -1
uint32_t t = P - 1;
for (uint32_t i = 1; i <= half_L; i++) {
t = mul(t, r);
Rinv[half_L - i] = t;
}
uint64_t t = 1;
for (uint32_t i = 0; i < half_L; i++) {
R[i] = t;
t = mul(t, r);
}
}

for (uint32_t i = 0; i < L; i++) {
revbits[i] = reverse_bits(l, i);
{
// r^(L/2) = -1
uint32_t t = P - 1;
for (uint32_t i = 1; i <= half_L; i++) {
t = mul(t, r);
Rinv[half_L - i] = t;
}
}
}

for (uint32_t i = 0; i < L; i++) {
revbits[i] = reverse_bits(l, i);
}

std::vector<uint32_t> ntt(const std::vector<uint32_t> x, bool inverse) {
const std::vector<uint32_t>& U = inverse ? Rinv : R;
}

std::vector<uint32_t> y(L, 0);
std::vector<uint32_t> NTT::ntt(const std::vector<uint32_t> &x, bool inverse) {
const std::vector<uint32_t>& U = inverse ? Rinv : R;

// Bit inversion
for (uint32_t i = 0; i < L; i++) {
y[revbits[i]] = x[i];
}
std::vector<uint32_t> y(L, 0);

// Main loop
for (
uint32_t h = 2, k = 1, u = L/2;
h <= L;
k = h, h <<= 1, u >>= 1)
{
for (uint32_t i = 0; i < L; i += h) {
for (uint32_t j = 0, v = 0; j < k; j++, v += u) {
uint32_t r = i + j;
uint32_t s = r + k;
uint32_t a = y[r];
uint32_t b = mul(y[s], U[v]);
y[r] = add(a, b);
y[s] = sub(a, b);
}
// Bit inversion
for (uint32_t i = 0; i < L; i++) {
y[revbits[i]] = x[i];
}

// Main loop
for (
uint32_t h = 2, k = 1, u = L/2;
h <= L;
k = h, h <<= 1, u >>= 1)
{
for (uint32_t i = 0; i < L; i += h) {
for (uint32_t j = 0, v = 0; j < k; j++, v += u) {
uint32_t r = i + j;
uint32_t s = r + k;
uint32_t a = y[r];
uint32_t b = mul(y[s], U[v]);
y[r] = add(a, b);
y[s] = sub(a, b);
}
}
}

// Normalization for inverse
if (inverse) {
for (uint32_t i = 0; i < L; i++) {
y[i] = mul(Linv, y[i]);
}
// Normalization for inverse
if (inverse) {
for (uint32_t i = 0; i < L; i++) {
y[i] = mul(Linv, y[i]);
}
return y;
}
};
return y;
}
36 changes: 36 additions & 0 deletions src/ntt.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#pragma once

#include <cstdint>
#include <vector>

std::vector<uint32_t> mul_vec(const std::vector<uint32_t> &a, const std::vector<uint32_t> &b);

class NTT {
private:
/** Sequence length (power of 2) */
uint32_t L;

/** Inverse of L mod p */
uint32_t Linv;

/**
* Powers 1, r, r^2, ..., r^(L/2-1) mod p, where r is a primitive L'th
* root of unity mod p
*/
std::vector<uint32_t> R;

/**
* Inverse powers 1, r^{-1}, r^{-2}, ..., r{-^(L/2-1)} mod p
*/
std::vector<uint32_t> Rinv;

/**
* Lookup table for bit reversals
*/
std::vector<uint32_t> revbits;

public:
explicit NTT(unsigned l);

std::vector<uint32_t> ntt(const std::vector<uint32_t> &x, bool inverse);
};
Loading

0 comments on commit 48101e2

Please sign in to comment.