diff --git a/fuzz_test.go b/fuzz_test.go index e7fe8af..3197cad 100644 --- a/fuzz_test.go +++ b/fuzz_test.go @@ -8,7 +8,7 @@ import ( "github.com/celestiaorg/nmt" "github.com/celestiaorg/nmt/namespace" - "github.com/google/gofuzz" + fuzz "github.com/google/gofuzz" ) func TestFuzzProveVerifyNameSpace(t *testing.T) { @@ -72,7 +72,7 @@ func TestFuzzProveVerifyNameSpace(t *testing.T) { if err != nil { t.Fatalf("error on Prove(%v): %v", i, err) } - if ok := singleItemProof.VerifyInclusion(hash, data[i][:size], data[i][size:], treeRoot); !ok { + if ok := singleItemProof.VerifyInclusion(hash, data[i][:size], [][]byte{data[i][size:]}, treeRoot); !ok { t.Fatalf("expected VerifyInclusion() == true; data = %#v; proof = %#v", data[i], singleItemProof) } leafIdx++ diff --git a/nmt_test.go b/nmt_test.go index 3587393..7c32e1c 100644 --- a/nmt_test.go +++ b/nmt_test.go @@ -277,7 +277,7 @@ func TestNamespacedMerkleTree_ProveNamespace_Ranges_And_Verify(t *testing.T) { if err != nil { t.Fatalf("unexpected error on Prove(): %v", err) } - gotChecksOut := gotSingleProof.VerifyInclusion(sha256.New(), data.ID, data.Data, n.Root()) + gotChecksOut := gotSingleProof.VerifyInclusion(sha256.New(), data.ID, [][]byte{data.Data}, n.Root()) if !gotChecksOut { t.Errorf("Proof.VerifyInclusion() gotChecksOut: %v, want: true", gotChecksOut) } @@ -467,7 +467,7 @@ func TestIgnoreMaxNamespace(t *testing.T) { if err != nil { t.Fatalf("ProveNamespace() unexpected error: %v", err) } - if !singleProof.VerifyInclusion(hash, d.NamespaceID(), d.Data(), tree.Root()) { + if !singleProof.VerifyInclusion(hash, d.NamespaceID(), [][]byte{d.Data()}, tree.Root()) { t.Errorf("VerifyInclusion() failed on data: %#v with index: %v", d, idx) } if gotIgnored := singleProof.IsMaxNamespaceIDIgnored(); gotIgnored != tc.ignoreMaxNamespace { diff --git a/proof.go b/proof.go index d5cac31..6f6e277 100644 --- a/proof.go +++ b/proof.go @@ -206,10 +206,19 @@ func (proof Proof) verifyLeafHashes(nth *Hasher, verifyCompleteness bool, nID na return bytes.Equal(tree.Root(), root) } -func (proof Proof) VerifyInclusion(h hash.Hash, nid namespace.ID, data []byte, root []byte) bool { +// VerifyInclusion checks that the inclusion proof is valid by using leaf data +// and the provided proof to regenerate and compare the root. Note that the leaf +// data should not contain the prefixed namespace, unlike the tree.Push method, +// which takes prefixed data. All leaves implicitly have the same namespace ID: `nid`. +func (proof Proof) VerifyInclusion(h hash.Hash, nid namespace.ID, leaves [][]byte, root []byte) bool { nth := NewNmtHasher(h, nid.Size(), proof.isMaxNamespaceIDIgnored) - leafData := append(nid, data...) - return proof.verifyLeafHashes(nth, false, nid, [][]byte{nth.HashLeaf(leafData)}, root) + hashes := make([][]byte, len(leaves)) + for i, d := range leaves { + leafData := append(append(make([]byte, 0, len(d)+len(nid)), nid...), d...) + hashes[i] = nth.HashLeaf(leafData) + } + + return proof.verifyLeafHashes(nth, false, nid, hashes, root) } // nextSubtreeSize returns the size of the subtree adjacent to start that does diff --git a/proof_test.go b/proof_test.go index a496a2a..44a7bd6 100644 --- a/proof_test.go +++ b/proof_test.go @@ -1,6 +1,7 @@ package nmt import ( + "bytes" "crypto/sha256" "testing" @@ -90,3 +91,67 @@ func rangeProof(t *testing.T, n *NamespacedMerkleTree, start, end int) [][]byte } return incompleteRange } + +func TestProof_MultipleLeaves(t *testing.T) { + n := New(sha256.New()) + ns := []byte{1, 2, 3, 4, 5, 6, 7, 8} + rawData := [][]byte{ + bytes.Repeat([]byte{1}, 100), + bytes.Repeat([]byte{2}, 100), + bytes.Repeat([]byte{3}, 100), + bytes.Repeat([]byte{4}, 100), + bytes.Repeat([]byte{5}, 100), + bytes.Repeat([]byte{6}, 100), + bytes.Repeat([]byte{7}, 100), + bytes.Repeat([]byte{8}, 100), + } + + for _, d := range rawData { + err := n.Push(safeAppend(ns, d)) + if err != nil { + t.Fatal(err) + } + } + + type args struct { + start, end int + root []byte + } + tests := []struct { + name string + args args + want bool + }{ + { + "3rd through 5th leaf", args{2, 4, n.Root()}, true, + }, + { + "single leaf", args{2, 3, n.Root()}, true, + }, + { + "first leaf", args{0, 1, n.Root()}, true, + }, + { + "most leaves", args{0, 7, n.Root()}, true, + }, + { + "most leaves", args{0, 7, bytes.Repeat([]byte{1}, 48)}, false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + proof, err := n.ProveRange(tt.args.start, tt.args.end) + if err != nil { + t.Fatal(err) + } + got := proof.VerifyInclusion(sha256.New(), ns, rawData[tt.args.start:tt.args.end], tt.args.root) + if got != tt.want { + t.Errorf("VerifyInclusion() got = %v, want %v", got, tt.want) + } + }) + } +} + +func safeAppend(id, data []byte) []byte { + return append(append(make([]byte, 0, len(id)+len(data)), id...), data...) +}