diff --git a/activation/e2e/nipost_test.go b/activation/e2e/nipost_test.go index 927a89851d..a9a13970e9 100644 --- a/activation/e2e/nipost_test.go +++ b/activation/e2e/nipost_test.go @@ -97,7 +97,7 @@ func launchPostSupervisor( provingOpts := activation.DefaultPostProvingOpts() provingOpts.RandomXMode = activation.PostRandomXModeLight - builder := activation.NewMockAtxBuilder(gomock.NewController(tb)) + builder := activation.NewMockatxBuilder(gomock.NewController(tb)) builder.EXPECT().Register(gomock.Any()) ps := activation.NewPostSupervisor(log, postCfg, provingOpts, mgr, builder) require.NoError(tb, ps.Start(cmdCfg, postOpts, sig)) diff --git a/activation/handler.go b/activation/handler.go index 48dcd79ee8..82466fadd8 100644 --- a/activation/handler.go +++ b/activation/handler.go @@ -20,8 +20,6 @@ import ( "github.com/spacemeshos/go-spacemesh/p2p" "github.com/spacemeshos/go-spacemesh/p2p/pubsub" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" - "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/system" ) @@ -128,7 +126,7 @@ func NewHandler( beacon: beacon, tortoise: tortoise, malPublisher: legacyMalPublisher, - malPublisher2: &MalfeasancePublisher{}, // TODO(mafa): pass real publisher when available + malPublisher2: &MalfeasanceHandlerV2{}, // TODO(mafa): pass real publisher when available signers: make(map[types.NodeID]*signing.EdSigner), }, @@ -145,7 +143,7 @@ func NewHandler( fetcher: fetcher, beacon: beacon, tortoise: tortoise, - malPublisher: &MalfeasancePublisher{}, // TODO(mafa): pass real publisher when available + malPublisher: &MalfeasanceHandlerV2{}, // TODO(mafa): pass real publisher when available }, } @@ -292,28 +290,3 @@ func (h *Handler) handleAtx(ctx context.Context, expHash types.Hash32, peer p2p. h.inProgress.Forget(key) return err } - -// Obtain the atxSignature of the given ATX. -func atxSignature(ctx context.Context, db sql.Executor, id types.ATXID) (types.EdSignature, error) { - var blob sql.Blob - v, err := atxs.LoadBlob(ctx, db, id.Bytes(), &blob) - if err != nil { - return types.EmptyEdSignature, err - } - - if len(blob.Bytes) == 0 { - // An empty blob indicates a golden ATX (after a checkpoint-recovery). - return types.EmptyEdSignature, fmt.Errorf("can't get signature for a golden (checkpointed) ATX: %s", id) - } - - // TODO: implement for ATX V2 - switch v { - case types.AtxV1: - var atx wire.ActivationTxV1 - if err := codec.Decode(blob.Bytes, &atx); err != nil { - return types.EmptyEdSignature, fmt.Errorf("decoding atx v1: %w", err) - } - return atx.Signature, nil - } - return types.EmptyEdSignature, fmt.Errorf("unsupported ATX version: %v", v) -} diff --git a/activation/handler_v1.go b/activation/handler_v1.go index 2b8c82a793..653727bc1d 100644 --- a/activation/handler_v1.go +++ b/activation/handler_v1.go @@ -510,9 +510,8 @@ func (h *HandlerV1) storeAtx(ctx context.Context, atx *types.ActivationTx, watx return fmt.Errorf("store atx: %w", err) } - added := h.cacheAtx(ctx, atx, malicious) h.beacon.OnAtx(atx) - if added != nil { + if added := h.cacheAtx(ctx, atx, malicious); added != nil { h.tortoise.OnAtx(atx.TargetEpoch(), atx.ID(), added) } @@ -631,3 +630,28 @@ func collectAtxDeps(goldenAtxId types.ATXID, atx *wire.ActivationTxV1) (types.Ha return types.BytesToHash(atx.NIPost.PostMetadata.Challenge), maps.Keys(filtered) } + +// Obtain the signature of the given ATX. +func atxSignature(ctx context.Context, db sql.Executor, id types.ATXID) (types.EdSignature, error) { + var blob sql.Blob + v, err := atxs.LoadBlob(ctx, db, id.Bytes(), &blob) + if err != nil { + return types.EmptyEdSignature, err + } + + if len(blob.Bytes) == 0 { + // An empty blob indicates a golden ATX (after a checkpoint-recovery). + return types.EmptyEdSignature, fmt.Errorf("can't get signature for a golden (checkpointed) ATX: %s", id) + } + + switch v { + case types.AtxV1: + var atx wire.ActivationTxV1 + if err := codec.Decode(blob.Bytes, &atx); err != nil { + return types.EmptyEdSignature, fmt.Errorf("decoding atx v1: %w", err) + } + return atx.Signature, nil + default: // only needed for V1 ATXs + return types.EmptyEdSignature, fmt.Errorf("unsupported ATX version: %v", v) + } +} diff --git a/activation/interface.go b/activation/interface.go index ae410e6291..6b61383cf0 100644 --- a/activation/interface.go +++ b/activation/interface.go @@ -198,7 +198,7 @@ var ( ErrPostClientNotConnected = errors.New("post service not registered") ) -type AtxBuilder interface { +type atxBuilder interface { Register(sig *signing.EdSigner) } diff --git a/activation/malfeasance2.go b/activation/malfeasance2.go index ff44452b35..f6ea4ecee2 100644 --- a/activation/malfeasance2.go +++ b/activation/malfeasance2.go @@ -2,15 +2,43 @@ package activation import ( "context" + "fmt" "github.com/spacemeshos/go-spacemesh/activation/wire" + "github.com/spacemeshos/go-spacemesh/codec" "github.com/spacemeshos/go-spacemesh/common/types" ) -// MalfeasancePublisher is the publisher for ATX proofs. -type MalfeasancePublisher struct{} +type MalfeasanceHandlerV2 struct{} -func (p *MalfeasancePublisher) Publish(ctx context.Context, id types.NodeID, proof wire.Proof) error { +func NewMalfeasanceHandlerV2() *MalfeasanceHandlerV2 { + return &MalfeasanceHandlerV2{} +} + +func (mh *MalfeasanceHandlerV2) decodeProof(data []byte) (wire.Proof, error) { + var atxProof wire.ATXProof + if err := codec.Decode(data, &atxProof); err != nil { + return nil, err + } + + proof, err := atxProof.Decode() + if err != nil { + return nil, err + } + return proof, nil +} + +func (mh *MalfeasanceHandlerV2) Info(data []byte) (map[string]string, error) { + proof, err := mh.decodeProof(data) + if err != nil { + return nil, fmt.Errorf("decoding ATX malfeasance proof: %w", err) + } + info := proof.Info() + info["type"] = proof.String() + return info, nil +} + +func (p *MalfeasanceHandlerV2) Publish(ctx context.Context, id types.NodeID, proof wire.Proof) error { // TODO(mafa): implement me return nil } diff --git a/activation/malfeasance2_test.go b/activation/malfeasance2_test.go new file mode 100644 index 0000000000..99f9a482ac --- /dev/null +++ b/activation/malfeasance2_test.go @@ -0,0 +1,91 @@ +package activation + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/spacemeshos/go-spacemesh/activation/wire" + "github.com/spacemeshos/go-spacemesh/codec" +) + +type testMalHandler struct { + *MalfeasanceHandlerV2 +} + +func newTestMalHandler(tb testing.TB) *testMalHandler { + handler := NewMalfeasanceHandlerV2() + + return &testMalHandler{ + MalfeasanceHandlerV2: handler, + } +} + +func TestHandler_Info(t *testing.T) { + t.Parallel() + + t.Run("decode proof error", func(t *testing.T) { + t.Parallel() + th := newTestMalHandler(t) + + info, err := th.Info([]byte("invalid proof")) + require.Error(t, err) + require.Contains(t, err.Error(), "decoding ATX malfeasance proof") + require.Nil(t, info) + }) + + tt := []struct { + name string + proofType wire.ProofType + proof wire.Proof + }{ + { + name: "double marry proof", + proofType: wire.DoubleMarry, + proof: &wire.ProofDoubleMarry{}, + }, + { + name: "double merge proof", + proofType: wire.DoubleMerge, + proof: &wire.ProofDoubleMerge{}, + }, + { + name: "invalid post", + proofType: wire.InvalidPost, + proof: &wire.ProofInvalidPost{}, + }, + { + name: "invalid prev atx v1", + proofType: wire.InvalidPreviousV1, + proof: &wire.ProofInvalidPrevAtxV1{}, + }, + { + name: "invalid prev atx v2", + proofType: wire.InvalidPreviousV2, + proof: &wire.ProofInvalidPrevAtxV2{}, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + th := newTestMalHandler(t) + + atxProof := &wire.ATXProof{ + Version: wire.ProofVersion(1), + + ProofType: tc.proofType, + Proof: codec.MustEncode(tc.proof), + } + data, err := codec.Encode(atxProof) + require.NoError(t, err) + + expectedInfo := tc.proof.Info() + expectedInfo["type"] = tc.proof.String() + + info, err := th.Info(data) + require.NoError(t, err) + require.Equal(t, expectedInfo, info) + }) + } +} diff --git a/activation/mocks.go b/activation/mocks.go index 00ba677320..594e7f382e 100644 --- a/activation/mocks.go +++ b/activation/mocks.go @@ -2229,62 +2229,62 @@ func (c *MockpoetDbAPIValidateAndStoreCall) DoAndReturn(f func(context.Context, return c } -// MockAtxBuilder is a mock of AtxBuilder interface. -type MockAtxBuilder struct { +// MockatxBuilder is a mock of atxBuilder interface. +type MockatxBuilder struct { ctrl *gomock.Controller - recorder *MockAtxBuilderMockRecorder + recorder *MockatxBuilderMockRecorder isgomock struct{} } -// MockAtxBuilderMockRecorder is the mock recorder for MockAtxBuilder. -type MockAtxBuilderMockRecorder struct { - mock *MockAtxBuilder +// MockatxBuilderMockRecorder is the mock recorder for MockatxBuilder. +type MockatxBuilderMockRecorder struct { + mock *MockatxBuilder } -// NewMockAtxBuilder creates a new mock instance. -func NewMockAtxBuilder(ctrl *gomock.Controller) *MockAtxBuilder { - mock := &MockAtxBuilder{ctrl: ctrl} - mock.recorder = &MockAtxBuilderMockRecorder{mock} +// NewMockatxBuilder creates a new mock instance. +func NewMockatxBuilder(ctrl *gomock.Controller) *MockatxBuilder { + mock := &MockatxBuilder{ctrl: ctrl} + mock.recorder = &MockatxBuilderMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockAtxBuilder) EXPECT() *MockAtxBuilderMockRecorder { +func (m *MockatxBuilder) EXPECT() *MockatxBuilderMockRecorder { return m.recorder } // Register mocks base method. -func (m *MockAtxBuilder) Register(sig *signing.EdSigner) { +func (m *MockatxBuilder) Register(sig *signing.EdSigner) { m.ctrl.T.Helper() m.ctrl.Call(m, "Register", sig) } // Register indicates an expected call of Register. -func (mr *MockAtxBuilderMockRecorder) Register(sig any) *MockAtxBuilderRegisterCall { +func (mr *MockatxBuilderMockRecorder) Register(sig any) *MockatxBuilderRegisterCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Register", reflect.TypeOf((*MockAtxBuilder)(nil).Register), sig) - return &MockAtxBuilderRegisterCall{Call: call} + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Register", reflect.TypeOf((*MockatxBuilder)(nil).Register), sig) + return &MockatxBuilderRegisterCall{Call: call} } -// MockAtxBuilderRegisterCall wrap *gomock.Call -type MockAtxBuilderRegisterCall struct { +// MockatxBuilderRegisterCall wrap *gomock.Call +type MockatxBuilderRegisterCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return -func (c *MockAtxBuilderRegisterCall) Return() *MockAtxBuilderRegisterCall { +func (c *MockatxBuilderRegisterCall) Return() *MockatxBuilderRegisterCall { c.Call = c.Call.Return() return c } // Do rewrite *gomock.Call.Do -func (c *MockAtxBuilderRegisterCall) Do(f func(*signing.EdSigner)) *MockAtxBuilderRegisterCall { +func (c *MockatxBuilderRegisterCall) Do(f func(*signing.EdSigner)) *MockatxBuilderRegisterCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockAtxBuilderRegisterCall) DoAndReturn(f func(*signing.EdSigner)) *MockAtxBuilderRegisterCall { +func (c *MockatxBuilderRegisterCall) DoAndReturn(f func(*signing.EdSigner)) *MockatxBuilderRegisterCall { c.Call = c.Call.DoAndReturn(f) return c } diff --git a/activation/post_supervisor.go b/activation/post_supervisor.go index bc49dd63f1..f3ea4b6300 100644 --- a/activation/post_supervisor.go +++ b/activation/post_supervisor.go @@ -70,7 +70,7 @@ type PostSupervisor struct { provingOpts PostProvingOpts postSetupProvider postSetupProvider - atxBuilder AtxBuilder + atxBuilder atxBuilder pid atomic.Int64 // pid of the running post service, only for tests. @@ -85,7 +85,7 @@ func NewPostSupervisor( postCfg PostConfig, provingOpts PostProvingOpts, postSetupProvider postSetupProvider, - atxBuilder AtxBuilder, + atxBuilder atxBuilder, ) *PostSupervisor { return &PostSupervisor{ logger: logger, diff --git a/activation/post_supervisor_test.go b/activation/post_supervisor_test.go index 597ce47fe2..07dc303293 100644 --- a/activation/post_supervisor_test.go +++ b/activation/post_supervisor_test.go @@ -105,7 +105,7 @@ func Test_PostSupervisor_Start_FailPrepare(t *testing.T) { mgr := NewMockpostSetupProvider(ctrl) testErr := errors.New("test error") mgr.EXPECT().PrepareInitializer(gomock.Any(), postOpts, sig.NodeID()).Return(testErr) - builder := NewMockAtxBuilder(ctrl) + builder := NewMockatxBuilder(ctrl) ps := NewPostSupervisor(log.Named("supervisor"), postCfg, provingOpts, mgr, builder) require.NoError(t, ps.Start(cmdCfg, postOpts, sig)) @@ -141,7 +141,7 @@ func Test_PostSupervisor_Start_FailStartSession(t *testing.T) { mgr := NewMockpostSetupProvider(ctrl) mgr.EXPECT().PrepareInitializer(gomock.Any(), postOpts, sig.NodeID()).Return(nil) mgr.EXPECT().StartSession(gomock.Any(), sig.NodeID()).Return(errors.New("failed start session")) - builder := NewMockAtxBuilder(ctrl) + builder := NewMockatxBuilder(ctrl) ps := NewPostSupervisor(log.Named("supervisor"), postCfg, provingOpts, mgr, builder) require.NoError(t, ps.Start(cmdCfg, postOpts, sig)) @@ -160,7 +160,7 @@ func Test_PostSupervisor_StartsServiceCmd(t *testing.T) { ctrl := gomock.NewController(t) mgr := newPostManager(t, postCfg, postOpts) - builder := NewMockAtxBuilder(ctrl) + builder := NewMockatxBuilder(ctrl) builder.EXPECT().Register(sig) ps := NewPostSupervisor(log.Named("supervisor"), postCfg, provingOpts, mgr, builder) @@ -197,7 +197,7 @@ func Test_PostSupervisor_Restart_Possible(t *testing.T) { ctrl := gomock.NewController(t) mgr := newPostManager(t, postCfg, postOpts) - builder := NewMockAtxBuilder(ctrl) + builder := NewMockatxBuilder(ctrl) builder.EXPECT().Register(sig) ps := NewPostSupervisor(log.Named("supervisor"), postCfg, provingOpts, mgr, builder) @@ -228,7 +228,7 @@ func Test_PostSupervisor_LogFatalOnCrash(t *testing.T) { ctrl := gomock.NewController(t) mgr := newPostManager(t, postCfg, postOpts) - builder := NewMockAtxBuilder(ctrl) + builder := NewMockatxBuilder(ctrl) builder.EXPECT().Register(sig) ps := NewPostSupervisor(log.Named("supervisor"), postCfg, provingOpts, mgr, builder) @@ -261,7 +261,7 @@ func Test_PostSupervisor_LogFatalOnInvalidConfig(t *testing.T) { ctrl := gomock.NewController(t) mgr := newPostManager(t, postCfg, postOpts) - builder := NewMockAtxBuilder(ctrl) + builder := NewMockatxBuilder(ctrl) builder.EXPECT().Register(sig) ps := NewPostSupervisor(log.Named("supervisor"), postCfg, provingOpts, mgr, builder) @@ -301,7 +301,7 @@ func Test_PostSupervisor_StopOnError(t *testing.T) { require.NoError(t, err) return nil }) - builder := NewMockAtxBuilder(ctrl) + builder := NewMockatxBuilder(ctrl) builder.EXPECT().Register(sig) ps := NewPostSupervisor(log.Named("supervisor"), postCfg, provingOpts, mgr, builder) @@ -322,7 +322,7 @@ func Test_PostSupervisor_Providers_includesCPU(t *testing.T) { ctrl := gomock.NewController(t) mgr := NewMockpostSetupProvider(ctrl) - builder := NewMockAtxBuilder(ctrl) + builder := NewMockatxBuilder(ctrl) ps := NewPostSupervisor(log.Named("supervisor"), postCfg, provingOpts, mgr, builder) providers, err := ps.Providers() @@ -344,7 +344,7 @@ func Test_PostSupervisor_Benchmark(t *testing.T) { ctrl := gomock.NewController(t) mgr := NewMockpostSetupProvider(ctrl) - builder := NewMockAtxBuilder(ctrl) + builder := NewMockatxBuilder(ctrl) ps := NewPostSupervisor(log.Named("supervisor"), postCfg, provingOpts, mgr, builder) providers, err := ps.Providers() diff --git a/activation/post_verifier_test.go b/activation/post_verifier_test.go index c02b329387..34a26667ca 100644 --- a/activation/post_verifier_test.go +++ b/activation/post_verifier_test.go @@ -23,9 +23,9 @@ func TestOffloadingPostVerifier(t *testing.T) { verifier := NewMockPostVerifier(gomock.NewController(t)) offloadingVerifier := newOffloadingPostVerifier(verifier, 1, zaptest.NewLogger(t)) defer offloadingVerifier.Close() - verifier.EXPECT().Close().Return(nil) + verifier.EXPECT().Close() - verifier.EXPECT().Verify(gomock.Any(), &proof, &metadata, gomock.Any()).Return(nil) + verifier.EXPECT().Verify(gomock.Any(), &proof, &metadata, gomock.Any()) err := offloadingVerifier.Verify(context.Background(), &proof, &metadata) require.NoError(t, err) @@ -49,12 +49,12 @@ func TestPostVerifierVerifyAfterStop(t *testing.T) { offloadingVerifier := newOffloadingPostVerifier(verifier, 1, zaptest.NewLogger(t)) defer offloadingVerifier.Close() - verifier.EXPECT().Verify(gomock.Any(), &proof, &metadata, gomock.Any()).Return(nil) + verifier.EXPECT().Verify(gomock.Any(), &proof, &metadata, gomock.Any()) err := offloadingVerifier.Verify(context.Background(), &proof, &metadata) require.NoError(t, err) // Stop the verifier - verifier.EXPECT().Close().Return(nil) + verifier.EXPECT().Close() offloadingVerifier.Close() err = offloadingVerifier.Verify(context.Background(), &proof, &metadata) @@ -69,8 +69,8 @@ func TestPostVerifierNoRaceOnClose(t *testing.T) { offloadingVerifier := newOffloadingPostVerifier(verifier, 1, zaptest.NewLogger(t)) defer offloadingVerifier.Close() - verifier.EXPECT().Close().AnyTimes().Return(nil) - verifier.EXPECT().Verify(gomock.Any(), &proof, &metadata, gomock.Any()).AnyTimes().Return(nil) + verifier.EXPECT().Close().AnyTimes() + verifier.EXPECT().Verify(gomock.Any(), &proof, &metadata, gomock.Any()).AnyTimes() // Stop the verifier var eg errgroup.Group @@ -95,7 +95,7 @@ func TestPostVerifierClose(t *testing.T) { // 0 workers - no one will verify the proof v := newOffloadingPostVerifier(verifier, 0, zaptest.NewLogger(t)) - verifier.EXPECT().Close().Return(nil) + verifier.EXPECT().Close() require.NoError(t, v.Close()) err := v.Verify(context.Background(), &shared.Proof{}, &shared.ProofMetadata{}) @@ -107,28 +107,17 @@ func TestPostVerifierPrioritization(t *testing.T) { verifier := NewMockPostVerifier(gomock.NewController(t)) v := newOffloadingPostVerifier(verifier, 2, zaptest.NewLogger(t), nodeID) - verifier.EXPECT(). - Verify(gomock.Any(), gomock.Any(), &shared.ProofMetadata{NodeId: nodeID.Bytes()}, gomock.Any()). - Return(nil) + verifier.EXPECT().Verify(gomock.Any(), gomock.Any(), &shared.ProofMetadata{NodeId: nodeID.Bytes()}, gomock.Any()) err := v.Verify(context.Background(), &shared.Proof{}, &shared.ProofMetadata{NodeId: nodeID.Bytes()}) require.NoError(t, err) - verifier.EXPECT(). - Verify( - context.Background(), - gomock.Any(), - &shared.ProofMetadata{}, gomock.Any()). - Return(nil) - - err = v.Verify( - context.Background(), - &shared.Proof{}, - &shared.ProofMetadata{}, - PrioritizedCall()) + verifier.EXPECT().Verify(context.Background(), gomock.Any(), &shared.ProofMetadata{}, gomock.Any()) + + err = v.Verify(context.Background(), &shared.Proof{}, &shared.ProofMetadata{}, PrioritizedCall()) require.NoError(t, err) - verifier.EXPECT().Close().Return(nil) + verifier.EXPECT().Close() require.NoError(t, v.Close()) } diff --git a/activation/wire/interface.go b/activation/wire/interface.go index ba5006e3cd..9475113a60 100644 --- a/activation/wire/interface.go +++ b/activation/wire/interface.go @@ -2,6 +2,9 @@ package wire import ( "context" + "fmt" + + "github.com/spacemeshos/go-scale" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/signing" @@ -25,3 +28,15 @@ type MalfeasanceValidator interface { // Signature validates the given signature against the given message and public key. Signature(d signing.Domain, nodeID types.NodeID, m []byte, sig types.EdSignature) bool } + +// Proof is an interface for all types of proofs that can be provided in an ATXProof. +// Generally the proof should be able to validate itself and be scale encoded. +type Proof interface { + scale.Encodable + scale.Decodable + fmt.Stringer + + Type() ProofType + Info() map[string]string + Valid(ctx context.Context, malHandler MalfeasanceValidator) (types.NodeID, error) +} diff --git a/activation/wire/malfeasance.go b/activation/wire/malfeasance.go index 7e2aa97b98..c38b51baf7 100644 --- a/activation/wire/malfeasance.go +++ b/activation/wire/malfeasance.go @@ -1,11 +1,9 @@ package wire import ( - "context" + "fmt" - "github.com/spacemeshos/go-scale" - - "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/codec" ) //go:generate scalegen @@ -62,16 +60,27 @@ type ProofType byte const ( // TODO(mafa): legacy types for future migration to new malfeasance proofs. - LegacyDoublePublish ProofType = 0x00 - LegacyInvalidPost ProofType = 0x01 - LegacyInvalidPrevATX ProofType = 0x02 - - DoubleMarry ProofType = 0x10 - DoubleMerge ProofType = 0x11 - InvalidPost ProofType = 0x12 - InvalidPrevious ProofType = 0x13 + LegacyDoublePublish ProofType = 0x01 + LegacyInvalidPost ProofType = 0x02 + LegacyInvalidPrevATX ProofType = 0x03 + + DoubleMarry ProofType = 0x11 + DoubleMerge ProofType = 0x12 + InvalidPost ProofType = 0x13 + InvalidPreviousV1 ProofType = 0x14 + InvalidPreviousV2 ProofType = 0x15 ) +var proofTypes = map[ProofType]Proof{ + // TODO(mafa): legacy proofs + + DoubleMarry: &ProofDoubleMarry{}, + DoubleMerge: &ProofDoubleMerge{}, + InvalidPost: &ProofInvalidPost{}, + InvalidPreviousV1: &ProofInvalidPrevAtxV1{}, + InvalidPreviousV2: &ProofInvalidPrevAtxV2{}, +} + // ProofVersion is an identifier for the version of the proof that is encoded in the ATXProof. type ProofVersion byte @@ -84,10 +93,13 @@ type ATXProof struct { Proof []byte `scale:"max=1048576"` // max size of proof is 1MiB } -// Proof is an interface for all types of proofs that can be provided in an ATXProof. -// Generally the proof should be able to validate itself and be scale encoded. -type Proof interface { - scale.Encodable - - Valid(ctx context.Context, malHandler MalfeasanceValidator) (types.NodeID, error) +func (p *ATXProof) Decode() (Proof, error) { + rst, ok := proofTypes[p.ProofType] + if !ok { + return nil, fmt.Errorf("unknown ATX malfeasance proof type: 0x%x", p.ProofType) + } + if err := codec.Decode(p.Proof, rst); err != nil { + return nil, fmt.Errorf("decoding ATX malfeasance proof of type 0x%x: %w", p.ProofType, err) + } + return rst, nil } diff --git a/activation/wire/malfeasance_double_marry.go b/activation/wire/malfeasance_double_marry.go index fc2a98e545..32b6cba763 100644 --- a/activation/wire/malfeasance_double_marry.go +++ b/activation/wire/malfeasance_double_marry.go @@ -23,8 +23,8 @@ type ProofDoubleMarry struct { // NodeID is the node ID that married twice. NodeID types.NodeID - // ATX1 is the ID of the ATX being proven to have the marriage certificate of interest. - ATX1 types.ATXID + // ATXID1 is the ID of the ATX being proven to have the marriage certificate of interest. + ATXID1 types.ATXID // SmesherID1 is the ID of the smesher that published ATX1. SmesherID1 types.NodeID // Signature1 is the signature of the ATXID by the smesher. @@ -32,8 +32,8 @@ type ProofDoubleMarry struct { // Proof1 is the proof that the marriage certificate is contained in the ATX1. Proof1 MarryProof - // ATX2 is the ID of the ATX being proven to have the marriage certificate of interest. - ATX2 types.ATXID + // ATXID2 is the ID of the ATX being proven to have the marriage certificate of interest. + ATXID2 types.ATXID // SmesherID2 is the ID of the smesher that published ATX2. SmesherID2 types.NodeID // Signature2 is the signature of the ATXID by the smesher. @@ -42,6 +42,24 @@ type ProofDoubleMarry struct { Proof2 MarryProof } +func (p ProofDoubleMarry) String() string { + return "DoubleMarryProof" +} + +func (p ProofDoubleMarry) Type() ProofType { + return DoubleMarry +} + +func (p ProofDoubleMarry) Info() map[string]string { + return map[string]string{ + "node_id": p.NodeID.String(), + "atx1": p.ATXID1.String(), + "smesher_id1": p.SmesherID1.String(), + "atx2": p.ATXID2.String(), + "smesher_id2": p.SmesherID2.String(), + } +} + var _ Proof = &ProofDoubleMarry{} func NewDoubleMarryProof(db sql.Executor, atx1, atx2 *ActivationTxV2, nodeID types.NodeID) (*ProofDoubleMarry, error) { @@ -61,12 +79,12 @@ func NewDoubleMarryProof(db sql.Executor, atx1, atx2 *ActivationTxV2, nodeID typ return &ProofDoubleMarry{ NodeID: nodeID, - ATX1: atx1.ID(), + ATXID1: atx1.ID(), SmesherID1: atx1.SmesherID, Signature1: atx1.Signature, Proof1: proof1, - ATX2: atx2.ID(), + ATXID2: atx2.ID(), SmesherID2: atx2.SmesherID, Signature2: atx2.Signature, Proof2: proof2, @@ -74,19 +92,19 @@ func NewDoubleMarryProof(db sql.Executor, atx1, atx2 *ActivationTxV2, nodeID typ } func (p ProofDoubleMarry) Valid(_ context.Context, malValidator MalfeasanceValidator) (types.NodeID, error) { - if p.ATX1 == p.ATX2 { + if p.ATXID1 == p.ATXID2 { return types.EmptyNodeID, errors.New("proofs have the same ATX ID") } - if !malValidator.Signature(signing.ATX, p.SmesherID1, p.ATX1.Bytes(), p.Signature1) { + if !malValidator.Signature(signing.ATX, p.SmesherID1, p.ATXID1.Bytes(), p.Signature1) { return types.EmptyNodeID, errors.New("invalid signature for ATX1") } - if !malValidator.Signature(signing.ATX, p.SmesherID2, p.ATX2.Bytes(), p.Signature2) { + if !malValidator.Signature(signing.ATX, p.SmesherID2, p.ATXID2.Bytes(), p.Signature2) { return types.EmptyNodeID, errors.New("invalid signature for ATX2") } - if err := p.Proof1.Valid(malValidator, p.ATX1, p.SmesherID1, p.NodeID); err != nil { + if err := p.Proof1.Valid(malValidator, p.ATXID1, p.SmesherID1, p.NodeID); err != nil { return types.EmptyNodeID, fmt.Errorf("proof 1 is invalid: %w", err) } - if err := p.Proof2.Valid(malValidator, p.ATX2, p.SmesherID2, p.NodeID); err != nil { + if err := p.Proof2.Valid(malValidator, p.ATXID2, p.SmesherID2, p.NodeID); err != nil { return types.EmptyNodeID, fmt.Errorf("proof 2 is invalid: %w", err) } return p.NodeID, nil diff --git a/activation/wire/malfeasance_double_marry_scale.go b/activation/wire/malfeasance_double_marry_scale.go index d7f3855020..89dcf45a25 100644 --- a/activation/wire/malfeasance_double_marry_scale.go +++ b/activation/wire/malfeasance_double_marry_scale.go @@ -16,7 +16,7 @@ func (t *ProofDoubleMarry) EncodeScale(enc *scale.Encoder) (total int, err error total += n } { - n, err := scale.EncodeByteArray(enc, t.ATX1[:]) + n, err := scale.EncodeByteArray(enc, t.ATXID1[:]) if err != nil { return total, err } @@ -44,7 +44,7 @@ func (t *ProofDoubleMarry) EncodeScale(enc *scale.Encoder) (total int, err error total += n } { - n, err := scale.EncodeByteArray(enc, t.ATX2[:]) + n, err := scale.EncodeByteArray(enc, t.ATXID2[:]) if err != nil { return total, err } @@ -83,7 +83,7 @@ func (t *ProofDoubleMarry) DecodeScale(dec *scale.Decoder) (total int, err error total += n } { - n, err := scale.DecodeByteArray(dec, t.ATX1[:]) + n, err := scale.DecodeByteArray(dec, t.ATXID1[:]) if err != nil { return total, err } @@ -111,7 +111,7 @@ func (t *ProofDoubleMarry) DecodeScale(dec *scale.Decoder) (total int, err error total += n } { - n, err := scale.DecodeByteArray(dec, t.ATX2[:]) + n, err := scale.DecodeByteArray(dec, t.ATXID2[:]) if err != nil { return total, err } diff --git a/activation/wire/malfeasance_double_marry_test.go b/activation/wire/malfeasance_double_marry_test.go index f52f8c8559..42ffcb839d 100644 --- a/activation/wire/malfeasance_double_marry_test.go +++ b/activation/wire/malfeasance_double_marry_test.go @@ -117,8 +117,8 @@ func Test_DoubleMarryProof(t *testing.T) { // manually construct an invalid proof proof = &ProofDoubleMarry{ - ATX1: atx1.ID(), - ATX2: atx1.ID(), + ATXID1: atx1.ID(), + ATXID2: atx1.ID(), } ctrl := gomock.NewController(t) @@ -189,17 +189,17 @@ func Test_DoubleMarryProof(t *testing.T) { proof.SmesherID2 = atx2.SmesherID // invalid ATX ID for ATX1 - proof.ATX1 = types.RandomATXID() + proof.ATXID1 = types.RandomATXID() id, err = proof.Valid(context.Background(), verifier) require.ErrorContains(t, err, "invalid signature for ATX1") require.Equal(t, types.EmptyNodeID, id) - proof.ATX1 = atx1.ID() + proof.ATXID1 = atx1.ID() // invalid ATX ID for ATX2 - proof.ATX2 = types.RandomATXID() + proof.ATXID2 = types.RandomATXID() id, err = proof.Valid(context.Background(), verifier) require.ErrorContains(t, err, "invalid signature for ATX2") require.Equal(t, types.EmptyNodeID, id) - proof.ATX2 = atx2.ID() + proof.ATXID2 = atx2.ID() }) } diff --git a/activation/wire/malfeasance_double_merge.go b/activation/wire/malfeasance_double_merge.go index 3b3f73194a..b84f191d6c 100644 --- a/activation/wire/malfeasance_double_merge.go +++ b/activation/wire/malfeasance_double_merge.go @@ -59,6 +59,25 @@ type ProofDoubleMerge struct { SmesherID2MarryProof MarryProof } +func (p ProofDoubleMerge) String() string { + return "DoubleMergeProof" +} + +func (p ProofDoubleMerge) Type() ProofType { + return DoubleMerge +} + +func (p ProofDoubleMerge) Info() map[string]string { + return map[string]string{ + "publish_epoch": p.PublishEpoch.String(), + "marriage_atx": p.MarriageATX.String(), + "atx1": p.ATXID1.String(), + "smesher_id1": p.SmesherID1.String(), + "atx2": p.ATXID2.String(), + "smesher_id2": p.SmesherID2.String(), + } +} + var _ Proof = &ProofDoubleMerge{} func NewDoubleMergeProof(db sql.Executor, atx1, atx2 *ActivationTxV2) (*ProofDoubleMerge, error) { diff --git a/activation/wire/malfeasance_invalid_post.go b/activation/wire/malfeasance_invalid_post.go index cb380ba88d..9dfd7a6cd8 100644 --- a/activation/wire/malfeasance_invalid_post.go +++ b/activation/wire/malfeasance_invalid_post.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "slices" + "strconv" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/signing" @@ -38,6 +39,23 @@ type ProofInvalidPost struct { InvalidPostProof InvalidPostProof } +func (p ProofInvalidPost) String() string { + return "InvalidPoSTProof" +} + +func (p ProofInvalidPost) Type() ProofType { + return InvalidPost +} + +func (p ProofInvalidPost) Info() map[string]string { + return map[string]string{ + "atx": p.ATXID.String(), + "index": strconv.FormatUint(uint64(p.InvalidPostProof.InvalidPostIndex), 10), + "post_node_id": p.NodeID.String(), + "smesher_id": p.SmesherID.String(), + } +} + var _ Proof = &ProofInvalidPost{} func NewInvalidPostProof( diff --git a/activation/wire/malfeasance_invalid_prev_atx.go b/activation/wire/malfeasance_invalid_prev_atx.go index abb5acb5f6..69c8558538 100644 --- a/activation/wire/malfeasance_invalid_prev_atx.go +++ b/activation/wire/malfeasance_invalid_prev_atx.go @@ -26,12 +26,31 @@ type ProofInvalidPrevAtxV2 struct { // NodeID is the node ID that referenced the same previous ATX twice. NodeID types.NodeID - // PrevATX is the ATX that was referenced twice. - PrevATX types.ATXID + // PrevATXID is the ATX that was referenced twice. + PrevATXID types.ATXID Proofs [2]InvalidPrevAtxProof } +func (p ProofInvalidPrevAtxV2) String() string { + return "InvalidPreviousATXProofV2" +} + +func (p ProofInvalidPrevAtxV2) Type() ProofType { + return InvalidPreviousV2 +} + +func (p ProofInvalidPrevAtxV2) Info() map[string]string { + return map[string]string{ + "prev_atx": p.PrevATXID.String(), + "node_id": p.NodeID.String(), + "atx1": p.Proofs[0].ATXID.String(), + "smesher_id1": p.Proofs[0].SmesherID.String(), + "atx2": p.Proofs[1].ATXID.String(), + "smesher_id2": p.Proofs[1].SmesherID.String(), + } +} + var _ Proof = &ProofInvalidPrevAtxV2{} func NewInvalidPrevAtxProofV2( @@ -114,9 +133,9 @@ func NewInvalidPrevAtxProofV2( } proof := &ProofInvalidPrevAtxV2{ - NodeID: nodeID, - PrevATX: prevATX1, - Proofs: [2]InvalidPrevAtxProof{proof1, proof2}, + NodeID: nodeID, + PrevATXID: prevATX1, + Proofs: [2]InvalidPrevAtxProof{proof1, proof2}, } return proof, nil } @@ -161,10 +180,10 @@ func (p ProofInvalidPrevAtxV2) Valid(_ context.Context, malValidator Malfeasance if p.Proofs[0].ATXID == p.Proofs[1].ATXID { return types.EmptyNodeID, errors.New("proofs have the same ATX ID") } - if err := p.Proofs[0].Valid(p.PrevATX, p.NodeID, malValidator); err != nil { + if err := p.Proofs[0].Valid(p.PrevATXID, p.NodeID, malValidator); err != nil { return types.EmptyNodeID, fmt.Errorf("proof 1 is invalid: %w", err) } - if err := p.Proofs[1].Valid(p.PrevATX, p.NodeID, malValidator); err != nil { + if err := p.Proofs[1].Valid(p.PrevATXID, p.NodeID, malValidator); err != nil { return types.EmptyNodeID, fmt.Errorf("proof 2 is invalid: %w", err) } return p.NodeID, nil @@ -183,13 +202,32 @@ type ProofInvalidPrevAtxV1 struct { // NodeID is the node ID that referenced the same previous ATX twice. NodeID types.NodeID - // PrevATX is the ATX that was referenced twice. - PrevATX types.ATXID + // PrevATXID is the ATX that was referenced twice. + PrevATXID types.ATXID Proof InvalidPrevAtxProof ATXv1 ActivationTxV1 } +func (p ProofInvalidPrevAtxV1) String() string { + return "InvalidPreviousATXProofV1" +} + +func (p ProofInvalidPrevAtxV1) Type() ProofType { + return InvalidPreviousV1 +} + +func (p ProofInvalidPrevAtxV1) Info() map[string]string { + return map[string]string{ + "prev_atx": p.PrevATXID.String(), + "node_id": p.NodeID.String(), + "atx1": p.Proof.ATXID.String(), + "smesher_id1": p.Proof.SmesherID.String(), + "atx2": p.ATXv1.ID().String(), + "smesher_id2": p.ATXv1.SmesherID.String(), + } +} + var _ Proof = &ProofInvalidPrevAtxV1{} func NewInvalidPrevAtxProofV1( @@ -240,15 +278,15 @@ func NewInvalidPrevAtxProofV1( } return &ProofInvalidPrevAtxV1{ - NodeID: nodeID, - PrevATX: prevATX1, - Proof: proof, - ATXv1: *atx2, + NodeID: nodeID, + PrevATXID: prevATX1, + Proof: proof, + ATXv1: *atx2, }, nil } func (p ProofInvalidPrevAtxV1) Valid(_ context.Context, malValidator MalfeasanceValidator) (types.NodeID, error) { - if err := p.Proof.Valid(p.PrevATX, p.NodeID, malValidator); err != nil { + if err := p.Proof.Valid(p.PrevATXID, p.NodeID, malValidator); err != nil { return types.EmptyNodeID, fmt.Errorf("proof is invalid: %w", err) } if !malValidator.Signature(signing.ATX, p.ATXv1.SmesherID, p.ATXv1.SignedBytes(), p.ATXv1.Signature) { @@ -257,7 +295,7 @@ func (p ProofInvalidPrevAtxV1) Valid(_ context.Context, malValidator Malfeasance if p.NodeID != p.ATXv1.SmesherID { return types.EmptyNodeID, errors.New("ATXv1 has not been signed by the same identity") } - if p.ATXv1.PrevATXID != p.PrevATX { + if p.ATXv1.PrevATXID != p.PrevATXID { return types.EmptyNodeID, errors.New("ATXv1 references a different previous ATX") } return p.NodeID, nil diff --git a/activation/wire/malfeasance_invalid_prev_atx_scale.go b/activation/wire/malfeasance_invalid_prev_atx_scale.go index 15b06acc28..f2c455b2cd 100644 --- a/activation/wire/malfeasance_invalid_prev_atx_scale.go +++ b/activation/wire/malfeasance_invalid_prev_atx_scale.go @@ -17,7 +17,7 @@ func (t *ProofInvalidPrevAtxV2) EncodeScale(enc *scale.Encoder) (total int, err total += n } { - n, err := scale.EncodeByteArray(enc, t.PrevATX[:]) + n, err := scale.EncodeByteArray(enc, t.PrevATXID[:]) if err != nil { return total, err } @@ -42,7 +42,7 @@ func (t *ProofInvalidPrevAtxV2) DecodeScale(dec *scale.Decoder) (total int, err total += n } { - n, err := scale.DecodeByteArray(dec, t.PrevATX[:]) + n, err := scale.DecodeByteArray(dec, t.PrevATXID[:]) if err != nil { return total, err } @@ -67,7 +67,7 @@ func (t *ProofInvalidPrevAtxV1) EncodeScale(enc *scale.Encoder) (total int, err total += n } { - n, err := scale.EncodeByteArray(enc, t.PrevATX[:]) + n, err := scale.EncodeByteArray(enc, t.PrevATXID[:]) if err != nil { return total, err } @@ -99,7 +99,7 @@ func (t *ProofInvalidPrevAtxV1) DecodeScale(dec *scale.Decoder) (total int, err total += n } { - n, err := scale.DecodeByteArray(dec, t.PrevATX[:]) + n, err := scale.DecodeByteArray(dec, t.PrevATXID[:]) if err != nil { return total, err } diff --git a/activation/wire/malfeasance_invalid_prev_atx_test.go b/activation/wire/malfeasance_invalid_prev_atx_test.go index 1c829e74eb..c07529c64f 100644 --- a/activation/wire/malfeasance_invalid_prev_atx_test.go +++ b/activation/wire/malfeasance_invalid_prev_atx_test.go @@ -317,11 +317,11 @@ func Test_InvalidPrevAtxProofV2(t *testing.T) { proof.Proofs[0].ATXID = atx1.ID() // invalid prev ATX - proof.PrevATX = types.RandomATXID() + proof.PrevATXID = types.RandomATXID() id, err = proof.Valid(context.Background(), verifier) require.ErrorContains(t, err, "invalid previous ATX proof") require.Equal(t, types.EmptyNodeID, id) - proof.PrevATX = prevATXID + proof.PrevATXID = prevATXID // invalid node ID proof.NodeID = types.RandomNodeID() @@ -935,11 +935,11 @@ func Test_InvalidPrevAtxProofV1(t *testing.T) { }).AnyTimes() // invalid PrevATX - proof.PrevATX = types.RandomATXID() + proof.PrevATXID = types.RandomATXID() id, err := proof.Valid(context.Background(), verifier) require.ErrorContains(t, err, "invalid previous ATX proof") require.Equal(t, types.EmptyNodeID, id) - proof.PrevATX = prevATX + proof.PrevATXID = prevATX // invalid SmesherID for atxv1 proof.ATXv1.SmesherID = types.RandomNodeID() diff --git a/activation/wire/malfeasance_test.go b/activation/wire/malfeasance_test.go new file mode 100644 index 0000000000..2cd41288a4 --- /dev/null +++ b/activation/wire/malfeasance_test.go @@ -0,0 +1,89 @@ +package wire + +import ( + "testing" + + fuzz "github.com/google/gofuzz" + "github.com/stretchr/testify/require" + + "github.com/spacemeshos/go-spacemesh/codec" +) + +func fuzzDecoding[T Proof](t *testing.T, data []byte, proof T) { + fuzzer := fuzz.NewFromGoFuzz(data) + fuzzer.Fuzz(proof) + + atxProof := &ATXProof{ + Version: 0x01, + ProofType: proof.Type(), + + Proof: codec.MustEncode(proof), + } + + encodedAtxProof := codec.MustEncode(atxProof) + decodedAtxProof := &ATXProof{} + codec.MustDecode(encodedAtxProof, decodedAtxProof) + + decodedProof, err := decodedAtxProof.Decode() + require.NoError(t, err) + + require.Equal(t, proof, decodedProof.(T)) +} + +func FuzzATXProofDecodeDoubleMarry(f *testing.F) { + f.Add([]byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06}) + f.Fuzz(func(t *testing.T, data []byte) { + fuzzDecoding(t, data, &ProofDoubleMarry{}) + }) +} + +func FuzzATXProofDecodeDoubleMerge(f *testing.F) { + f.Add([]byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06}) + f.Fuzz(func(t *testing.T, data []byte) { + fuzzDecoding(t, data, &ProofDoubleMerge{}) + }) +} + +func FuzzATXProofDecodeInvalidPost(f *testing.F) { + f.Add([]byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06}) + f.Fuzz(func(t *testing.T, data []byte) { + fuzzDecoding(t, data, &ProofInvalidPost{}) + }) +} + +func FuzzATXProofDecodeInvalidPrevAtxV1(f *testing.F) { + f.Add([]byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06}) + f.Fuzz(func(t *testing.T, data []byte) { + fuzzDecoding(t, data, &ProofInvalidPrevAtxV1{}) + }) +} + +func FuzzATXProofDecodeInvalidPrevAtxV2(f *testing.F) { + f.Add([]byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06}) + f.Fuzz(func(t *testing.T, data []byte) { + fuzzDecoding(t, data, &ProofInvalidPrevAtxV2{}) + }) +} + +func TestDecode(t *testing.T) { + t.Run("unknown proof type", func(t *testing.T) { + atxProof := &ATXProof{ + Version: 0x01, + ProofType: 0x42, // unknown proof type + } + + _, err := atxProof.Decode() + require.ErrorContains(t, err, "unknown ATX malfeasance proof type") + }) + + t.Run("atx proof fails decoding", func(t *testing.T) { + atxProof := &ATXProof{ + Version: 0x01, + ProofType: DoubleMarry, + Proof: []byte{}, // invalid proof + } + + _, err := atxProof.Decode() + require.ErrorContains(t, err, "decoding ATX malfeasance proof of type 0x11") + }) +} diff --git a/activation/wire/mocks.go b/activation/wire/mocks.go index ae0fd1be61..a6f97f8087 100644 --- a/activation/wire/mocks.go +++ b/activation/wire/mocks.go @@ -13,6 +13,7 @@ import ( context "context" reflect "reflect" + scale "github.com/spacemeshos/go-scale" types "github.com/spacemeshos/go-spacemesh/common/types" signing "github.com/spacemeshos/go-spacemesh/signing" gomock "go.uber.org/mock/gomock" @@ -117,3 +118,258 @@ func (c *MockMalfeasanceValidatorSignatureCall) DoAndReturn(f func(signing.Domai c.Call = c.Call.DoAndReturn(f) return c } + +// MockProof is a mock of Proof interface. +type MockProof struct { + ctrl *gomock.Controller + recorder *MockProofMockRecorder + isgomock struct{} +} + +// MockProofMockRecorder is the mock recorder for MockProof. +type MockProofMockRecorder struct { + mock *MockProof +} + +// NewMockProof creates a new mock instance. +func NewMockProof(ctrl *gomock.Controller) *MockProof { + mock := &MockProof{ctrl: ctrl} + mock.recorder = &MockProofMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockProof) EXPECT() *MockProofMockRecorder { + return m.recorder +} + +// DecodeScale mocks base method. +func (m *MockProof) DecodeScale(dec *scale.Decoder) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DecodeScale", dec) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DecodeScale indicates an expected call of DecodeScale. +func (mr *MockProofMockRecorder) DecodeScale(dec any) *MockProofDecodeScaleCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecodeScale", reflect.TypeOf((*MockProof)(nil).DecodeScale), dec) + return &MockProofDecodeScaleCall{Call: call} +} + +// MockProofDecodeScaleCall wrap *gomock.Call +type MockProofDecodeScaleCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockProofDecodeScaleCall) Return(arg0 int, arg1 error) *MockProofDecodeScaleCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockProofDecodeScaleCall) Do(f func(*scale.Decoder) (int, error)) *MockProofDecodeScaleCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockProofDecodeScaleCall) DoAndReturn(f func(*scale.Decoder) (int, error)) *MockProofDecodeScaleCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// EncodeScale mocks base method. +func (m *MockProof) EncodeScale(enc *scale.Encoder) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "EncodeScale", enc) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// EncodeScale indicates an expected call of EncodeScale. +func (mr *MockProofMockRecorder) EncodeScale(enc any) *MockProofEncodeScaleCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EncodeScale", reflect.TypeOf((*MockProof)(nil).EncodeScale), enc) + return &MockProofEncodeScaleCall{Call: call} +} + +// MockProofEncodeScaleCall wrap *gomock.Call +type MockProofEncodeScaleCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockProofEncodeScaleCall) Return(arg0 int, arg1 error) *MockProofEncodeScaleCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockProofEncodeScaleCall) Do(f func(*scale.Encoder) (int, error)) *MockProofEncodeScaleCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockProofEncodeScaleCall) DoAndReturn(f func(*scale.Encoder) (int, error)) *MockProofEncodeScaleCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Info mocks base method. +func (m *MockProof) Info() map[string]string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Info") + ret0, _ := ret[0].(map[string]string) + return ret0 +} + +// Info indicates an expected call of Info. +func (mr *MockProofMockRecorder) Info() *MockProofInfoCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Info", reflect.TypeOf((*MockProof)(nil).Info)) + return &MockProofInfoCall{Call: call} +} + +// MockProofInfoCall wrap *gomock.Call +type MockProofInfoCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockProofInfoCall) Return(arg0 map[string]string) *MockProofInfoCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockProofInfoCall) Do(f func() map[string]string) *MockProofInfoCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockProofInfoCall) DoAndReturn(f func() map[string]string) *MockProofInfoCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// String mocks base method. +func (m *MockProof) String() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "String") + ret0, _ := ret[0].(string) + return ret0 +} + +// String indicates an expected call of String. +func (mr *MockProofMockRecorder) String() *MockProofStringCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "String", reflect.TypeOf((*MockProof)(nil).String)) + return &MockProofStringCall{Call: call} +} + +// MockProofStringCall wrap *gomock.Call +type MockProofStringCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockProofStringCall) Return(arg0 string) *MockProofStringCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockProofStringCall) Do(f func() string) *MockProofStringCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockProofStringCall) DoAndReturn(f func() string) *MockProofStringCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Type mocks base method. +func (m *MockProof) Type() ProofType { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Type") + ret0, _ := ret[0].(ProofType) + return ret0 +} + +// Type indicates an expected call of Type. +func (mr *MockProofMockRecorder) Type() *MockProofTypeCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Type", reflect.TypeOf((*MockProof)(nil).Type)) + return &MockProofTypeCall{Call: call} +} + +// MockProofTypeCall wrap *gomock.Call +type MockProofTypeCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockProofTypeCall) Return(arg0 ProofType) *MockProofTypeCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockProofTypeCall) Do(f func() ProofType) *MockProofTypeCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockProofTypeCall) DoAndReturn(f func() ProofType) *MockProofTypeCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Valid mocks base method. +func (m *MockProof) Valid(ctx context.Context, malHandler MalfeasanceValidator) (types.NodeID, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Valid", ctx, malHandler) + ret0, _ := ret[0].(types.NodeID) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Valid indicates an expected call of Valid. +func (mr *MockProofMockRecorder) Valid(ctx, malHandler any) *MockProofValidCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Valid", reflect.TypeOf((*MockProof)(nil).Valid), ctx, malHandler) + return &MockProofValidCall{Call: call} +} + +// MockProofValidCall wrap *gomock.Call +type MockProofValidCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockProofValidCall) Return(arg0 types.NodeID, arg1 error) *MockProofValidCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockProofValidCall) Do(f func(context.Context, MalfeasanceValidator) (types.NodeID, error)) *MockProofValidCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockProofValidCall) DoAndReturn(f func(context.Context, MalfeasanceValidator) (types.NodeID, error)) *MockProofValidCall { + c.Call = c.Call.DoAndReturn(f) + return c +} diff --git a/activation/wire/wire_v2.go b/activation/wire/wire_v2.go index 696c4b332d..eb7d222feb 100644 --- a/activation/wire/wire_v2.go +++ b/activation/wire/wire_v2.go @@ -465,7 +465,7 @@ type SubPostV2 struct { // Can be used to extract the nodeID and verify if it is married with the smesher of the ATX. // Must be 0 for non-merged ATXs. MarriageIndex uint32 - PrevATXIndex uint32 // Index of the previous ATX in the `InnerActivationTxV2.PreviousATXs` slice + PrevATXIndex uint32 // Index of the previous ATX in the `ActivationTxV2.PreviousATXs` slice // Index of the leaf for this ID's challenge in the poet membership tree. // IDs might shared the same index if their nipost challenges are equal. // This happens when the IDs are continuously merged (they share the previous ATX). @@ -630,6 +630,9 @@ type MarriageCertificate struct { // An ATX of the NodeID that marries. It proves that the NodeID exists. // Note: the reference ATX does not need to be from the previous epoch. // It only needs to prove the existence of the Identity. + // + // In the case of a self signed certificate that is included in the Marriage ATX by the Smesher signing the ATX, + // this can be `types.EmptyATXID`. ReferenceAtx types.ATXID // Signature over the other ID that this ID marries with // If Alice marries Bob, then Alice signs Bob's ID diff --git a/api/grpcserver/post_service_test.go b/api/grpcserver/post_service_test.go index df6b6b1717..856520dfa9 100644 --- a/api/grpcserver/post_service_test.go +++ b/api/grpcserver/post_service_test.go @@ -64,7 +64,7 @@ func launchPostSupervisor( require.NoError(tb, err) // start post supervisor - builder := activation.NewMockAtxBuilder(ctrl) + builder := activation.NewMockatxBuilder(ctrl) builder.EXPECT().Register(sig) ps := activation.NewPostSupervisor(log, postCfg, provingOpts, mgr, builder) require.NoError(tb, ps.Start(serviceCfg, postOpts, sig)) @@ -108,7 +108,7 @@ func launchPostSupervisorTLS( require.NoError(tb, err) // start post supervisor - builder := activation.NewMockAtxBuilder(ctrl) + builder := activation.NewMockatxBuilder(ctrl) builder.EXPECT().Register(sig) ps := activation.NewPostSupervisor(log, postCfg, provingOpts, mgr, builder) require.NoError(tb, ps.Start(serviceCfg, postOpts, sig)) diff --git a/common/types/poet.go b/common/types/poet.go index 27e5764243..8eb24c5330 100644 --- a/common/types/poet.go +++ b/common/types/poet.go @@ -67,7 +67,7 @@ func (p *PoetProof) MarshalLogObject(encoder zapcore.ObjectEncoder) error { type PoetProofMessage struct { PoetProof PoetServiceID []byte `scale:"max=32"` // public key of the PoET service - RoundID string `scale:"max=32"` // TODO(mafa): convert to uint64 + RoundID string `scale:"max=32"` // The input to Poet's POSW. // It's the root of a merkle tree built from all of the members // that are included in the proof. diff --git a/fetch/wire_types.go b/fetch/wire_types.go index 14c8eeafe3..bd28b52f13 100644 --- a/fetch/wire_types.go +++ b/fetch/wire_types.go @@ -29,7 +29,7 @@ func init() { // RequestMessage is sent to the peer for hash query. type RequestMessage struct { - Hint datastore.Hint `scale:"max=256"` // TODO(mafa): covert to an enum + Hint datastore.Hint `scale:"max=256"` Hash types.Hash32 } diff --git a/genvm/core/types.go b/genvm/core/types.go index 7c56487754..41d55a75cb 100644 --- a/genvm/core/types.go +++ b/genvm/core/types.go @@ -27,7 +27,7 @@ type ( // Signature is an alias to types.EdSignature. Signature = types.EdSignature - // Account is an alis to types.Account. + // Account is an alias to types.Account. Account = types.Account // Header is an alias to types.TxHeader. Header = types.TxHeader diff --git a/hare3/hare.go b/hare3/hare.go index 50ace9df0a..0241be0666 100644 --- a/hare3/hare.go +++ b/hare3/hare.go @@ -472,7 +472,7 @@ func (h *Hare) run(session *session) error { if err := h.onOutput(session, current, out); err != nil { return err } - // we are logginng stats 1 network delay after new iteration start + // we are logging stats 1 network delay after new iteration start // so that we can receive notify messages from previous iteration if session.proto.Round == softlock && h.config.LogStats { h.log.Debug("stats", zap.Uint32("lid", session.lid.Uint32()), zap.Inline(session.proto.Stats())) diff --git a/hare4/hare.go b/hare4/hare.go index b3f48dba17..34700ca2db 100644 --- a/hare4/hare.go +++ b/hare4/hare.go @@ -684,7 +684,7 @@ func (h *Hare) run(session *session) error { if err := h.onOutput(session, current, out); err != nil { return err } - // we are logginng stats 1 network delay after new iteration start + // we are logging stats 1 network delay after new iteration start // so that we can receive notify messages from previous iteration if session.proto.Round == softlock && h.config.LogStats { h.log.Info("stats", zap.Uint32("lid", session.lid.Uint32()), zap.Inline(session.proto.Stats())) diff --git a/malfeasance2/handler.go b/malfeasance2/handler.go index 392299a8a1..3b8a5255b8 100644 --- a/malfeasance2/handler.go +++ b/malfeasance2/handler.go @@ -2,6 +2,8 @@ package malfeasance2 import ( "context" + "fmt" + "strconv" "go.uber.org/zap" @@ -9,6 +11,8 @@ import ( "github.com/spacemeshos/go-spacemesh/p2p" "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/malfeasance" + "github.com/spacemeshos/go-spacemesh/sql/marriage" ) // nolint:unused @@ -19,6 +23,8 @@ type Handler struct { nodeIDs []types.NodeID edVerifier *signing.EdVerifier tortoise tortoise + + handlers map[ProofDomain]MalfeasanceHandler } func NewHandler( @@ -36,9 +42,49 @@ func NewHandler( nodeIDs: nodeIDs, edVerifier: edVerifier, tortoise: tortoise, + + handlers: make(map[ProofDomain]MalfeasanceHandler), + } +} + +func (h *Handler) RegisterHandler(malfeasanceType ProofDomain, handler MalfeasanceHandler) { + if _, ok := h.handlers[malfeasanceType]; ok { + h.logger.Panic("handler already registered", zap.Int("malfeasanceType", int(malfeasanceType))) } + h.handlers[malfeasanceType] = handler } func (h *Handler) Info(ctx context.Context, nodeID types.NodeID) (map[string]string, error) { - return nil, sql.ErrNotFound + var ( + isMarried = false + domain int + proof []byte + ) + marriageID, err := marriage.FindIDByNodeID(h.db, nodeID) + if err == nil { + isMarried = true + proof, domain, err = malfeasance.MarriageProof(h.db, marriageID) + if err != nil { + return nil, fmt.Errorf("get malfeasance proof for married node ID %s: %w", nodeID, err) + } + } else { + proof, domain, err = malfeasance.NodeIDProof(h.db, nodeID) + if err != nil { + return nil, fmt.Errorf("get malfeasance proof for node ID %s: %w", nodeID, err) + } + } + + mh, ok := h.handlers[ProofDomain(domain)] + if !ok { + return nil, fmt.Errorf("unknown malfeasance domain %d", domain) + } + properties, err := mh.Info(proof) + if err != nil { + return nil, fmt.Errorf("malfeasance info: %w", err) + } + properties["domain"] = strconv.FormatUint(uint64(domain), 10) + if isMarried { + properties["malicious_id"] = nodeID.String() + } + return properties, nil } diff --git a/malfeasance2/handler_test.go b/malfeasance2/handler_test.go new file mode 100644 index 0000000000..ac8861c9b3 --- /dev/null +++ b/malfeasance2/handler_test.go @@ -0,0 +1,207 @@ +package malfeasance2_test + +import ( + "context" + "errors" + "maps" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "go.uber.org/zap/zaptest" + "go.uber.org/zap/zaptest/observer" + + "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/malfeasance2" + "github.com/spacemeshos/go-spacemesh/signing" + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/malfeasance" + "github.com/spacemeshos/go-spacemesh/sql/marriage" + "github.com/spacemeshos/go-spacemesh/sql/statesql" +) + +type testHandler struct { + *malfeasance2.Handler + + observedLogs *observer.ObservedLogs + ctrl *gomock.Controller + db sql.StateDatabase +} + +func newTestHandler(tb testing.TB) *testHandler { + db := statesql.InMemory() + edVerifier := signing.NewEdVerifier() + + observer, observedLogs := observer.New(zap.WarnLevel) + logger := zaptest.NewLogger(tb, zaptest.WrapOptions(zap.WrapCore( + func(core zapcore.Core) zapcore.Core { + return zapcore.NewTee(core, observer) + }, + ))) + + ctrl := gomock.NewController(tb) + mockTrt := malfeasance2.NewMocktortoise(ctrl) + + h := malfeasance2.NewHandler( + db, + logger, + "self", + []types.NodeID{types.RandomNodeID()}, + edVerifier, + mockTrt, + ) + return &testHandler{ + Handler: h, + + observedLogs: observedLogs, + ctrl: ctrl, + db: db, + } +} + +func TestRegister(t *testing.T) { + t.Parallel() + + t.Run("register", func(t *testing.T) { + t.Parallel() + th := newTestHandler(t) + + handler := malfeasance2.NewMockMalfeasanceHandler(th.ctrl) + th.RegisterHandler(malfeasance2.InvalidActivation, handler) + }) + + t.Run("already registered", func(t *testing.T) { + t.Parallel() + th := newTestHandler(t) + + handler := malfeasance2.NewMockMalfeasanceHandler(th.ctrl) + th.RegisterHandler(malfeasance2.InvalidActivation, handler) + + require.Panics(t, func() { + th.RegisterHandler(malfeasance2.InvalidActivation, handler) + }) + + logs := th.observedLogs.FilterLevelExact(zap.PanicLevel) + + require.Equal(t, 1, logs.Len()) + require.Equal(t, zap.PanicLevel, logs.All()[0].Level) + require.Contains(t, logs.All()[0].Message, "handler already registered") + }) +} + +func TestHandler_Info(t *testing.T) { + t.Run("unknown identity", func(t *testing.T) { + h := newTestHandler(t) + + info, err := h.Info(context.Background(), types.RandomNodeID()) + require.ErrorContains(t, err, "get malfeasance proof") + require.ErrorIs(t, err, sql.ErrNotFound) + require.Nil(t, info) + }) + + t.Run("unknown malfeasance type", func(t *testing.T) { + h := newTestHandler(t) + + nodeID := types.RandomNodeID() + proofBytes := types.RandomBytes(100) + err := malfeasance.AddProof(h.db, nodeID, nil, proofBytes, 999, time.Now()) + require.NoError(t, err) + + info, err := h.Info(context.Background(), nodeID) + require.ErrorContains(t, err, "unknown malfeasance domain 999") + require.Nil(t, info) + }) + + t.Run("invalid proof", func(t *testing.T) { + h := newTestHandler(t) + invalidProof := []byte("invalid") + infoError := errors.New("invalid proof") + mockHandler := malfeasance2.NewMockMalfeasanceHandler(h.ctrl) + mockHandler.EXPECT().Info(invalidProof).Return(nil, infoError) + h.RegisterHandler(malfeasance2.InvalidActivation, mockHandler) + + nodeID := types.RandomNodeID() + err := malfeasance.AddProof(h.db, nodeID, nil, invalidProof, int(malfeasance2.InvalidActivation), time.Now()) + require.NoError(t, err) + + info, err := h.Info(context.Background(), nodeID) + require.ErrorIs(t, err, infoError) + require.Nil(t, info) + }) + + t.Run("valid proof for node", func(t *testing.T) { + h := newTestHandler(t) + validProof := []byte("valid") + properties := map[string]string{ + "type": "DoubleMarry", + "key": "value", + } + mockHandler := malfeasance2.NewMockMalfeasanceHandler(h.ctrl) + mockHandler.EXPECT().Info(validProof).Return(properties, nil) + h.RegisterHandler(malfeasance2.InvalidActivation, mockHandler) + + nodeID := types.RandomNodeID() + err := malfeasance.AddProof(h.db, nodeID, nil, validProof, int(malfeasance2.InvalidActivation), time.Now()) + require.NoError(t, err) + + expectedProperties := maps.Clone(properties) + expectedProperties["domain"] = strconv.FormatUint(uint64(malfeasance2.InvalidActivation), 10) + + info, err := h.Info(context.Background(), nodeID) + require.NoError(t, err) + require.Equal(t, expectedProperties, info) + }) + + t.Run("valid proof for married node", func(t *testing.T) { + h := newTestHandler(t) + validProof := []byte("valid") + properties := map[string]string{ + "type": "InvalidPost", + "key": "value", + } + mockHandler := malfeasance2.NewMockMalfeasanceHandler(h.ctrl) + mockHandler.EXPECT().Info(validProof).Return(properties, nil) + h.RegisterHandler(malfeasance2.InvalidActivation, mockHandler) + + maliciousID := types.RandomNodeID() + nodeID := types.RandomNodeID() + + id, err := marriage.NewID(h.db) + require.NoError(t, err) + + err = marriage.Add(h.db, marriage.Info{ + ID: id, + NodeID: maliciousID, + ATX: types.RandomATXID(), + MarriageIndex: 0, + Target: types.RandomNodeID(), + Signature: types.RandomEdSignature(), + }) + require.NoError(t, err) + + err = malfeasance.AddProof( + h.db, + nodeID, + &id, + validProof, + int(malfeasance2.InvalidActivation), + time.Now(), + ) + require.NoError(t, err) + + err = malfeasance.SetMalicious(h.db, maliciousID, id, time.Now()) + require.NoError(t, err) + + expectedProperties := maps.Clone(properties) + expectedProperties["domain"] = strconv.FormatUint(uint64(malfeasance2.InvalidActivation), 10) + expectedProperties["malicious_id"] = maliciousID.String() + + info, err := h.Info(context.Background(), maliciousID) + require.NoError(t, err) + require.Equal(t, expectedProperties, info) + }) +} diff --git a/malfeasance2/interface.go b/malfeasance2/interface.go index 12bd87e98a..b7232dcc86 100644 --- a/malfeasance2/interface.go +++ b/malfeasance2/interface.go @@ -9,3 +9,8 @@ import ( type tortoise interface { OnMalfeasance(types.NodeID) } + +type MalfeasanceHandler interface { + // Info returns a map of key-value pairs that serve as metadata for the proof + Info(data []byte) (map[string]string, error) +} diff --git a/malfeasance2/mocks.go b/malfeasance2/mocks.go index c1c68d418e..5cd772fa52 100644 --- a/malfeasance2/mocks.go +++ b/malfeasance2/mocks.go @@ -75,3 +75,66 @@ func (c *MocktortoiseOnMalfeasanceCall) DoAndReturn(f func(types.NodeID)) *Mockt c.Call = c.Call.DoAndReturn(f) return c } + +// MockMalfeasanceHandler is a mock of MalfeasanceHandler interface. +type MockMalfeasanceHandler struct { + ctrl *gomock.Controller + recorder *MockMalfeasanceHandlerMockRecorder + isgomock struct{} +} + +// MockMalfeasanceHandlerMockRecorder is the mock recorder for MockMalfeasanceHandler. +type MockMalfeasanceHandlerMockRecorder struct { + mock *MockMalfeasanceHandler +} + +// NewMockMalfeasanceHandler creates a new mock instance. +func NewMockMalfeasanceHandler(ctrl *gomock.Controller) *MockMalfeasanceHandler { + mock := &MockMalfeasanceHandler{ctrl: ctrl} + mock.recorder = &MockMalfeasanceHandlerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockMalfeasanceHandler) EXPECT() *MockMalfeasanceHandlerMockRecorder { + return m.recorder +} + +// Info mocks base method. +func (m *MockMalfeasanceHandler) Info(data []byte) (map[string]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Info", data) + ret0, _ := ret[0].(map[string]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Info indicates an expected call of Info. +func (mr *MockMalfeasanceHandlerMockRecorder) Info(data any) *MockMalfeasanceHandlerInfoCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Info", reflect.TypeOf((*MockMalfeasanceHandler)(nil).Info), data) + return &MockMalfeasanceHandlerInfoCall{Call: call} +} + +// MockMalfeasanceHandlerInfoCall wrap *gomock.Call +type MockMalfeasanceHandlerInfoCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockMalfeasanceHandlerInfoCall) Return(arg0 map[string]string, arg1 error) *MockMalfeasanceHandlerInfoCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockMalfeasanceHandlerInfoCall) Do(f func([]byte) (map[string]string, error)) *MockMalfeasanceHandlerInfoCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockMalfeasanceHandlerInfoCall) DoAndReturn(f func([]byte) (map[string]string, error)) *MockMalfeasanceHandlerInfoCall { + c.Call = c.Call.DoAndReturn(f) + return c +} diff --git a/malfeasance2/wire.go b/malfeasance2/wire.go new file mode 100644 index 0000000000..35d2598580 --- /dev/null +++ b/malfeasance2/wire.go @@ -0,0 +1,10 @@ +package malfeasance2 + +// ProofDomain encodes the type of malfeasance proof. It is used to decide which domain generated the proof. +type ProofDomain byte + +const ( + InvalidActivation ProofDomain = 0x01 + InvalidBallot ProofDomain = 0x02 + InvalidHareMsg ProofDomain = 0x03 +) diff --git a/node/node.go b/node/node.go index b1711e3a3f..20b8905485 100644 --- a/node/node.go +++ b/node/node.go @@ -138,6 +138,7 @@ const ( ConStateLogger = "conState" ExecutorLogger = "executor" MalfeasanceLogger = "malfeasance" + Malfeasance2Logger = "malfeasance2" BootstrapLogger = "bootstrap" ) @@ -838,15 +839,18 @@ func (app *App) initServices(ctx context.Context) error { beaconProtocol.SetSyncState(syncer) hOracle.SetSync(syncer) - malfeasanceLogger := app.addLogger(MalfeasanceLogger, lg).Zap() + legacyMalfeasanceLogger := app.addLogger(MalfeasanceLogger, lg).Zap() legacyMalPublisher := malfeasance.NewPublisher( - malfeasanceLogger, + legacyMalfeasanceLogger, app.cachedDB, syncer, trtl, app.host, ) + malfeasanceLogger := app.addLogger(Malfeasance2Logger, lg).Zap() + atxMalHandler := activation.NewMalfeasanceHandlerV2() + atxHandler := activation.NewHandler( app.host.ID(), app.cachedDB, @@ -1139,18 +1143,18 @@ func (app *App) initServices(ctx context.Context) error { activationMH := activation.NewMalfeasanceHandler( app.cachedDB, - malfeasanceLogger, + legacyMalfeasanceLogger, app.edVerifier, ) meshMH := mesh.NewMalfeasanceHandler( app.cachedDB, app.edVerifier, - mesh.WithMalfeasanceLogger(malfeasanceLogger), + mesh.WithMalfeasanceLogger(legacyMalfeasanceLogger), ) hareMH := hare3.NewMalfeasanceHandler( app.cachedDB, app.edVerifier, - hare3.WithMalfeasanceLogger(malfeasanceLogger), + hare3.WithMalfeasanceLogger(legacyMalfeasanceLogger), ) invalidPostMH := activation.NewInvalidPostIndexHandler( app.cachedDB, @@ -1165,7 +1169,7 @@ func (app *App) initServices(ctx context.Context) error { } malHandler := malfeasance.NewHandler( app.cachedDB, - malfeasanceLogger, + legacyMalfeasanceLogger, app.host.ID(), nodeIDs, trtl, @@ -1184,6 +1188,7 @@ func (app *App) initServices(ctx context.Context) error { app.edVerifier, trtl, ) + malHandler2.RegisterHandler(malfeasance2.InvalidActivation, atxMalHandler) fetcher.SetValidators( fetch.ValidatorFunc( diff --git a/node/node_test.go b/node/node_test.go index a777bfd0c9..835bb69fbb 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -1225,7 +1225,7 @@ func launchPostSupervisor( provingOpts := activation.DefaultPostProvingOpts() provingOpts.RandomXMode = activation.PostRandomXModeLight - builder := activation.NewMockAtxBuilder(gomock.NewController(tb)) + builder := activation.NewMockatxBuilder(gomock.NewController(tb)) builder.EXPECT().Register(sig) ps := activation.NewPostSupervisor(log, postCfg, provingOpts, mgr, builder) require.NoError(tb, ps.Start(cmdCfg, postOpts, sig)) diff --git a/p2p/server/server.go b/p2p/server/server.go index 3a3925735c..ebc424cc24 100644 --- a/p2p/server/server.go +++ b/p2p/server/server.go @@ -138,7 +138,7 @@ func (err *ServerError) Error() string { type Response struct { // keep in line with limit of ResponseMessage.Data in `fetch/wire_types.go` Data []byte `scale:"max=272629760"` // 260 MiB > 8.0 mio ATX * 32 bytes per ID - Error string `scale:"max=1024"` // TODO(mafa): make error code instead of string + Error string `scale:"max=1024"` } // Server for the Handler. diff --git a/sql/malfeasance/malfeasance.go b/sql/malfeasance/malfeasance.go index 416743e153..ce313721ca 100644 --- a/sql/malfeasance/malfeasance.go +++ b/sql/malfeasance/malfeasance.go @@ -95,3 +95,59 @@ func IterateOps( ) return err } + +// NodeIDProof returns the malfeasance proof and its domain for the given node ID. Returns sql.ErrNotFound if no proof +// for the given node ID exists. To return a proof for a marriage set use MarriageProof instead. +func NodeIDProof(db sql.Executor, nodeID types.NodeID) ([]byte, int, error) { + var ( + proof []byte + domain int + ) + rows, err := db.Exec(` + SELECT proof, domain + FROM malfeasance + WHERE pubkey = ?1 AND marriage_id IS NULL + `, func(stmt *sql.Statement) { + stmt.BindBytes(1, nodeID.Bytes()) + }, func(stmt *sql.Statement) bool { + proof = make([]byte, stmt.ColumnLen(0)) + stmt.ColumnBytes(0, proof) + domain = int(stmt.ColumnInt64(1)) + return false + }) + if err != nil { + return nil, 0, fmt.Errorf("proof %v: %w", nodeID, err) + } + if rows == 0 { + return nil, 0, sql.ErrNotFound + } + return proof, domain, nil +} + +// MarriageProof returns the malfeasance proof for the marriage set. Returns sql.ErrNotFound if no proof for the given +// marriage ID exists. To return a proof for a node ID use NodeIDProof instead. +func MarriageProof(db sql.Executor, marriageID marriage.ID) ([]byte, int, error) { + var ( + proof []byte + domain int + ) + rows, err := db.Exec(` + SELECT proof, domain + FROM malfeasance + WHERE marriage_id = ?1 AND proof IS NOT NULL + `, func(stmt *sql.Statement) { + stmt.BindInt64(1, int64(marriageID)) + }, func(stmt *sql.Statement) bool { + proof = make([]byte, stmt.ColumnLen(0)) + stmt.ColumnBytes(0, proof) + domain = int(stmt.ColumnInt64(1)) + return false + }) + if err != nil { + return nil, 0, fmt.Errorf("marriage proof %v: %w", marriageID, err) + } + if rows == 0 { + return nil, 0, sql.ErrNotFound + } + return proof, domain, nil +} diff --git a/sql/malfeasance/malfeasance_test.go b/sql/malfeasance/malfeasance_test.go index e191bca7e6..078f021f70 100644 --- a/sql/malfeasance/malfeasance_test.go +++ b/sql/malfeasance/malfeasance_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/require" "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/builder" "github.com/spacemeshos/go-spacemesh/sql/malfeasance" "github.com/spacemeshos/go-spacemesh/sql/marriage" @@ -306,3 +307,122 @@ func Test_IterateMaliciousOps_Married(t *testing.T) { require.Zero(t, got[i].domain) } } + +func TestNodeIDProof(t *testing.T) { + t.Parallel() + + t.Run("unknown node has no proof", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + _, _, err := malfeasance.NodeIDProof(db, types.RandomNodeID()) + require.ErrorIs(t, err, sql.ErrNotFound) + }) + + t.Run("node with proof has proof", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + nodeID := types.RandomNodeID() + proof := types.RandomBytes(100) + err := malfeasance.AddProof(db, nodeID, nil, proof, 1, time.Now()) + require.NoError(t, err) + + p, domain, err := malfeasance.NodeIDProof(db, nodeID) + require.NoError(t, err) + require.Equal(t, proof, p) + require.Equal(t, 1, domain) + }) + + t.Run("node with proof and marriage ID returns no proof", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + id, err := marriage.NewID(db) + require.NoError(t, err) + + nodeID := types.RandomNodeID() + err = marriage.Add(db, marriage.Info{ + ID: id, + NodeID: nodeID, + ATX: types.RandomATXID(), + MarriageIndex: 0, + Target: types.RandomNodeID(), + Signature: types.RandomEdSignature(), + }) + require.NoError(t, err) + + err = malfeasance.AddProof(db, nodeID, &id, types.RandomBytes(100), 1, time.Now()) + require.NoError(t, err) + + _, _, err = malfeasance.NodeIDProof(db, nodeID) + require.ErrorIs(t, err, sql.ErrNotFound) + }) + + t.Run("node without proof and marriage ID returns no proof", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + id, err := marriage.NewID(db) + require.NoError(t, err) + + nodeID := types.RandomNodeID() + err = marriage.Add(db, marriage.Info{ + ID: id, + NodeID: nodeID, + ATX: types.RandomATXID(), + MarriageIndex: 0, + Target: types.RandomNodeID(), + Signature: types.RandomEdSignature(), + }) + require.NoError(t, err) + + err = malfeasance.AddProof(db, nodeID, &id, nil, 1, time.Now()) + require.NoError(t, err) + + _, _, err = malfeasance.NodeIDProof(db, nodeID) + require.ErrorIs(t, err, sql.ErrNotFound) + }) +} + +func TestMarriageProof(t *testing.T) { + t.Parallel() + + t.Run("unknown marriage ID has no proof", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + _, _, err := malfeasance.MarriageProof(db, marriage.ID(0)) + require.ErrorIs(t, err, sql.ErrNotFound) + }) + + t.Run("known marriage ID has proof", func(t *testing.T) { + t.Parallel() + db := statesql.InMemoryTest(t) + + id, err := marriage.NewID(db) + require.NoError(t, err) + + nodeID := types.RandomNodeID() + err = marriage.Add(db, marriage.Info{ + ID: id, + NodeID: nodeID, + ATX: types.RandomATXID(), + MarriageIndex: 0, + Target: types.RandomNodeID(), + Signature: types.RandomEdSignature(), + }) + require.NoError(t, err) + + proof := types.RandomBytes(100) + require.NoError(t, malfeasance.AddProof(db, nodeID, &id, proof, 1, time.Now())) + + nodeID2 := types.RandomNodeID() + require.NoError(t, malfeasance.SetMalicious(db, nodeID2, id, time.Now())) + + p, domain, err := malfeasance.MarriageProof(db, id) + require.NoError(t, err) + require.Equal(t, proof, p) + require.Equal(t, 1, domain) + }) +} diff --git a/sync2/multipeer/multipeer.go b/sync2/multipeer/multipeer.go index 90bd15a002..2b6400aaa2 100644 --- a/sync2/multipeer/multipeer.go +++ b/sync2/multipeer/multipeer.go @@ -165,7 +165,7 @@ func DefaultConfig() MultiPeerReconcilerConfig { } } -// MultiPeerReconciler reconcilies the local set against multiple remote sets. +// MultiPeerReconciler reconciles the local set against multiple remote sets. type MultiPeerReconciler struct { logger *zap.Logger cfg MultiPeerReconcilerConfig diff --git a/syncer/find_fork_test.go b/syncer/find_fork_test.go index 7276cbda36..f714b82305 100644 --- a/syncer/find_fork_test.go +++ b/syncer/find_fork_test.go @@ -3,7 +3,6 @@ package syncer_test import ( "context" "encoding/binary" - "fmt" "math/rand/v2" "strconv" "testing" @@ -108,7 +107,7 @@ func serveHashReq(tb testing.TB, req *fetch.MeshHashRequest) (*fetch.MeshHashes, hashes = append(hashes, layerHash(int(req.To.Uint32()), true)) expCount := int(req.Count()) - require.Len(tb, hashes, expCount, fmt.Sprintf("%#v; count exp: %v, got %v", req, expCount, len(hashes))) + require.Lenf(tb, hashes, expCount, "%#v; count exp: %v, got %v", req, expCount, len(hashes)) mh := &fetch.MeshHashes{ Hashes: hashes, } @@ -133,7 +132,7 @@ func TestForkFinder_FindFork_Permutation(t *testing.T) { }).AnyTimes() fork, err := tf.FindFork(context.Background(), peer, types.LayerID(uint32(lid)), layerHash(lid, true)) - require.NoError(t, err, fmt.Sprintf("lid: %v", lid)) + require.NoErrorf(t, err, "lid: %v", lid) require.Equal(t, expected, int(fork)) } } diff --git a/systest/tests/distributed_post_verification_test.go b/systest/tests/distributed_post_verification_test.go index 94784843f9..513466bdde 100644 --- a/systest/tests/distributed_post_verification_test.go +++ b/systest/tests/distributed_post_verification_test.go @@ -163,7 +163,7 @@ func TestPostMalfeasanceProof(t *testing.T) { ) require.NoError(t, err) - builder := activation.NewMockAtxBuilder(ctrl) + builder := activation.NewMockatxBuilder(ctrl) builder.EXPECT().Register(signer) postSupervisor := activation.NewPostSupervisor( logger.Named("post-supervisor"), diff --git a/systest/tests/steps_test.go b/systest/tests/steps_test.go index dd1369b106..5d284ee264 100644 --- a/systest/tests/steps_test.go +++ b/systest/tests/steps_test.go @@ -189,10 +189,10 @@ func TestStepReplaceNodes(t *testing.T) { require.NoError(t, err) var ( - delete = rand.Intn(cctx.ClusterSize*2/10) + 1 + toDelete = rand.Intn(cctx.ClusterSize*2/10) + 1 deleting []*cluster.NodeClient ) - for i := cl.Bootnodes(); i < cl.Total() && len(deleting) < delete; i++ { + for i := cl.Bootnodes(); i < cl.Total() && len(deleting) < toDelete; i++ { node := cl.Client(i) // don't replace non-synced nodes if !isSynced(cctx, node) {