Skip to content

Commit

Permalink
arbo: add CalculateProofNodes and CheckProofBatch
Browse files Browse the repository at this point in the history
  • Loading branch information
altergui committed Nov 4, 2024
1 parent d969541 commit 41af43a
Show file tree
Hide file tree
Showing 3 changed files with 210 additions and 4 deletions.
32 changes: 32 additions & 0 deletions tree/arbo/circomproofs.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package arbo

import (
"bytes"
"encoding/json"
"fmt"
"slices"
)

// CircomVerifierProof contains the needed data to check a Circom Verifier Proof
Expand Down Expand Up @@ -89,3 +92,32 @@ func (t *Tree) GenerateCircomVerifierProof(k []byte) (*CircomVerifierProof, erro

return &cp, nil
}

// CalculateProofNodes calculates the chain of hashes in the path of the proof.
// In the returned list, first item is the root, and last item is the hash of the leaf.
func (cvp CircomVerifierProof) CalculateProofNodes(hashFunc HashFunction) ([][]byte, error) {
paddedSiblings := slices.Clone(cvp.Siblings)
for k, v := range paddedSiblings {
if bytes.Equal(v, []byte{0}) {
paddedSiblings[k] = make([]byte, hashFunc.Len())
}
}
packedSiblings, err := PackSiblings(hashFunc, paddedSiblings)
if err != nil {
return nil, err
}
return CalculateProofNodes(hashFunc, cvp.Key, cvp.Value, packedSiblings)
}

