Skip to content

Commit

Permalink
small refactor + extended checks
Browse files Browse the repository at this point in the history
  • Loading branch information
sstanculeanu committed Jan 10, 2025
1 parent d4243bc commit 681f506
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 56 deletions.
25 changes: 25 additions & 0 deletions common/common.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package common

import (
"fmt"

"github.com/multiversx/mx-chain-core-go/core"
"github.com/multiversx/mx-chain-core-go/core/check"
"github.com/multiversx/mx-chain-core-go/data"
)

Expand All @@ -24,3 +27,25 @@ func IsFlagEnabledAfterEpochsStartBlock(header data.HeaderHandler, enableEpochsH
func ShouldBlockHavePrevProof(header data.HeaderHandler, enableEpochsHandler EnableEpochsHandler, flag core.EnableEpochFlag) bool {
return IsFlagEnabledAfterEpochsStartBlock(header, enableEpochsHandler, flag) && header.GetNonce() > 1
}

// VerifyProofAgainstHeader verifies the fields on the proof match the ones on the header
func VerifyProofAgainstHeader(proof data.HeaderProofHandler, header data.HeaderHandler) error {
if check.IfNilReflect(proof) {
return ErrInvalidHeaderProof
}

if proof.GetHeaderNonce() != header.GetNonce() {
return fmt.Errorf("%w, nonce mismatch", ErrInvalidHeaderProof)
}
if proof.GetHeaderShardId() != header.GetShardID() {
return fmt.Errorf("%w, shard id mismatch", ErrInvalidHeaderProof)
}
if proof.GetHeaderEpoch() != header.GetEpoch() {
return fmt.Errorf("%w, epoch mismatch", ErrInvalidHeaderProof)
}
if proof.GetHeaderRound() != header.GetRound() {
return fmt.Errorf("%w, round mismatch", ErrInvalidHeaderProof)
}

return nil
}
3 changes: 3 additions & 0 deletions common/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@ var ErrNilWasmChangeLocker = errors.New("nil wasm change locker")

// ErrNilStateSyncNotifierSubscriber signals that a nil state sync notifier subscriber has been provided
var ErrNilStateSyncNotifierSubscriber = errors.New("nil state sync notifier subscriber")

// ErrInvalidHeaderProof signals that an invalid equivalent proof has been provided
var ErrInvalidHeaderProof = errors.New("invalid equivalent proof")
33 changes: 21 additions & 12 deletions consensus/spos/bls/v2/subroundBlock.go
Original file line number Diff line number Diff line change
Expand Up @@ -389,19 +389,27 @@ func isProofEmpty(proof data.HeaderProofHandler) bool {
len(proof.GetHeaderHash()) == 0
}

func (sr *subroundBlock) saveProofForPreviousHeaderIfNeeded(header data.HeaderHandler) {
func (sr *subroundBlock) saveProofForPreviousHeaderIfNeeded(header data.HeaderHandler, prevHeader data.HeaderHandler) {
hasProof := sr.EquivalentProofsPool().HasProof(sr.ShardCoordinator().SelfId(), header.GetPrevHash())
if hasProof {
log.Debug("saveProofForPreviousHeaderIfNeeded: no need to set proof since it is already saved")
return
}

proof := header.GetPreviousProof()
err := sr.EquivalentProofsPool().AddProof(proof)
err := common.VerifyProofAgainstHeader(proof, prevHeader)
if err != nil {
log.Debug("saveProofForPreviousHeaderIfNeeded: invalid proof, %w", err)
return
}

err = sr.EquivalentProofsPool().AddProof(proof)
if err != nil {
log.Debug("saveProofForPreviousHeaderIfNeeded: failed to add proof, %w", err)
return
}

return
}

// receivedBlockBody method is called when a block body is received through the block body channel
Expand Down Expand Up @@ -445,30 +453,30 @@ func (sr *subroundBlock) receivedBlockBody(ctx context.Context, cnsDta *consensu
return blockProcessedWithSuccess
}

func (sr *subroundBlock) isHeaderForCurrentConsensus(header data.HeaderHandler) bool {
func (sr *subroundBlock) isHeaderForCurrentConsensus(header data.HeaderHandler) (bool, data.HeaderHandler) {
if check.IfNil(header) {
return false
return false, nil
}
if header.GetShardID() != sr.ShardCoordinator().SelfId() {
return false
return false, nil
}
if header.GetRound() != uint64(sr.RoundHandler().Index()) {
return false
return false, nil
}

prevHeader, prevHash := sr.getPrevHeaderAndHash()
if check.IfNil(prevHeader) {
return false
return false, nil
}
if !bytes.Equal(header.GetPrevHash(), prevHash) {
return false
return false, nil
}
if header.GetNonce() != prevHeader.GetNonce()+1 {
return false
return false, nil
}
prevRandSeed := prevHeader.GetRandSeed()

return bytes.Equal(header.GetPrevRandSeed(), prevRandSeed)
return bytes.Equal(header.GetPrevRandSeed(), prevRandSeed), prevHeader
}

func (sr *subroundBlock) getLeaderForHeader(headerHandler data.HeaderHandler) ([]byte, error) {
Expand All @@ -495,7 +503,8 @@ func (sr *subroundBlock) receivedBlockHeader(headerHandler data.HeaderHandler) {
return
}

if !sr.isHeaderForCurrentConsensus(headerHandler) {
isHeaderForCurrentConsensus, prevHeader := sr.isHeaderForCurrentConsensus(headerHandler)
if !isHeaderForCurrentConsensus {
return
}

Expand Down Expand Up @@ -539,7 +548,7 @@ func (sr *subroundBlock) receivedBlockHeader(headerHandler data.HeaderHandler) {
sr.SetData(sr.Hasher().Compute(string(marshalledHeader)))
sr.SetHeader(headerHandler)

sr.saveProofForPreviousHeaderIfNeeded(headerHandler)
sr.saveProofForPreviousHeaderIfNeeded(headerHandler, prevHeader)

log.Debug("step 1: block header has been received",
"nonce", sr.GetHeader().GetNonce(),
Expand Down
23 changes: 1 addition & 22 deletions process/block/baseProcess.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,28 +231,7 @@ func (bp *baseProcessor) checkPrevProofValidity(prevHeader, headerHandler data.H
}

prevProof := headerHandler.GetPreviousProof()
return bp.verifyProofAgainstHeader(prevProof, prevHeader)
}

func (bp *baseProcessor) verifyProofAgainstHeader(proof data.HeaderProofHandler, header data.HeaderHandler) error {
if check.IfNilReflect(proof) {
return process.ErrMissingHeaderProof
}

if proof.GetHeaderNonce() != header.GetNonce() {
return fmt.Errorf("%w, nonce mismatch", process.ErrInvalidHeaderProof)
}
if proof.GetHeaderShardId() != header.GetShardID() {
return fmt.Errorf("%w, shard id mismatch", process.ErrInvalidHeaderProof)
}
if proof.GetHeaderEpoch() != header.GetEpoch() {
return fmt.Errorf("%w, epoch mismatch", process.ErrInvalidHeaderProof)
}
if proof.GetHeaderRound() != header.GetRound() {
return fmt.Errorf("%w, round mismatch", process.ErrInvalidHeaderProof)
}

return nil
return common.VerifyProofAgainstHeader(prevProof, prevHeader)
}

// checkScheduledRootHash checks if the scheduled root hash from the given header is the same with the current user accounts state root hash
Expand Down
16 changes: 2 additions & 14 deletions process/block/interceptedBlocks/interceptedEquivalentProof.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/multiversx/mx-chain-core-go/data"
"github.com/multiversx/mx-chain-core-go/data/block"
"github.com/multiversx/mx-chain-core-go/marshal"
"github.com/multiversx/mx-chain-go/common"
"github.com/multiversx/mx-chain-go/consensus"
"github.com/multiversx/mx-chain-go/dataRetriever"
proofscache "github.com/multiversx/mx-chain-go/dataRetriever/dataPool/proofsCache"
Expand Down Expand Up @@ -136,20 +137,7 @@ func (iep *interceptedEquivalentProof) checkHeaderParamsFromProof() error {
return fmt.Errorf("%w while getting header for proof hash %s", err, hex.EncodeToString(iep.proof.GetHeaderHash()))
}

if iep.proof.GetHeaderNonce() != header.GetNonce() {
return fmt.Errorf("%w, nonce mismatch", ErrInvalidProof)
}
if iep.proof.GetHeaderShardId() != header.GetShardID() {
return fmt.Errorf("%w, shard id mismatch", ErrInvalidProof)
}
if iep.proof.GetHeaderEpoch() != header.GetEpoch() {
return fmt.Errorf("%w, epoch mismatch", ErrInvalidProof)
}
if iep.proof.GetHeaderRound() != header.GetRound() {
return fmt.Errorf("%w, round mismatch", ErrInvalidProof)
}

return nil
return common.VerifyProofAgainstHeader(iep.proof, header)
}

func (iep *interceptedEquivalentProof) integrity() error {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/multiversx/mx-chain-core-go/core"
"github.com/multiversx/mx-chain-core-go/data"
"github.com/multiversx/mx-chain-core-go/data/block"
"github.com/multiversx/mx-chain-go/common"
"github.com/multiversx/mx-chain-go/consensus/mock"
proofscache "github.com/multiversx/mx-chain-go/dataRetriever/dataPool/proofsCache"
"github.com/multiversx/mx-chain-go/process"
Expand Down Expand Up @@ -232,7 +233,7 @@ func TestInterceptedEquivalentProof_CheckValidity(t *testing.T) {
require.NoError(t, err)

err = iep.CheckValidity()
require.True(t, errors.Is(err, ErrInvalidProof))
require.True(t, errors.Is(err, common.ErrInvalidHeaderProof))
require.True(t, strings.Contains(err.Error(), "nonce mismatch"))
})

Expand All @@ -257,7 +258,7 @@ func TestInterceptedEquivalentProof_CheckValidity(t *testing.T) {
require.NoError(t, err)

err = iep.CheckValidity()
require.True(t, errors.Is(err, ErrInvalidProof))
require.True(t, errors.Is(err, common.ErrInvalidHeaderProof))
require.True(t, strings.Contains(err.Error(), "shard id mismatch"))
})

Expand All @@ -283,7 +284,7 @@ func TestInterceptedEquivalentProof_CheckValidity(t *testing.T) {
require.NoError(t, err)

err = iep.CheckValidity()
require.True(t, errors.Is(err, ErrInvalidProof))
require.True(t, errors.Is(err, common.ErrInvalidHeaderProof))
require.True(t, strings.Contains(err.Error(), "epoch mismatch"))
})

Expand All @@ -310,7 +311,7 @@ func TestInterceptedEquivalentProof_CheckValidity(t *testing.T) {
require.NoError(t, err)

err = iep.CheckValidity()
require.True(t, errors.Is(err, ErrInvalidProof))
require.True(t, errors.Is(err, common.ErrInvalidHeaderProof))
require.True(t, strings.Contains(err.Error(), "round mismatch"))
})

Expand Down
2 changes: 1 addition & 1 deletion process/block/metablock.go
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ func (mp *metaProcessor) checkProofsForShardData(header *block.MetaBlock) error
return err
}

err = mp.verifyProofAgainstHeader(prevProof, prevHeader)
err = common.VerifyProofAgainstHeader(prevProof, prevHeader)
if err != nil {
return err
}
Expand Down
3 changes: 0 additions & 3 deletions process/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -1251,9 +1251,6 @@ var ErrEmptyChainParametersConfiguration = errors.New("empty chain parameters co
// ErrNoMatchingConfigForProvidedEpoch signals that there is no matching configuration for the provided epoch
var ErrNoMatchingConfigForProvidedEpoch = errors.New("no matching configuration")

// ErrInvalidHeader is raised when header is invalid
var ErrInvalidHeader = errors.New("header is invalid")

// ErrNilHeaderProof signals that a nil header proof has been provided
var ErrNilHeaderProof = errors.New("nil header proof")

Expand Down

0 comments on commit 681f506

Please sign in to comment.