diff --git a/pkg/proof/proof.go b/pkg/proof/proof.go index 5845d3b1ae..5c4a118c60 100644 --- a/pkg/proof/proof.go +++ b/pkg/proof/proof.go @@ -5,6 +5,9 @@ import ( "errors" "fmt" + tmbytes "github.com/tendermint/tendermint/libs/bytes" + coretypes "github.com/tendermint/tendermint/proto/tendermint/types" + "github.com/celestiaorg/rsmt2d" "github.com/celestiaorg/celestia-app/pkg/appconsts" @@ -14,8 +17,6 @@ import ( "github.com/celestiaorg/celestia-app/pkg/square" "github.com/celestiaorg/celestia-app/pkg/wrapper" "github.com/tendermint/tendermint/crypto/merkle" - tmbytes "github.com/tendermint/tendermint/libs/bytes" - tmproto "github.com/tendermint/tendermint/proto/tendermint/types" "github.com/tendermint/tendermint/types" ) @@ -98,9 +99,11 @@ func NewShareInclusionProofFromEDS( _, allProofs := merkle.ProofsFromByteSlices(append(edsRowRoots, edsColRoots...)) rowProofs := make([]*merkle.Proof, endRow-startRow+1) rowRoots := make([]tmbytes.HexBytes, endRow-startRow+1) + bzRowRoots := make([][]byte, endRow-startRow+1) for i := startRow; i <= endRow; i++ { rowProofs[i-startRow] = allProofs[i] rowRoots[i-startRow] = edsRowRoots[i] + bzRowRoots[i-startRow] = edsRowRoots[i] } // get the extended rows containing the shares. @@ -113,9 +116,30 @@ func NewShareInclusionProofFromEDS( rows[i-startRow] = shares } - var shareProofs []*tmproto.NMTProof //nolint:prealloc + shareProofs, rawShares, err := CreateShareToRowRootProofs(squareSize, rows, bzRowRoots, startLeaf, endLeaf) + if err != nil { + return types.ShareProof{}, err + } + return types.ShareProof{ + RowProof: types.RowProof{ + RowRoots: rowRoots, + Proofs: rowProofs, + StartRow: uint32(startRow), + EndRow: uint32(endRow), + }, + Data: rawShares, + ShareProofs: shareProofs, + NamespaceID: namespace.ID, + NamespaceVersion: uint32(namespace.Version), + }, nil +} + +// CreateShareToRowRootProofs takes a set of shares and their corresponding row roots, and generates +// an NMT inclusion proof of a set of shares, defined by startLeaf and endLeaf, to their corresponding row roots. +func CreateShareToRowRootProofs(squareSize int, rowShares [][]shares.Share, rowRoots [][]byte, startLeaf, endLeaf int) ([]*coretypes.NMTProof, [][]byte, error) { + shareProofs := make([]*coretypes.NMTProof, 0, len(rowRoots)) var rawShares [][]byte - for i, row := range rows { + for i, row := range rowShares { // create an nmt to generate a proof. // we have to re-create the tree as the eds one is not accessible. tree := wrapper.NewErasuredNamespacedMerkleTree(uint64(squareSize), uint(i)) @@ -124,17 +148,17 @@ func NewShareInclusionProofFromEDS( share.ToBytes(), ) if err != nil { - return types.ShareProof{}, err + return nil, nil, err } } // make sure that the generated root is the same as the eds row root. root, err := tree.Root() if err != nil { - return types.ShareProof{}, err + return nil, nil, err } - if !bytes.Equal(rowRoots[i].Bytes(), root) { - return types.ShareProof{}, errors.New("eds row root is different than tree root") + if !bytes.Equal(rowRoots[i], root) { + return nil, nil, errors.New("eds row root is different than tree root") } startLeafPos := startLeaf @@ -145,34 +169,22 @@ func NewShareInclusionProofFromEDS( startLeafPos = 0 } // if this is not the last row, then select for the rest of the row - if i != (len(rows) - 1) { + if i != (len(rowShares) - 1) { endLeafPos = squareSize - 1 } rawShares = append(rawShares, shares.ToBytes(row[startLeafPos:endLeafPos+1])...) - proof, err := tree.ProveRange(int(startLeafPos), int(endLeafPos+1)) + proof, err := tree.ProveRange(startLeafPos, endLeafPos+1) if err != nil { - return types.ShareProof{}, err + return nil, nil, err } - shareProofs = append(shareProofs, &tmproto.NMTProof{ + shareProofs = append(shareProofs, &coretypes.NMTProof{ Start: int32(proof.Start()), End: int32(proof.End()), Nodes: proof.Nodes(), LeafHash: proof.LeafHash(), }) } - - return types.ShareProof{ - RowProof: types.RowProof{ - RowRoots: rowRoots, - Proofs: rowProofs, - StartRow: uint32(startRow), - EndRow: uint32(endRow), - }, - Data: rawShares, - ShareProofs: shareProofs, - NamespaceID: namespace.ID, - NamespaceVersion: uint32(namespace.Version), - }, nil + return shareProofs, rawShares, nil }