// CheckProof verifies the given proof. The proof verification depends on the
// HashFunction passed as parameter.
func (cvp CircomVerifierProof) CheckProof(hashFunc HashFunction) (bool, error) {
hashes, err := cvp.CalculateProofNodes(hashFunc)
if err != nil {
return false, err
}
if !bytes.Equal(hashes[0], cvp.Root) {
return false, fmt.Errorf("calculated root doesn't match expected root")
}
return true, nil
}
71 changes: 67 additions & 4 deletions tree/arbo/proof.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package arbo
import (
"bytes"
"encoding/binary"
"encoding/hex"
"fmt"
"math"
"slices"
Expand Down Expand Up @@ -160,19 +161,31 @@ func bytesToBitmap(b []byte) []bool {
// CheckProof verifies the given proof. The proof verification depends on the
// HashFunction passed as parameter.
func CheckProof(hashFunc HashFunction, k, v, root, packedSiblings []byte) (bool, error) {
siblings, err := UnpackSiblings(hashFunc, packedSiblings)
hashes, err := CalculateProofNodes(hashFunc, k, v, packedSiblings)
if err != nil {
return false, err
}
return bytes.Equal(hashes[0], root), nil
}

// CalculateProofNodes calculates the chain of hashes in the path of the given proof.
// In the returned list, first item is the root, and last item is the hash of the leaf.
func CalculateProofNodes(hashFunc HashFunction, k, v, packedSiblings []byte) ([][]byte, error) {
siblings, err := UnpackSiblings(hashFunc, packedSiblings)
if err != nil {
return nil, err
}

keyPath := make([]byte, int(math.Ceil(float64(len(siblings))/float64(8))))
copy(keyPath, k)

key, _, err := newLeafValue(hashFunc, k, v)
if err != nil {
return false, err
return nil, err
}

hashes := [][]byte{key}

path := getPath(len(siblings), keyPath)
for i, sibling := range slices.Backward(siblings) {
if path[i] {
Expand All @@ -181,8 +194,58 @@ func CheckProof(hashFunc HashFunction, k, v, root, packedSiblings []byte) (bool,
key, _, err = newIntermediate(hashFunc, key, sibling)
}
if err != nil {
return false, err
return nil, err
}
hashes = append(hashes, key)
}
slices.Reverse(hashes)
return hashes, nil
}

// CheckProofBatch verifies a batch of N proofs pairs (old and new). The proof verification depends on the
// HashFunction passed as parameter.
// Returns nil if the batch is valid, or an error otherwise.
func CheckProofBatch(hashFunc HashFunction, oldProofs, newProofs []*CircomVerifierProof) error {
newBranches := make(map[string]int)
newSiblings := make(map[string]int)

if len(oldProofs) != len(newProofs) {
return fmt.Errorf("batch of proofs incomplete")
}

for i := range oldProofs {
// Check all old proofs are valid
if valid, err := oldProofs[i].CheckProof(hashFunc); !valid {
return fmt.Errorf("old proof invalid: %w", err)
}

// Map all new branches
nodes, err := newProofs[i].CalculateProofNodes(hashFunc)
if err != nil {
return fmt.Errorf("new proof invalid: %w", err)
}
// and check they are valid
if !bytes.Equal(newProofs[i].Root, nodes[0]) {
return fmt.Errorf("new proof invalid: root doesn't match")
}

for level, hash := range nodes {
newBranches[hex.EncodeToString(hash)] = level
}

for level := range newProofs[i].Siblings {
if !slices.Equal(oldProofs[i].Siblings[level], newProofs[i].Siblings[level]) {
// since in newBranch the root is level 0, we shift siblings to level + 1
newSiblings[hex.EncodeToString(newProofs[i].Siblings[level])] = level + 1
}
}
}
return bytes.Equal(key, root), nil

for hash, level := range newSiblings {
if newBranches[hash] != newSiblings[hash] {
return fmt.Errorf("sibling %s (at level %d) changed but there's no proof why", hash, level)
}
}

return nil
}
111 changes: 111 additions & 0 deletions tree/arbo/proof_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package arbo

import (
"math/big"
"slices"
"testing"

qt "github.com/frankban/quicktest"
"go.vocdoni.io/dvote/db/metadb"
)

func TestCheckProofBatch(t *testing.T) {
database := metadb.NewTest(t)
c := qt.New(t)

keyLen := 1
maxLevels := keyLen * 8
tree, err := NewTree(Config{
Database: database, MaxLevels: maxLevels,
HashFunction: HashFunctionBlake3,
})
c.Assert(err, qt.IsNil)

processID := []byte("01234567890123456789012345678900")
censusRoot := []byte("01234567890123456789012345678901")
ballotMode := []byte("1234")

err = tree.Add(BigIntToBytesLE(keyLen, big.NewInt(0x00)), processID)
c.Assert(err, qt.IsNil)

err = tree.Add(BigIntToBytesLE(keyLen, big.NewInt(0x01)), censusRoot)
c.Assert(err, qt.IsNil)

err = tree.Add(BigIntToBytesLE(keyLen, big.NewInt(0x02)), ballotMode)
c.Assert(err, qt.IsNil)

var oldProofs, newProofs []*CircomVerifierProof

for i := int64(0x00); i <= int64(0x02); i++ {
proof, err := tree.GenerateCircomVerifierProof(BigIntToBytesLE(keyLen, big.NewInt(i)))
c.Assert(err, qt.IsNil)
oldProofs = append(oldProofs, proof)
}

censusRoot[0] = byte(0x02)
ballotMode[0] = byte(0x02)

err = tree.Update(BigIntToBytesLE(keyLen, big.NewInt(0x01)), censusRoot)
c.Assert(err, qt.IsNil)

err = tree.Update(BigIntToBytesLE(keyLen, big.NewInt(0x02)), ballotMode)
c.Assert(err, qt.IsNil)

for i := int64(0x00); i <= int64(0x02); i++ {
proof, err := tree.GenerateCircomVerifierProof(BigIntToBytesLE(keyLen, big.NewInt(i)))
c.Assert(err, qt.IsNil)
newProofs = append(newProofs, proof)
}

// this mix should pass: proof 0 is unchanged, proof 1 + 2 verify together
err = CheckProofBatch(HashFunctionBlake3, oldProofs, newProofs)
c.Assert(err, qt.IsNil)

// omitting proof 0 (unchanged) should also pass
err = CheckProofBatch(HashFunctionBlake3, oldProofs[1:], newProofs[1:])
c.Assert(err, qt.IsNil)

// providing just proof 0 (unchanged) should also pass
err = CheckProofBatch(HashFunctionBlake3, oldProofs[:0], newProofs[:0])
c.Assert(err, qt.IsNil)

// length mismatch
err = CheckProofBatch(HashFunctionBlake3, oldProofs, newProofs[:1])
c.Assert(err, qt.ErrorMatches, "batch of proofs incomplete")

// omitting proof 2 should fail (since changed siblings in proof 1 can't be explained)
err = CheckProofBatch(HashFunctionBlake3, oldProofs[:1], newProofs[:1])
c.Assert(err, qt.ErrorMatches, ".*changed but there's no proof why.*")

// the rest is mangling proofs to simulate other unexplained changes in the tree, all of these should fail
badProofs := deepClone(oldProofs)
badProofs[0].Root = []byte("01234567890123456789012345678900")
err = CheckProofBatch(HashFunctionBlake3, badProofs, newProofs)
c.Assert(err, qt.ErrorMatches, "old proof invalid: calculated root doesn't match expected root")

badProofs = deepClone(oldProofs)
badProofs[0].Siblings[0] = []byte("01234567890123456789012345678900")
err = CheckProofBatch(HashFunctionBlake3, badProofs, newProofs)
c.Assert(err, qt.ErrorMatches, "old proof invalid: calculated root doesn't match expected root")

badProofs = deepClone(newProofs)
badProofs[0].Root = []byte("01234567890123456789012345678900")
err = CheckProofBatch(HashFunctionBlake3, oldProofs, badProofs)
c.Assert(err, qt.ErrorMatches, "new proof invalid: root doesn't match")

badProofs = deepClone(newProofs)
badProofs[0].Siblings[0] = []byte("01234567890123456789012345678900")
err = CheckProofBatch(HashFunctionBlake3, oldProofs, badProofs)
c.Assert(err, qt.ErrorMatches, "new proof invalid: root doesn't match")
}

func deepClone(src []*CircomVerifierProof) []*CircomVerifierProof {
dst := slices.Clone(src)
for i := range src {
proof := *src[i]
dst[i] = &proof

dst[i].Siblings = slices.Clone(src[i].Siblings)
}
return dst
}

0 comments on commit 41af43a

Please sign in to comment.