From 1e866571296c3eae0702e37bcd18eb030820fe86 Mon Sep 17 00:00:00 2001 From: Mostafa Date: Tue, 24 Sep 2024 20:17:22 +0800 Subject: [PATCH] refactor(sync): define errors for sync package --- sync/bundle/bundle.go | 11 ++++----- sync/bundle/message/blocks_request.go | 5 ++-- sync/bundle/message/blocks_request_test.go | 7 +++--- sync/bundle/message/errors.go | 21 ++++++++++++++++ sync/bundle/message/hello.go | 5 ++-- sync/bundle/message/hello_test.go | 10 ++++---- sync/bundle/message/message.go | 28 ++++++++++++---------- sync/bundle/message/message_test.go | 10 +++++++- sync/bundle/message/query_proposal.go | 3 +-- sync/bundle/message/query_proposal_test.go | 4 ++-- sync/bundle/message/query_votes.go | 3 +-- sync/bundle/message/query_votes_test.go | 4 ++-- sync/bundle/message/transactions.go | 3 +-- sync/bundle/message/transactions_test.go | 4 ++-- sync/firewall/errors.go | 15 +++++++----- sync/firewall/firewall.go | 14 ++++------- sync/firewall/firewall_test.go | 27 ++++++++++++++------- util/errors/errors.go | 6 ----- www/grpc/wallet_test.go | 2 +- 19 files changed, 107 insertions(+), 75 deletions(-) create mode 100644 sync/bundle/message/errors.go diff --git a/sync/bundle/bundle.go b/sync/bundle/bundle.go index 46047a439..6f45afde7 100644 --- a/sync/bundle/bundle.go +++ b/sync/bundle/bundle.go @@ -7,7 +7,6 @@ import ( "github.com/fxamacker/cbor/v2" "github.com/pactus-project/pactus/sync/bundle/message" "github.com/pactus-project/pactus/util" - "github.com/pactus-project/pactus/util/errors" ) const ( @@ -84,19 +83,19 @@ func (b *Bundle) Decode(r io.Reader) (int, error) { err := d.Decode(&bdl) bytesRead := d.NumBytesRead() if err != nil { - return bytesRead, errors.Errorf(errors.ErrInvalidMessage, "%s", err.Error()) + return bytesRead, err } data := bdl.MessageData - msg := message.MakeMessage(bdl.MessageType) - if msg == nil { - return bytesRead, errors.Errorf(errors.ErrInvalidMessage, "invalid data") + msg, err := message.MakeMessage(bdl.MessageType) + if err != nil { + return bytesRead, err } if util.IsFlagSet(bdl.Flags, BundleFlagCompressed) { c, err := util.DecompressBuffer(bdl.MessageData) if err != nil { - return bytesRead, errors.Errorf(errors.ErrInvalidMessage, "%s", err.Error()) + return bytesRead, err } data = c } diff --git a/sync/bundle/message/blocks_request.go b/sync/bundle/message/blocks_request.go index dcb2ce923..e7f0ab692 100644 --- a/sync/bundle/message/blocks_request.go +++ b/sync/bundle/message/blocks_request.go @@ -4,7 +4,6 @@ import ( "fmt" "github.com/pactus-project/pactus/network" - "github.com/pactus-project/pactus/util/errors" ) type BlocksRequestMessage struct { @@ -27,10 +26,10 @@ func (m *BlocksRequestMessage) To() uint32 { func (m *BlocksRequestMessage) BasicCheck() error { if m.From == 0 { - return errors.Errorf(errors.ErrInvalidHeight, "height is zero") + return BasicCheckError{Reason: "invalid height"} } if m.Count == 0 { - return errors.Errorf(errors.ErrInvalidMessage, "count is zero") + return BasicCheckError{Reason: "count is zero"} } return nil diff --git a/sync/bundle/message/blocks_request_test.go b/sync/bundle/message/blocks_request_test.go index 764b94d6e..6b4a5cb4e 100644 --- a/sync/bundle/message/blocks_request_test.go +++ b/sync/bundle/message/blocks_request_test.go @@ -3,7 +3,6 @@ package message import ( "testing" - "github.com/pactus-project/pactus/util/errors" "github.com/stretchr/testify/assert" ) @@ -16,12 +15,14 @@ func TestBlocksRequestMessage(t *testing.T) { t.Run("Invalid height", func(t *testing.T) { m := NewBlocksRequestMessage(1, 0, 0) - assert.Equal(t, errors.ErrInvalidHeight, errors.Code(m.BasicCheck())) + err := m.BasicCheck() + assert.ErrorIs(t, err, BasicCheckError{Reason: "invalid height"}) }) t.Run("Invalid count", func(t *testing.T) { m := NewBlocksRequestMessage(1, 200, 0) - assert.Equal(t, errors.ErrInvalidMessage, errors.Code(m.BasicCheck())) + err := m.BasicCheck() + assert.ErrorIs(t, err, BasicCheckError{Reason: "count is zero"}) }) t.Run("OK", func(t *testing.T) { diff --git a/sync/bundle/message/errors.go b/sync/bundle/message/errors.go new file mode 100644 index 000000000..cf4d3f33d --- /dev/null +++ b/sync/bundle/message/errors.go @@ -0,0 +1,21 @@ +package message + +import "fmt" + +// BasicCheckError is returned when the basic check on the message fails. +type BasicCheckError struct { + Reason string +} + +func (e BasicCheckError) Error() string { + return e.Reason +} + +// InvalidMessageTypeError is returned when the message type is not valid. +type InvalidMessageTypeError struct { + Type int +} + +func (e InvalidMessageTypeError) Error() string { + return fmt.Sprintf("invalid message type: %d", e.Type) +} diff --git a/sync/bundle/message/hello.go b/sync/bundle/message/hello.go index 5457228a3..c53d61af7 100644 --- a/sync/bundle/message/hello.go +++ b/sync/bundle/message/hello.go @@ -9,7 +9,6 @@ import ( "github.com/pactus-project/pactus/crypto/hash" "github.com/pactus-project/pactus/network" "github.com/pactus-project/pactus/sync/peerset/peer/service" - "github.com/pactus-project/pactus/util/errors" "github.com/pactus-project/pactus/version" ) @@ -43,10 +42,10 @@ func NewHelloMessage(pid peer.ID, moniker string, func (m *HelloMessage) BasicCheck() error { if m.Signature == nil { - return errors.Error(errors.ErrInvalidSignature) + return BasicCheckError{"no signature"} } if len(m.PublicKeys) == 0 { - return errors.Error(errors.ErrInvalidPublicKey) + return BasicCheckError{"no public key"} } aggPublicKey := bls.PublicKeyAggregate(m.PublicKeys...) diff --git a/sync/bundle/message/hello_test.go b/sync/bundle/message/hello_test.go index d28f694a4..c964ee68f 100644 --- a/sync/bundle/message/hello_test.go +++ b/sync/bundle/message/hello_test.go @@ -7,7 +7,6 @@ import ( "github.com/pactus-project/pactus/crypto" "github.com/pactus-project/pactus/crypto/bls" "github.com/pactus-project/pactus/sync/peerset/peer/service" - "github.com/pactus-project/pactus/util/errors" "github.com/pactus-project/pactus/util/testsuite" "github.com/stretchr/testify/assert" ) @@ -27,7 +26,8 @@ func TestHelloMessage(t *testing.T) { m.Sign([]*bls.ValidatorKey{valKey}) m.Signature = ts.RandBLSSignature() - assert.ErrorIs(t, crypto.ErrInvalidSignature, m.BasicCheck()) + err := m.BasicCheck() + assert.ErrorIs(t, err, crypto.ErrInvalidSignature) }) t.Run("Signature is nil", func(t *testing.T) { @@ -37,7 +37,8 @@ func TestHelloMessage(t *testing.T) { m.Sign([]*bls.ValidatorKey{valKey}) m.Signature = nil - assert.Equal(t, errors.ErrInvalidSignature, errors.Code(m.BasicCheck())) + err := m.BasicCheck() + assert.ErrorIs(t, err, BasicCheckError{"no signature"}) }) t.Run("PublicKeys are empty", func(t *testing.T) { @@ -47,7 +48,8 @@ func TestHelloMessage(t *testing.T) { m.Sign([]*bls.ValidatorKey{valKey}) m.PublicKeys = make([]*bls.PublicKey, 0) - assert.Equal(t, errors.ErrInvalidPublicKey, errors.Code(m.BasicCheck())) + err := m.BasicCheck() + assert.ErrorIs(t, err, BasicCheckError{"no public key"}) }) t.Run("MyTimeUnixMilli of time1 is less or equal than hello message time", func(t *testing.T) { diff --git a/sync/bundle/message/message.go b/sync/bundle/message/message.go index 0f9262837..1db1e2bbd 100644 --- a/sync/bundle/message/message.go +++ b/sync/bundle/message/message.go @@ -89,41 +89,45 @@ func (t Type) String() string { } } -func MakeMessage(t Type) Message { +func MakeMessage(t Type) (Message, error) { + var msg Message switch t { case TypeHello: - return &HelloMessage{} + msg = &HelloMessage{} case TypeHelloAck: - return &HelloAckMessage{} + msg = &HelloAckMessage{} case TypeTransaction: - return &TransactionsMessage{} + msg = &TransactionsMessage{} case TypeQueryProposal: - return &QueryProposalMessage{} + msg = &QueryProposalMessage{} case TypeProposal: - return &ProposalMessage{} + msg = &ProposalMessage{} case TypeQueryVote: - return &QueryVoteMessage{} + msg = &QueryVoteMessage{} case TypeVote: - return &VoteMessage{} + msg = &VoteMessage{} case TypeBlockAnnounce: - return &BlockAnnounceMessage{} + msg = &BlockAnnounceMessage{} case TypeBlocksRequest: - return &BlocksRequestMessage{} + msg = &BlocksRequestMessage{} case TypeBlocksResponse: - return &BlocksResponseMessage{} + msg = &BlocksResponseMessage{} + + default: + return nil, InvalidMessageTypeError{Type: int(t)} } // - return nil + return msg, nil } type Message interface { diff --git a/sync/bundle/message/message_test.go b/sync/bundle/message/message_test.go index 4e4b18f70..f4a80c725 100644 --- a/sync/bundle/message/message_test.go +++ b/sync/bundle/message/message_test.go @@ -5,6 +5,7 @@ import ( "github.com/pactus-project/pactus/network" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestMessage(t *testing.T) { @@ -27,9 +28,16 @@ func TestMessage(t *testing.T) { } for _, tc := range testCases { - msg := MakeMessage(tc.msgType) + msg, err := MakeMessage(tc.msgType) + require.NoError(t, err) + assert.Equal(t, tc.typeName, msg.Type().String()) assert.Equal(t, tc.topicID, msg.TopicID()) assert.Equal(t, tc.shouldBroadcast, msg.ShouldBroadcast()) } } + +func TestInvalidMessageType(t *testing.T) { + _, err := MakeMessage(66) + assert.ErrorIs(t, err, InvalidMessageTypeError{Type: 66}) +} diff --git a/sync/bundle/message/query_proposal.go b/sync/bundle/message/query_proposal.go index 8e0cc945f..6b918fe5e 100644 --- a/sync/bundle/message/query_proposal.go +++ b/sync/bundle/message/query_proposal.go @@ -5,7 +5,6 @@ import ( "github.com/pactus-project/pactus/crypto" "github.com/pactus-project/pactus/network" - "github.com/pactus-project/pactus/util/errors" ) type QueryProposalMessage struct { @@ -24,7 +23,7 @@ func NewQueryProposalMessage(height uint32, round int16, querier crypto.Address) func (m *QueryProposalMessage) BasicCheck() error { if m.Round < 0 { - return errors.Error(errors.ErrInvalidRound) + return BasicCheckError{Reason: "invalid round"} } return nil diff --git a/sync/bundle/message/query_proposal_test.go b/sync/bundle/message/query_proposal_test.go index 733c33f2d..bfe2a451d 100644 --- a/sync/bundle/message/query_proposal_test.go +++ b/sync/bundle/message/query_proposal_test.go @@ -3,7 +3,6 @@ package message import ( "testing" - "github.com/pactus-project/pactus/util/errors" "github.com/pactus-project/pactus/util/testsuite" "github.com/stretchr/testify/assert" ) @@ -19,7 +18,8 @@ func TestQueryProposalMessage(t *testing.T) { t.Run("Invalid round", func(t *testing.T) { m := NewQueryProposalMessage(0, -1, ts.RandValAddress()) - assert.Equal(t, errors.ErrInvalidRound, errors.Code(m.BasicCheck())) + err := m.BasicCheck() + assert.ErrorIs(t, err, BasicCheckError{"invalid round"}) }) t.Run("OK", func(t *testing.T) { diff --git a/sync/bundle/message/query_votes.go b/sync/bundle/message/query_votes.go index a013b9eb9..53b3d44c5 100644 --- a/sync/bundle/message/query_votes.go +++ b/sync/bundle/message/query_votes.go @@ -5,7 +5,6 @@ import ( "github.com/pactus-project/pactus/crypto" "github.com/pactus-project/pactus/network" - "github.com/pactus-project/pactus/util/errors" ) type QueryVoteMessage struct { @@ -24,7 +23,7 @@ func NewQueryVoteMessage(height uint32, round int16, querier crypto.Address) *Qu func (m *QueryVoteMessage) BasicCheck() error { if m.Round < 0 { - return errors.Error(errors.ErrInvalidRound) + return BasicCheckError{Reason: "invalid round"} } return nil diff --git a/sync/bundle/message/query_votes_test.go b/sync/bundle/message/query_votes_test.go index c31b6fc5a..e7b520a98 100644 --- a/sync/bundle/message/query_votes_test.go +++ b/sync/bundle/message/query_votes_test.go @@ -3,7 +3,6 @@ package message import ( "testing" - "github.com/pactus-project/pactus/util/errors" "github.com/pactus-project/pactus/util/testsuite" "github.com/stretchr/testify/assert" ) @@ -19,7 +18,8 @@ func TestQueryVoteMessage(t *testing.T) { t.Run("Invalid round", func(t *testing.T) { m := NewQueryVoteMessage(0, -1, ts.RandValAddress()) - assert.Equal(t, errors.ErrInvalidRound, errors.Code(m.BasicCheck())) + err := m.BasicCheck() + assert.ErrorIs(t, err, BasicCheckError{Reason: "invalid round"}) }) t.Run("OK", func(t *testing.T) { diff --git a/sync/bundle/message/transactions.go b/sync/bundle/message/transactions.go index 2d5a3ed26..922c0d5fd 100644 --- a/sync/bundle/message/transactions.go +++ b/sync/bundle/message/transactions.go @@ -6,7 +6,6 @@ import ( "github.com/pactus-project/pactus/network" "github.com/pactus-project/pactus/types/tx" - "github.com/pactus-project/pactus/util/errors" ) type TransactionsMessage struct { @@ -21,7 +20,7 @@ func NewTransactionsMessage(trxs []*tx.Tx) *TransactionsMessage { func (m *TransactionsMessage) BasicCheck() error { if len(m.Transactions) == 0 { - return errors.Errorf(errors.ErrInvalidMessage, "no transaction") + return BasicCheckError{Reason: "no transaction"} } for _, trx := range m.Transactions { if err := trx.BasicCheck(); err != nil { diff --git a/sync/bundle/message/transactions_test.go b/sync/bundle/message/transactions_test.go index b13f2ccaa..30ea9d82e 100644 --- a/sync/bundle/message/transactions_test.go +++ b/sync/bundle/message/transactions_test.go @@ -4,7 +4,6 @@ import ( "testing" "github.com/pactus-project/pactus/types/tx" - "github.com/pactus-project/pactus/util/errors" "github.com/pactus-project/pactus/util/testsuite" "github.com/stretchr/testify/assert" ) @@ -20,7 +19,8 @@ func TestTransactionsMessage(t *testing.T) { t.Run("No transactions", func(t *testing.T) { m := NewTransactionsMessage(nil) - assert.Equal(t, errors.ErrInvalidMessage, errors.Code(m.BasicCheck())) + err := m.BasicCheck() + assert.ErrorIs(t, err, BasicCheckError{Reason: "no transaction"}) }) t.Run("OK", func(t *testing.T) { diff --git a/sync/firewall/errors.go b/sync/firewall/errors.go index da3bf78e3..c9e579c45 100644 --- a/sync/firewall/errors.go +++ b/sync/firewall/errors.go @@ -7,6 +7,15 @@ import ( lp2pcore "github.com/libp2p/go-libp2p/core" ) +// ErrGossipMessage is returned when a stream message sends as gossip message. +var ErrGossipMessage = errors.New("receive stream message as gossip message") + +// ErrStreamMessage is returned when a gossip message sends as stream message. +var ErrStreamMessage = errors.New("receive gossip message as stream message") + +// ErrNetworkMismatch is returned when the bundle doesn't belong to this network. +var ErrNetworkMismatch = errors.New("bundle is not for this network") + // PeerBannedError is returned when a message received from a banned peer-id or banned address. type PeerBannedError struct { PeerID lp2pcore.PeerID @@ -16,9 +25,3 @@ type PeerBannedError struct { func (e PeerBannedError) Error() string { return fmt.Sprintf("peer is banned, peer-id: %s, remote-address: %s", e.PeerID, e.Address) } - -// ErrGossipMessage is returned when a stream message sends as gossip message. -var ErrGossipMessage = errors.New("receive stream message as gossip message") - -// ErrStreamMessage is returned when a gossip message sends as stream message. -var ErrStreamMessage = errors.New("receive gossip message as stream message") diff --git a/sync/firewall/firewall.go b/sync/firewall/firewall.go index 19d97ae52..38a8c9268 100644 --- a/sync/firewall/firewall.go +++ b/sync/firewall/firewall.go @@ -13,7 +13,6 @@ import ( "github.com/pactus-project/pactus/sync/peerset" "github.com/pactus-project/pactus/sync/peerset/peer" "github.com/pactus-project/pactus/sync/peerset/peer/status" - "github.com/pactus-project/pactus/util/errors" "github.com/pactus-project/pactus/util/ipblocker" "github.com/pactus-project/pactus/util/logger" "github.com/pactus-project/pactus/util/ratelimit" @@ -152,7 +151,7 @@ func (f *Firewall) decodeBundle(r io.Reader, pid peer.ID) (*bundle.Bundle, error bdl := new(bundle.Bundle) bytesRead, err := bdl.Decode(r) if err != nil { - return nil, errors.Errorf(errors.ErrInvalidMessage, "%s", err.Error()) + return nil, err } f.peerSet.IncreaseReceivedBytesCounter(pid, bdl.Message.Type(), int64(bytesRead)) @@ -161,26 +160,23 @@ func (f *Firewall) decodeBundle(r io.Reader, pid peer.ID) (*bundle.Bundle, error func (f *Firewall) checkBundle(bdl *bundle.Bundle) error { if err := bdl.BasicCheck(); err != nil { - return errors.Errorf(errors.ErrInvalidMessage, "%s", err.Error()) + return err } switch f.state.Genesis().ChainType() { case genesis.Mainnet: if bdl.Flags&0x3 != bundle.BundleFlagNetworkMainnet { - return errors.Errorf(errors.ErrInvalidMessage, - "bundle is not for the mainnet") + return ErrNetworkMismatch } case genesis.Testnet: if bdl.Flags&0x3 != bundle.BundleFlagNetworkTestnet { - return errors.Errorf(errors.ErrInvalidMessage, - "bundle is not for the testnet") + return ErrNetworkMismatch } case genesis.Localnet: if bdl.Flags&0x3 != 0 { - return errors.Errorf(errors.ErrInvalidMessage, - "bundle is not for the localnet") + return ErrNetworkMismatch } } diff --git a/sync/firewall/firewall_test.go b/sync/firewall/firewall_test.go index 6e2a53697..6e95ce79b 100644 --- a/sync/firewall/firewall_test.go +++ b/sync/firewall/firewall_test.go @@ -328,13 +328,16 @@ func TestNetworkFlagsMainnet(t *testing.T) { bdl := bundle.NewBundle(message.NewQueryVoteMessage(td.RandHeight(), td.RandRound(), td.RandValAddress())) bdl.Flags = util.SetFlag(bdl.Flags, bundle.BundleFlagNetworkMainnet) - assert.NoError(t, td.firewall.checkBundle(bdl)) + err := td.firewall.checkBundle(bdl) + assert.NoError(t, err) bdl.Flags = util.SetFlag(bdl.Flags, bundle.BundleFlagNetworkTestnet) - assert.Error(t, td.firewall.checkBundle(bdl)) + err = td.firewall.checkBundle(bdl) + assert.ErrorIs(t, err, ErrNetworkMismatch) bdl.Flags = 0 - assert.Error(t, td.firewall.checkBundle(bdl)) + err = td.firewall.checkBundle(bdl) + assert.ErrorIs(t, err, ErrNetworkMismatch) } func TestNetworkFlagsTestnet(t *testing.T) { @@ -343,13 +346,16 @@ func TestNetworkFlagsTestnet(t *testing.T) { bdl := bundle.NewBundle(message.NewQueryVoteMessage(td.RandHeight(), td.RandRound(), td.RandValAddress())) bdl.Flags = util.SetFlag(bdl.Flags, bundle.BundleFlagNetworkTestnet) - assert.NoError(t, td.firewall.checkBundle(bdl)) + err := td.firewall.checkBundle(bdl) + assert.NoError(t, err) bdl.Flags = util.SetFlag(bdl.Flags, bundle.BundleFlagNetworkMainnet) - assert.Error(t, td.firewall.checkBundle(bdl)) + err = td.firewall.checkBundle(bdl) + assert.ErrorIs(t, err, ErrNetworkMismatch) bdl.Flags = 0 - assert.Error(t, td.firewall.checkBundle(bdl)) + err = td.firewall.checkBundle(bdl) + assert.ErrorIs(t, err, ErrNetworkMismatch) } func TestNetworkFlagsLocalnet(t *testing.T) { @@ -358,13 +364,16 @@ func TestNetworkFlagsLocalnet(t *testing.T) { bdl := bundle.NewBundle(message.NewQueryVoteMessage(td.RandHeight(), td.RandRound(), td.RandValAddress())) bdl.Flags = util.SetFlag(bdl.Flags, bundle.BundleFlagNetworkTestnet) - assert.Error(t, td.firewall.checkBundle(bdl)) + err := td.firewall.checkBundle(bdl) + assert.ErrorIs(t, err, ErrNetworkMismatch) bdl.Flags = util.SetFlag(bdl.Flags, bundle.BundleFlagNetworkMainnet) - assert.Error(t, td.firewall.checkBundle(bdl)) + err = td.firewall.checkBundle(bdl) + assert.ErrorIs(t, err, ErrNetworkMismatch) bdl.Flags = 0 - assert.NoError(t, td.firewall.checkBundle(bdl)) + err = td.firewall.checkBundle(bdl) + assert.NoError(t, err) } func TestParseP2PAddr(t *testing.T) { diff --git a/util/errors/errors.go b/util/errors/errors.go index 5ba9a28fb..aceea67f9 100644 --- a/util/errors/errors.go +++ b/util/errors/errors.go @@ -10,9 +10,6 @@ const ( ErrInvalidPublicKey ErrInvalidPrivateKey ErrInvalidSignature - ErrInvalidHeight - ErrInvalidRound - ErrInvalidMessage ErrCount ) @@ -23,9 +20,6 @@ var messages = map[int]string{ ErrInvalidPublicKey: "invalid public key", ErrInvalidPrivateKey: "invalid private key", ErrInvalidSignature: "invalid signature", - ErrInvalidHeight: "invalid height", - ErrInvalidRound: "invalid round", - ErrInvalidMessage: "invalid message", } type withCodeError struct { diff --git a/www/grpc/wallet_test.go b/www/grpc/wallet_test.go index 0a7317998..eea1270c3 100644 --- a/www/grpc/wallet_test.go +++ b/www/grpc/wallet_test.go @@ -179,7 +179,7 @@ func TestLoadWallet(t *testing.T) { signedTx, err := tx.FromBytes(td.DecodingHex(res.SignedRawTransaction)) assert.NoError(t, err) assert.NotNil(t, signedTx.Signature()) - assert.Nil(t, signedTx.BasicCheck()) + assert.NoError(t, signedTx.BasicCheck()) }) t.Run("Sign raw transaction using not loaded wallet", func(t *testing.T) {