From 901b02fa6dff3509f78e24527380fc9fd8606a95 Mon Sep 17 00:00:00 2001 From: Matthias Fasching <5011972+fasmat@users.noreply.github.com> Date: Tue, 15 Oct 2024 17:58:22 +0000 Subject: [PATCH] Fix creating invalid malfeasance proofs (#6387) ## Motivation In rare occasions a node can create a malfeasance proof against a malicious identity that isn't valid. This only happens when the identity is actually malicious, just that the proof that is created cannot be validated by other nodes. --- .gitignore | 2 +- CHANGELOG.md | 16 ++ activation/handler_v1.go | 27 ++- activation/handler_v1_test.go | 75 +++++- checkpoint/recovery.go | 14 +- fetch/handler.go | 2 +- malfeasance/handler.go | 38 ++- malfeasance/handler_test.go | 4 +- mesh/ballotwriter/ballotwriter.go | 3 +- mesh/ballotwriter/ballotwriter_test.go | 27 +-- node/node.go | 2 +- node/node_integrity_test.go | 99 ++++++++ p2p/pubsub/wrapper.go | 2 +- sql/atxs/atxs.go | 2 +- sql/blocks/blocks.go | 2 +- sql/localsql/migrations/schema_test.go | 22 ++ sql/schemagen/main.go | 8 +- sql/statesql/migrations/interfaces.go | 14 ++ sql/statesql/migrations/mocks.go | 81 +++++++ sql/statesql/migrations/schema.go | 7 +- sql/statesql/migrations/schema_test.go | 2 +- .../migrations/state_0025_migration.go | 100 ++++++++ .../migrations/state_0025_migration_test.go | 217 ++++++++++++++++++ .../0017_atxs_prev_id_nonce_placeholder.sql | 2 +- .../migrations/0025_check_malfeasance.sql | 1 + sql/statesql/schema/schema.sql | 2 +- 26 files changed, 703 insertions(+), 68 deletions(-) create mode 100644 node/node_integrity_test.go create mode 100644 sql/localsql/migrations/schema_test.go create mode 100644 sql/statesql/migrations/interfaces.go create mode 100644 sql/statesql/migrations/mocks.go create mode 100644 sql/statesql/migrations/state_0025_migration.go create mode 100644 sql/statesql/migrations/state_0025_migration_test.go create mode 100644 sql/statesql/schema/migrations/0025_check_malfeasance.sql diff --git a/.gitignore b/.gitignore index 6e0d199a8a..27cb917ab8 100644 --- a/.gitignore +++ b/.gitignore @@ -36,4 +36,4 @@ database/data vendor/* systest/vendor/* -.run/* \ No newline at end of file +.run/* diff --git a/CHANGELOG.md b/CHANGELOG.md index 2aa75b9cb5..0464827b45 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,22 @@ See [RELEASE](./RELEASE.md) for workflow instructions. +## UNRELEASED + +### Upgrade information + +### Highlights + +### Features + +### Improvements + +* [#6378](https://github.com/spacemeshos/go-spacemesh/pull/6387) Improved handling of malicious identities. This reduces + the number of DB queries needed during ATX validation. + +* [#6387](https://github.com/spacemeshos/go-spacemesh/pull/6387) Fix an issue were in rare cases invalid proofs for + malicious identities were created. + ## v1.7.4 ### Improvements diff --git a/activation/handler_v1.go b/activation/handler_v1.go index 85cd805685..c96756ba9c 100644 --- a/activation/handler_v1.go +++ b/activation/handler_v1.go @@ -226,6 +226,13 @@ func (h *HandlerV1) syntacticallyValidateDeps( zap.Stringer("atx_id", atx.ID()), zap.Int("index", invalidIdx.Index), ) + malicious, err := identities.IsMalicious(h.cdb, atx.SmesherID) + if err != nil { + return 0, 0, nil, fmt.Errorf("check if smesher is malicious: %w", err) + } + if malicious { + return 0, 0, nil, fmt.Errorf("smesher %s is known malfeasant", atx.SmesherID.ShortString()) + } proof := &mwire.MalfeasanceProof{ Layer: atx.PublishEpoch.FirstLayer(), Proof: mwire.Proof{ @@ -261,10 +268,11 @@ func (h *HandlerV1) validateNonInitialAtx( return err } - needRecheck := atx.VRFNonce != nil || atx.NumUnits > previous.NumUnits - if atx.VRFNonce == nil { - atx.VRFNonce = new(uint64) - *atx.VRFNonce = uint64(previous.VRFNonce) + vrfNonce := atx.VRFNonce + needRecheck := vrfNonce != nil || atx.NumUnits > previous.NumUnits + if vrfNonce == nil { + vrfNonce = new(uint64) + *vrfNonce = uint64(previous.VRFNonce) } if needRecheck { @@ -274,8 +282,13 @@ func (h *HandlerV1) validateNonInitialAtx( zap.Bool("post increased", atx.NumUnits > previous.NumUnits), zap.Stringer("smesher", atx.SmesherID), ) - err := h.nipostValidator. - VRFNonce(atx.SmesherID, commitment, *atx.VRFNonce, atx.NIPost.PostMetadata.LabelsPerUnit, atx.NumUnits) + err := h.nipostValidator.VRFNonce( + atx.SmesherID, + commitment, + *vrfNonce, + atx.NIPost.PostMetadata.LabelsPerUnit, + atx.NumUnits, + ) if err != nil { return fmt.Errorf("invalid vrf nonce: %w", err) } @@ -517,7 +530,7 @@ func (h *HandlerV1) processATX( existing, _ := h.cdb.GetAtx(watx.ID()) if existing != nil { - return nil, fmt.Errorf("%w atx %s", errKnownAtx, watx.ID()) + return nil, fmt.Errorf("%w: %s", errKnownAtx, watx.ID()) } h.logger.Debug("processing atx", diff --git a/activation/handler_v1_test.go b/activation/handler_v1_test.go index 56b9ffb7f1..1f88396407 100644 --- a/activation/handler_v1_test.go +++ b/activation/handler_v1_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/spacemeshos/post/verifying" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" "go.uber.org/zap/zaptest" @@ -81,6 +82,7 @@ func TestHandlerV1_SyntacticallyValidateAtx(t *testing.T) { require.NoError(t, err) return atxHdlr, prevAtx, posAtx } + t.Run("valid atx", func(t *testing.T) { t.Parallel() atxHdlr, prevAtx, posAtx := setup(t) @@ -103,6 +105,7 @@ func TestHandlerV1_SyntacticallyValidateAtx(t *testing.T) { require.Equal(t, atx.NumUnits, units) require.Nil(t, proof) }) + t.Run("valid atx with new VRF nonce", func(t *testing.T) { t.Parallel() atxHdlr, prevAtx, posAtx := setup(t) @@ -128,6 +131,7 @@ func TestHandlerV1_SyntacticallyValidateAtx(t *testing.T) { require.Equal(t, atx.NumUnits, units) require.Nil(t, proof) }) + t.Run("valid atx with decreasing num units", func(t *testing.T) { t.Parallel() atxHdlr, prevAtx, posAtx := setup(t) @@ -149,6 +153,7 @@ func TestHandlerV1_SyntacticallyValidateAtx(t *testing.T) { require.Equal(t, atx.NumUnits, units) require.Nil(t, proof) }) + t.Run("atx with increasing num units, no new VRF, old valid", func(t *testing.T) { t.Parallel() atxHdlr, prevAtx, posAtx := setup(t) @@ -171,6 +176,7 @@ func TestHandlerV1_SyntacticallyValidateAtx(t *testing.T) { require.Equal(t, prevAtx.NumUnits, units) require.Nil(t, proof) }) + t.Run("atx with increasing num units, no new VRF, old invalid for new size", func(t *testing.T) { t.Parallel() atxHdlr, prevAtx, posAtx := setup(t) @@ -190,6 +196,7 @@ func TestHandlerV1_SyntacticallyValidateAtx(t *testing.T) { require.ErrorContains(t, err, "invalid VRF") require.Nil(t, proof) }) + t.Run("valid initial atx", func(t *testing.T) { t.Parallel() atxHdlr, _, posAtx := setup(t) @@ -216,6 +223,7 @@ func TestHandlerV1_SyntacticallyValidateAtx(t *testing.T) { require.Equal(t, atx.NumUnits, units) require.Nil(t, proof) }) + t.Run("atx targeting wrong publish epoch", func(t *testing.T) { t.Parallel() atxHdlr, prevAtx, posAtx := setup(t) @@ -227,6 +235,7 @@ func TestHandlerV1_SyntacticallyValidateAtx(t *testing.T) { err := atxHdlr.syntacticallyValidate(context.Background(), atx) require.ErrorContains(t, err, "atx publish epoch is too far in the future") }) + t.Run("failing nipost challenge validation", func(t *testing.T) { t.Parallel() atxHdlr, prevAtx, posAtx := setup(t) @@ -244,6 +253,7 @@ func TestHandlerV1_SyntacticallyValidateAtx(t *testing.T) { require.EqualError(t, err, "nipost error") require.Nil(t, proof) }) + t.Run("failing positioning atx validation", func(t *testing.T) { t.Parallel() atxHdlr, prevAtx, posAtx := setup(t) @@ -262,6 +272,7 @@ func TestHandlerV1_SyntacticallyValidateAtx(t *testing.T) { require.EqualError(t, err, "bad positioning atx") require.Nil(t, proof) }) + t.Run("bad initial nipost challenge", func(t *testing.T) { t.Parallel() atxHdlr, _, posAtx := setup(t) @@ -283,6 +294,7 @@ func TestHandlerV1_SyntacticallyValidateAtx(t *testing.T) { require.EqualError(t, err, "bad initial nipost") require.Nil(t, proof) }) + t.Run("bad NIPoST", func(t *testing.T) { t.Parallel() atxHdlr, prevATX, postAtx := setup(t) @@ -298,9 +310,55 @@ func TestHandlerV1_SyntacticallyValidateAtx(t *testing.T) { atxHdlr.mValidator.EXPECT(). NIPost(gomock.Any(), atx.SmesherID, goldenATXID, gomock.Any(), gomock.Any(), atx.NumUnits, gomock.Any()). Return(0, errors.New("bad nipost")) - _, _, _, err := atxHdlr.syntacticallyValidateDeps(context.Background(), atx) + _, _, proof, err := atxHdlr.syntacticallyValidateDeps(context.Background(), atx) require.EqualError(t, err, "validating nipost: bad nipost") + require.Nil(t, proof) }) + + t.Run("invalid NIPoST", func(t *testing.T) { + t.Parallel() + atxHdlr, prevATX, postAtx := setup(t) + + atx := newChainedActivationTxV1(t, prevATX, postAtx.ID()) + atx.Sign(sig) + + atxHdlr.mclock.EXPECT().CurrentLayer().Return(atx.PublishEpoch.FirstLayer()) + require.NoError(t, atxHdlr.syntacticallyValidate(context.Background(), atx)) + + atxHdlr.mValidator.EXPECT().NIPostChallengeV1(gomock.Any(), gomock.Any(), atx.SmesherID) + atxHdlr.mValidator.EXPECT().PositioningAtx(atx.PositioningATXID, gomock.Any(), goldenATXID, atx.PublishEpoch) + atxHdlr.mValidator.EXPECT(). + NIPost(gomock.Any(), atx.SmesherID, goldenATXID, gomock.Any(), gomock.Any(), atx.NumUnits, gomock.Any()). + Return(0, &verifying.ErrInvalidIndex{Index: 2}) + atxHdlr.mtortoise.EXPECT().OnMalfeasance(atx.SmesherID) + _, _, proof, err := atxHdlr.syntacticallyValidateDeps(context.Background(), atx) + require.NoError(t, err) + require.NotNil(t, proof) + require.Equal(t, mwire.InvalidPostIndex, proof.Proof.Type) + }) + + t.Run("invalid NIPoST of known malfeasant", func(t *testing.T) { + t.Parallel() + atxHdlr, prevATX, postAtx := setup(t) + + atx := newChainedActivationTxV1(t, prevATX, postAtx.ID()) + atx.Sign(sig) + + require.NoError(t, identities.SetMalicious(atxHdlr.cdb, atx.SmesherID, []byte("proof"), time.Now())) + + atxHdlr.mclock.EXPECT().CurrentLayer().Return(atx.PublishEpoch.FirstLayer()) + require.NoError(t, atxHdlr.syntacticallyValidate(context.Background(), atx)) + + atxHdlr.mValidator.EXPECT().NIPostChallengeV1(gomock.Any(), gomock.Any(), atx.SmesherID) + atxHdlr.mValidator.EXPECT().PositioningAtx(atx.PositioningATXID, gomock.Any(), goldenATXID, atx.PublishEpoch) + atxHdlr.mValidator.EXPECT(). + NIPost(gomock.Any(), atx.SmesherID, goldenATXID, gomock.Any(), gomock.Any(), atx.NumUnits, gomock.Any()). + Return(0, &verifying.ErrInvalidIndex{Index: 2}) + _, _, proof, err := atxHdlr.syntacticallyValidateDeps(context.Background(), atx) + require.EqualError(t, err, fmt.Sprintf("smesher %s is known malfeasant", atx.SmesherID.ShortString())) + require.Nil(t, proof) + }) + t.Run("missing NodeID in initial atx", func(t *testing.T) { t.Parallel() atxHdlr, _, _ := setup(t) @@ -313,6 +371,7 @@ func TestHandlerV1_SyntacticallyValidateAtx(t *testing.T) { err := atxHdlr.syntacticallyValidate(context.Background(), atx) require.ErrorContains(t, err, "node id is missing") }) + t.Run("missing VRF nonce in initial atx", func(t *testing.T) { t.Parallel() atxHdlr, _, _ := setup(t) @@ -325,6 +384,7 @@ func TestHandlerV1_SyntacticallyValidateAtx(t *testing.T) { err := atxHdlr.syntacticallyValidate(context.Background(), atx) require.ErrorContains(t, err, "vrf nonce is missing") }) + t.Run("invalid VRF nonce in initial atx", func(t *testing.T) { t.Parallel() atxHdlr, _, _ := setup(t) @@ -339,6 +399,7 @@ func TestHandlerV1_SyntacticallyValidateAtx(t *testing.T) { err := atxHdlr.syntacticallyValidate(context.Background(), atx) require.ErrorContains(t, err, "invalid VRF nonce") }) + t.Run("prevAtx not declared but initial Post not included", func(t *testing.T) { t.Parallel() atxHdlr, prevAtx, posAtx := setup(t) @@ -351,6 +412,7 @@ func TestHandlerV1_SyntacticallyValidateAtx(t *testing.T) { err := atxHdlr.syntacticallyValidate(context.Background(), atx) require.EqualError(t, err, "no prev atx declared, but initial post is not included") }) + t.Run("prevAtx not declared but commitment ATX is not included", func(t *testing.T) { t.Parallel() atxHdlr, _, _ := setup(t) @@ -363,6 +425,7 @@ func TestHandlerV1_SyntacticallyValidateAtx(t *testing.T) { err := atxHdlr.syntacticallyValidate(context.Background(), atx) require.EqualError(t, err, "no prev atx declared, but commitment atx is missing") }) + t.Run("prevAtx not declared but commitment ATX is empty", func(t *testing.T) { t.Parallel() atxHdlr, _, _ := setup(t) @@ -375,6 +438,7 @@ func TestHandlerV1_SyntacticallyValidateAtx(t *testing.T) { err := atxHdlr.syntacticallyValidate(context.Background(), atx) require.EqualError(t, err, "empty commitment atx") }) + t.Run("prevAtx not declared but sequence not zero", func(t *testing.T) { t.Parallel() atxHdlr, _, _ := setup(t) @@ -387,6 +451,7 @@ func TestHandlerV1_SyntacticallyValidateAtx(t *testing.T) { err := atxHdlr.syntacticallyValidate(context.Background(), atx) require.EqualError(t, err, "no prev atx declared, but sequence number not zero") }) + t.Run("prevAtx not declared but validation of initial post fails", func(t *testing.T) { t.Parallel() atxHdlr, _, _ := setup(t) @@ -403,6 +468,7 @@ func TestHandlerV1_SyntacticallyValidateAtx(t *testing.T) { err := atxHdlr.syntacticallyValidate(context.Background(), atx) require.ErrorContains(t, err, "failed post validation") }) + t.Run("empty positioning ATX", func(t *testing.T) { t.Parallel() atxHdlr, _, _ := setup(t) @@ -415,6 +481,7 @@ func TestHandlerV1_SyntacticallyValidateAtx(t *testing.T) { err := atxHdlr.syntacticallyValidate(context.Background(), atx) require.EqualError(t, err, "empty positioning atx") }) + t.Run("prevAtx declared but initial Post is included", func(t *testing.T) { t.Parallel() atxHdlr, prevAtx, _ := setup(t) @@ -427,6 +494,7 @@ func TestHandlerV1_SyntacticallyValidateAtx(t *testing.T) { err := atxHdlr.syntacticallyValidate(context.Background(), atx) require.EqualError(t, err, "prev atx declared, but initial post is included") }) + t.Run("prevAtx declared but NodeID is included", func(t *testing.T) { t.Parallel() atxHdlr, prevAtx, posAtx := setup(t) @@ -439,6 +507,7 @@ func TestHandlerV1_SyntacticallyValidateAtx(t *testing.T) { err := atxHdlr.syntacticallyValidate(context.Background(), atx) require.EqualError(t, err, "prev atx declared, but node id is included") }) + t.Run("prevAtx declared but commitmentATX is included", func(t *testing.T) { t.Parallel() atxHdlr, prevAtx, posAtx := setup(t) @@ -716,6 +785,7 @@ func TestHandlerV1_StoreAtx(t *testing.T) { func TestHandlerV1_RegistersHashesInPeer(t *testing.T) { goldenATXID := types.RandomATXID() peer := p2p.Peer("buddy") + t.Run("registers poet and atxs", func(t *testing.T) { t.Parallel() atxHdlr := newV1TestHandler(t, goldenATXID) @@ -727,6 +797,7 @@ func TestHandlerV1_RegistersHashesInPeer(t *testing.T) { RegisterPeerHashes(peer, gomock.InAnyOrder([]types.Hash32{poet, atxs[0].Hash32(), atxs[1].Hash32()})) atxHdlr.registerHashes(peer, poet, atxs) }) + t.Run("registers poet", func(t *testing.T) { t.Parallel() atxHdlr := newV1TestHandler(t, goldenATXID) @@ -740,6 +811,7 @@ func TestHandlerV1_RegistersHashesInPeer(t *testing.T) { func TestHandlerV1_FetchesReferences(t *testing.T) { goldenATXID := types.RandomATXID() + t.Run("fetch poet and atxs", func(t *testing.T) { t.Parallel() atxHdlr := newV1TestHandler(t, goldenATXID) @@ -776,6 +848,7 @@ func TestHandlerV1_FetchesReferences(t *testing.T) { atxHdlr.mockFetch.EXPECT().GetAtxs(gomock.Any(), atxs, gomock.Any()).Return(errors.New("oh")) require.Error(t, atxHdlr.fetchReferences(context.Background(), poet, atxs)) }) + t.Run("reject ATX when dependency ATX is rejected", func(t *testing.T) { t.Parallel() atxHdlr := newV1TestHandler(t, goldenATXID) diff --git a/checkpoint/recovery.go b/checkpoint/recovery.go index aeb1a51c7b..7fb0cb775d 100644 --- a/checkpoint/recovery.go +++ b/checkpoint/recovery.go @@ -24,14 +24,12 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/atxsync" "github.com/spacemeshos/go-spacemesh/sql/localsql" - localmigrations "github.com/spacemeshos/go-spacemesh/sql/localsql/migrations" "github.com/spacemeshos/go-spacemesh/sql/localsql/nipost" "github.com/spacemeshos/go-spacemesh/sql/malsync" "github.com/spacemeshos/go-spacemesh/sql/marriage" "github.com/spacemeshos/go-spacemesh/sql/poets" "github.com/spacemeshos/go-spacemesh/sql/recovery" "github.com/spacemeshos/go-spacemesh/sql/statesql" - statemigrations "github.com/spacemeshos/go-spacemesh/sql/statesql/migrations" ) const recoveryDir = "recovery" @@ -123,25 +121,17 @@ func Recover( return nil, errors.New("restore layer not set") } logger.Info("recovering from checkpoint", zap.String("url", cfg.Uri), zap.Stringer("restore", cfg.Restore)) - schema, err := statemigrations.SchemaWithInCodeMigrations() - if err != nil { - return nil, fmt.Errorf("error loading db schema: %w", err) - } db, err := statesql.Open( "file:"+cfg.DbPath(), - sql.WithDatabaseSchema(schema), + sql.WithMigrationsDisabled(), ) if err != nil { return nil, fmt.Errorf("open old database: %w", err) } defer db.Close() - lSchema, err := localmigrations.SchemaWithInCodeMigrations() - if err != nil { - return nil, fmt.Errorf("get schema with in-code migrations: %w", err) - } localDB, err := localsql.Open( "file:"+filepath.Join(cfg.DataDir, cfg.LocalDbFile), - sql.WithDatabaseSchema(lSchema), + sql.WithMigrationsDisabled(), ) if err != nil { return nil, fmt.Errorf("open old local database: %w", err) diff --git a/fetch/handler.go b/fetch/handler.go index 411480ee44..371377cc50 100644 --- a/fetch/handler.go +++ b/fetch/handler.go @@ -339,7 +339,7 @@ func (h *handler) doHandleHashReqStream( ) error { var requestBatch RequestBatch if err := codec.Decode(msg, &requestBatch); err != nil { - return fmt.Errorf("%w: decooding request: %w", errBadRequest, err) + return fmt.Errorf("%w: decoding request: %w", errBadRequest, err) } if hint != datastore.NoHint && len(requestBatch.Requests) > 1 { diff --git a/malfeasance/handler.go b/malfeasance/handler.go index 5d4eea62a9..a6e1d0f847 100644 --- a/malfeasance/handler.go +++ b/malfeasance/handler.go @@ -23,11 +23,11 @@ import ( ) var ( - ErrKnownProof = errors.New("known proof") - + errKnownProof = errors.New("known proof") errMalformedData = fmt.Errorf("%w: malformed data", pubsub.ErrValidationReject) errWrongHash = fmt.Errorf("%w: incorrect hash", pubsub.ErrValidationReject) errUnknownProof = fmt.Errorf("%w: unknown proof type", pubsub.ErrValidationReject) + errInvalidProof = fmt.Errorf("%w: invalid proof", pubsub.ErrValidationReject) ) type MalfeasanceType byte @@ -115,7 +115,7 @@ func (h *Handler) Info(data []byte) (map[string]string, error) { func (h *Handler) HandleSyncedMalfeasanceProof( ctx context.Context, expHash types.Hash32, - _ p2p.Peer, + peer p2p.Peer, data []byte, ) error { var p wire.MalfeasanceProof @@ -126,6 +126,14 @@ func (h *Handler) HandleSyncedMalfeasanceProof( } nodeID, err := h.validateAndSave(ctx, &p) if err == nil && types.Hash32(nodeID) != expHash { + // we log & return because libp2p will ignore the message if we return an error, + // but only log "validation ignored" instead of the error we return + h.logger.Warn("malfeasance proof for wrong identity", + log.ZContext(ctx), + log.ZShortStringer("expected", expHash), + log.ZShortStringer("got", nodeID), + zap.Stringer("peer", peer), + ) return fmt.Errorf( "%w: malfeasance proof want %s, got %s", errWrongHash, @@ -133,6 +141,9 @@ func (h *Handler) HandleSyncedMalfeasanceProof( nodeID.ShortString(), ) } + if errors.Is(err, errKnownProof) { + return nil + } return err } @@ -151,8 +162,13 @@ func (h *Handler) HandleMalfeasanceProof(ctx context.Context, peer p2p.Peer, dat if peer == h.self { id, err := h.Validate(ctx, &p.MalfeasanceProof) if err != nil { + h.logger.Warn("malfeasance proof failed validation during publish", + log.ZContext(ctx), + zap.Inline(&p.MalfeasanceProof), + zap.Error(err), + ) h.countInvalidProof(&p.MalfeasanceProof) - return err + return fmt.Errorf("%w: %s", pubsub.ErrValidationReject, err) } h.reportMalfeasance(id, codec.MustEncode(&p.MalfeasanceProof)) // node saves malfeasance proof eagerly/atomically with the malicious data. @@ -161,6 +177,9 @@ func (h *Handler) HandleMalfeasanceProof(ctx context.Context, peer p2p.Peer, dat return nil } _, err := h.validateAndSave(ctx, &p.MalfeasanceProof) + if errors.Is(err, errKnownProof) { + return nil + } return err } @@ -183,14 +202,14 @@ func (h *Handler) validateAndSave(ctx context.Context, p *wire.MalfeasanceProof) } if malicious { h.logger.Debug("known malicious identity", log.ZContext(ctx), zap.Stringer("smesher", nodeID)) - return ErrKnownProof + return errKnownProof } if err := identities.SetMalicious(dbtx, nodeID, proofBytes, time.Now()); err != nil { return fmt.Errorf("add malfeasance proof: %w", err) } return nil }); err != nil { - if !errors.Is(err, ErrKnownProof) { + if !errors.Is(err, errKnownProof) { h.logger.Error("failed to save MalfeasanceProof", log.ZContext(ctx), zap.Stringer("smesher", nodeID), @@ -221,10 +240,5 @@ func (h *Handler) Validate(ctx context.Context, p *wire.MalfeasanceProof) (types if err == nil { return nodeID, nil } - h.logger.Debug("malfeasance proof failed validation", - log.ZContext(ctx), - zap.Inline(p), - zap.Error(err), - ) - return nodeID, err + return nodeID, fmt.Errorf("%w: %v", errInvalidProof, err) } diff --git a/malfeasance/handler_test.go b/malfeasance/handler_test.go index f594138f93..9c45a4879c 100644 --- a/malfeasance/handler_test.go +++ b/malfeasance/handler_test.go @@ -185,7 +185,7 @@ func TestHandler_HandleMalfeasanceProof(t *testing.T) { } err := h.HandleMalfeasanceProof(context.Background(), "peer", codec.MustEncode(gossip)) - require.ErrorIs(t, ErrKnownProof, err) + require.NoError(t, err) var blob sql.Blob require.NoError(t, identities.LoadMalfeasanceBlob(context.Background(), h.db, nodeID.Bytes(), &blob)) @@ -363,7 +363,7 @@ func TestHandler_HandleSyncedMalfeasanceProof(t *testing.T) { require.NotEqual(t, proofBytes, newProofBytes) err := h.HandleSyncedMalfeasanceProof(context.Background(), types.Hash32(nodeID), "peer", newProofBytes) - require.ErrorIs(t, ErrKnownProof, err) + require.NoError(t, err) var blob sql.Blob require.NoError(t, identities.LoadMalfeasanceBlob(context.Background(), h.db, nodeID.Bytes(), &blob)) diff --git a/mesh/ballotwriter/ballotwriter.go b/mesh/ballotwriter/ballotwriter.go index 1252be9cfc..35cb2ec43c 100644 --- a/mesh/ballotwriter/ballotwriter.go +++ b/mesh/ballotwriter/ballotwriter.go @@ -165,6 +165,7 @@ type batchResult struct { } type db interface { - WithTx(context.Context, func(sql.Transaction) error) error sql.Executor + + WithTx(context.Context, func(sql.Transaction) error) error } diff --git a/mesh/ballotwriter/ballotwriter_test.go b/mesh/ballotwriter/ballotwriter_test.go index aeb742c68d..b9a19880a3 100644 --- a/mesh/ballotwriter/ballotwriter_test.go +++ b/mesh/ballotwriter/ballotwriter_test.go @@ -20,7 +20,6 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/ballots" "github.com/spacemeshos/go-spacemesh/sql/identities" "github.com/spacemeshos/go-spacemesh/sql/statesql" - "github.com/spacemeshos/go-spacemesh/sql/statesql/migrations" ) func init() { @@ -75,8 +74,8 @@ func TestWriteCoalesce_OnePerSmesher(t *testing.T) { require.NotNil(t, blob.Bytes) } -func BenchmarkWriteCoalesing(b *testing.B) { - a := make([]*types.Ballot, 100000) +func BenchmarkWriteCoalescing(b *testing.B) { + a := make([]*types.Ballot, 1000000) for i := 0; i < len(a); i++ { a[i] = genBallot(b) } @@ -119,7 +118,8 @@ func BenchmarkWriteCoalesing(b *testing.B) { return nil } b.ResetTimer() - b.Run("No Coalesing", func(b *testing.B) { + + b.Run("No Coalescing", func(b *testing.B) { db := newDiskSqlite(b) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -134,13 +134,10 @@ func BenchmarkWriteCoalesing(b *testing.B) { } }) - // with the coalesing tests, one must take the "ns/op" metrics and divide it - // by the number of entries written together to see how many items we're doing - // per time unit. - b.Run("Coalesing 1000 entries", func(b *testing.B) { + b.Run("Coalescing 1000 entries", func(b *testing.B) { db := newDiskSqlite(b) b.ResetTimer() - for j := 0; j < b.N; j++ { + for j := 0; j < b.N/1000; j++ { if err := db.WithTx(context.Background(), func(tx sql.Transaction) error { var err error for i := (j * 1000); i < (j*1000)+1000; i++ { @@ -155,10 +152,10 @@ func BenchmarkWriteCoalesing(b *testing.B) { } }) - b.Run("Coalesing 5000 entries", func(b *testing.B) { + b.Run("Coalescing 5000 entries", func(b *testing.B) { db := newDiskSqlite(b) b.ResetTimer() - for j := 0; j < b.N; j++ { + for j := 0; j < b.N/5000; j++ { if err := db.WithTx(context.Background(), func(tx sql.Transaction) error { var err error for i := (j * 5000); i < (j*5000)+5000; i++ { @@ -206,15 +203,9 @@ func newTestBallotWriter(t testing.TB) (*ballotwriter.BallotWriter, sql.Database func newDiskSqlite(tb testing.TB) sql.Database { tb.Helper() - schema, err := migrations.SchemaWithInCodeMigrations() - require.NoError(tb, err) - dbopts := []sql.Opt{ - sql.WithDatabaseSchema(schema), - sql.WithForceMigrations(true), - } dir := tb.TempDir() - sqlDB, err := sql.Open("file:"+filepath.Join(dir, "sql.sql"), dbopts...) + sqlDB, err := statesql.Open("file:" + filepath.Join(dir, "state.sql")) if err != nil { tb.Fatal(err) } diff --git a/node/node.go b/node/node.go index 0690bcf9e6..44ad9db4c0 100644 --- a/node/node.go +++ b/node/node.go @@ -1954,7 +1954,7 @@ func (app *App) setupDBs(ctx context.Context, lg log.Log) error { return fmt.Errorf("failed to create %s: %w", dbPath, err) } dbLog := app.addLogger(StateDbLogger, lg).Zap() - schema, err := statemigrations.SchemaWithInCodeMigrations() + schema, err := statemigrations.SchemaWithInCodeMigrations(app.malfeasanceHandler) if err != nil { return fmt.Errorf("error loading db schema: %w", err) } diff --git a/node/node_integrity_test.go b/node/node_integrity_test.go new file mode 100644 index 0000000000..79a3a2017a --- /dev/null +++ b/node/node_integrity_test.go @@ -0,0 +1,99 @@ +package node + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + "go.uber.org/zap/zaptest" + + "github.com/spacemeshos/go-spacemesh/activation" + "github.com/spacemeshos/go-spacemesh/codec" + "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/config" + "github.com/spacemeshos/go-spacemesh/datastore" + "github.com/spacemeshos/go-spacemesh/hare3" + "github.com/spacemeshos/go-spacemesh/malfeasance" + "github.com/spacemeshos/go-spacemesh/malfeasance/wire" + "github.com/spacemeshos/go-spacemesh/mesh" + "github.com/spacemeshos/go-spacemesh/signing" + "github.com/spacemeshos/go-spacemesh/sql/builder" + "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/statesql" +) + +func TestCheckDBValidity(t *testing.T) { + t.Skip("This test is intended to be run manually to check the validity of malfeasance proofs in the database") + + db, err := statesql.Open("file:state.sql") + require.NoError(t, err) + + logger := zaptest.NewLogger(t) + cfg := config.MainnetConfig() + edVerifier := signing.NewEdVerifier( + signing.WithVerifierPrefix(cfg.Genesis.GenesisID().Bytes()), + ) + postVerifier, err := activation.NewPostVerifier( + cfg.POST, + logger.Named("post_verifier"), + ) + require.NoError(t, err) + + malfeasanceLogger := logger.Named("malfeasance") + activationMH := activation.NewMalfeasanceHandler( + db, + malfeasanceLogger, + edVerifier, + ) + meshMH := mesh.NewMalfeasanceHandler( + db, + edVerifier, + mesh.WithMalfeasanceLogger(malfeasanceLogger), + ) + hareMH := hare3.NewMalfeasanceHandler( + db, + edVerifier, + hare3.WithMalfeasanceLogger(malfeasanceLogger), + ) + invalidPostMH := activation.NewInvalidPostIndexHandler( + db, + edVerifier, + postVerifier, + ) + invalidPrevMH := activation.NewInvalidPrevATXHandler(db, edVerifier) + + nodeIDs := make([]types.NodeID, 0) + + ctrl := gomock.NewController(t) + trtl := malfeasance.NewMocktortoise(ctrl) + handler := malfeasance.NewHandler( + datastore.NewCachedDB(db, logger.Named("cached_db")), + malfeasanceLogger, + "self", + nodeIDs, + trtl, + ) + handler.RegisterHandler(malfeasance.MultipleATXs, activationMH) + handler.RegisterHandler(malfeasance.MultipleBallots, meshMH) + handler.RegisterHandler(malfeasance.HareEquivocation, hareMH) + handler.RegisterHandler(malfeasance.InvalidPostIndex, invalidPostMH) + handler.RegisterHandler(malfeasance.InvalidPrevATX, invalidPrevMH) + + i := 0 + err = identities.IterateOps(db, builder.Operations{}, func(nodeID types.NodeID, bytes []byte, _ time.Time) bool { + proof := &wire.MalfeasanceProof{} + err := codec.Decode(bytes, proof) + require.NoError(t, err) + + id, err := handler.Validate(context.Background(), proof) + require.NoError(t, err) + require.Equal(t, nodeID, id) + + t.Logf("Proof %d is valid for %s\n", i, nodeID.ShortString()) + i++ + return true + }) + require.NoError(t, err) +} diff --git a/p2p/pubsub/wrapper.go b/p2p/pubsub/wrapper.go index 480dad6058..0e52656dea 100644 --- a/p2p/pubsub/wrapper.go +++ b/p2p/pubsub/wrapper.go @@ -105,7 +105,7 @@ func (ps *GossipPubSub) Publish(ctx context.Context, topic string, msg []byte) e ps.logger.Sugar().Panicf("Publish is called before Register for topic %s", topic) } if err := topich.Publish(ctx, msg); err != nil { - return fmt.Errorf("failed to publish to topic %v: %w", topic, err) + return fmt.Errorf("failed to publish to topic %s: %w", topic, err) } return nil } diff --git a/sql/atxs/atxs.go b/sql/atxs/atxs.go index 21b3ed0815..a60707b17e 100644 --- a/sql/atxs/atxs.go +++ b/sql/atxs/atxs.go @@ -83,7 +83,7 @@ func Get(db sql.Executor, id types.ATXID) (*types.ActivationTx, error) { enc := func(stmt *sql.Statement) { stmt.BindBytes(1, id.Bytes()) } - q := fmt.Sprintf("%v where id =?1;", fullQuery) + q := fmt.Sprintf("%s where id =?1;", fullQuery) v, err := load(db, q, enc) if err != nil { return nil, fmt.Errorf("get id %s: %w", id.String(), err) diff --git a/sql/blocks/blocks.go b/sql/blocks/blocks.go index f3f17f43a1..a65120a7f2 100644 --- a/sql/blocks/blocks.go +++ b/sql/blocks/blocks.go @@ -58,7 +58,7 @@ func Has(db sql.Executor, id types.BlockID) (bool, error) { } // GetBlobSizes returns the sizes of the blobs corresponding to blocks with specified -// ids. For non-existent balots, the corresponding items are set to -1. +// ids. For non-existent ballots, the corresponding items are set to -1. func GetBlobSizes(db sql.Executor, ids [][]byte) (sizes []int, err error) { return sql.GetBlobSizes(db, "select id, length(block) from blocks where id in", ids) } diff --git a/sql/localsql/migrations/schema_test.go b/sql/localsql/migrations/schema_test.go new file mode 100644 index 0000000000..c94f2b549c --- /dev/null +++ b/sql/localsql/migrations/schema_test.go @@ -0,0 +1,22 @@ +package migrations + +import ( + "testing" + + "github.com/stretchr/testify/require" + "go.uber.org/zap/zaptest" + + "github.com/spacemeshos/go-spacemesh/sql" +) + +func TestCodedMigrations(t *testing.T) { + schema, err := SchemaWithInCodeMigrations() + require.NoError(t, err) + + db := sql.InMemory( + sql.WithDatabaseSchema(schema), + sql.WithLogger(zaptest.NewLogger(t)), + sql.WithForceMigrations(true), + ) + require.NotNil(t, db) +} diff --git a/sql/schemagen/main.go b/sql/schemagen/main.go index 6838f79046..c55d2b3ad4 100644 --- a/sql/schemagen/main.go +++ b/sql/schemagen/main.go @@ -8,8 +8,8 @@ import ( "go.uber.org/zap/zapcore" "github.com/spacemeshos/go-spacemesh/sql" - localmigrations "github.com/spacemeshos/go-spacemesh/sql/localsql/migrations" - statemigrations "github.com/spacemeshos/go-spacemesh/sql/statesql/migrations" + "github.com/spacemeshos/go-spacemesh/sql/localsql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) var ( @@ -32,9 +32,9 @@ func main() { logger := zap.New(core).With(zap.String("dbType", *dbType)) switch *dbType { case "state": - schema, err = statemigrations.SchemaWithInCodeMigrations() + schema, err = statesql.Schema() case "local": - schema, err = localmigrations.SchemaWithInCodeMigrations() + schema, err = localsql.Schema() default: logger.Fatal("unknown database type, must be state or local") } diff --git a/sql/statesql/migrations/interfaces.go b/sql/statesql/migrations/interfaces.go new file mode 100644 index 0000000000..c1580b8582 --- /dev/null +++ b/sql/statesql/migrations/interfaces.go @@ -0,0 +1,14 @@ +package migrations + +import ( + "context" + + "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/malfeasance/wire" +) + +//go:generate mockgen -typed -package=migrations -destination=./mocks.go -source=interfaces.go + +type malfeasanceValidator interface { + Validate(ctx context.Context, proof *wire.MalfeasanceProof) (types.NodeID, error) +} diff --git a/sql/statesql/migrations/mocks.go b/sql/statesql/migrations/mocks.go new file mode 100644 index 0000000000..c5924a43f6 --- /dev/null +++ b/sql/statesql/migrations/mocks.go @@ -0,0 +1,81 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: interfaces.go +// +// Generated by this command: +// +// mockgen -typed -package=migrations -destination=./mocks.go -source=interfaces.go +// + +// Package migrations is a generated GoMock package. +package migrations + +import ( + context "context" + reflect "reflect" + + types "github.com/spacemeshos/go-spacemesh/common/types" + wire "github.com/spacemeshos/go-spacemesh/malfeasance/wire" + gomock "go.uber.org/mock/gomock" +) + +// MockmalfeasanceValidator is a mock of malfeasanceValidator interface. +type MockmalfeasanceValidator struct { + ctrl *gomock.Controller + recorder *MockmalfeasanceValidatorMockRecorder +} + +// MockmalfeasanceValidatorMockRecorder is the mock recorder for MockmalfeasanceValidator. +type MockmalfeasanceValidatorMockRecorder struct { + mock *MockmalfeasanceValidator +} + +// NewMockmalfeasanceValidator creates a new mock instance. +func NewMockmalfeasanceValidator(ctrl *gomock.Controller) *MockmalfeasanceValidator { + mock := &MockmalfeasanceValidator{ctrl: ctrl} + mock.recorder = &MockmalfeasanceValidatorMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockmalfeasanceValidator) EXPECT() *MockmalfeasanceValidatorMockRecorder { + return m.recorder +} + +// Validate mocks base method. +func (m *MockmalfeasanceValidator) Validate(ctx context.Context, proof *wire.MalfeasanceProof) (types.NodeID, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Validate", ctx, proof) + ret0, _ := ret[0].(types.NodeID) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Validate indicates an expected call of Validate. +func (mr *MockmalfeasanceValidatorMockRecorder) Validate(ctx, proof any) *MockmalfeasanceValidatorValidateCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Validate", reflect.TypeOf((*MockmalfeasanceValidator)(nil).Validate), ctx, proof) + return &MockmalfeasanceValidatorValidateCall{Call: call} +} + +// MockmalfeasanceValidatorValidateCall wrap *gomock.Call +type MockmalfeasanceValidatorValidateCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockmalfeasanceValidatorValidateCall) Return(arg0 types.NodeID, arg1 error) *MockmalfeasanceValidatorValidateCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockmalfeasanceValidatorValidateCall) Do(f func(context.Context, *wire.MalfeasanceProof) (types.NodeID, error)) *MockmalfeasanceValidatorValidateCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockmalfeasanceValidatorValidateCall) DoAndReturn(f func(context.Context, *wire.MalfeasanceProof) (types.NodeID, error)) *MockmalfeasanceValidatorValidateCall { + c.Call = c.Call.DoAndReturn(f) + return c +} diff --git a/sql/statesql/migrations/schema.go b/sql/statesql/migrations/schema.go index ad4c9cbc01..bf9bf57b56 100644 --- a/sql/statesql/migrations/schema.go +++ b/sql/statesql/migrations/schema.go @@ -5,6 +5,9 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/statesql" ) -func SchemaWithInCodeMigrations() (*sql.Schema, error) { - return statesql.Schema(New0021Migration(1_000_000)) +func SchemaWithInCodeMigrations(malHandler malfeasanceValidator) (*sql.Schema, error) { + return statesql.Schema( + New0021Migration(1_000_000), + New0025Migration(malHandler), + ) } diff --git a/sql/statesql/migrations/schema_test.go b/sql/statesql/migrations/schema_test.go index c94f2b549c..938036b7ff 100644 --- a/sql/statesql/migrations/schema_test.go +++ b/sql/statesql/migrations/schema_test.go @@ -10,7 +10,7 @@ import ( ) func TestCodedMigrations(t *testing.T) { - schema, err := SchemaWithInCodeMigrations() + schema, err := SchemaWithInCodeMigrations(nil) require.NoError(t, err) db := sql.InMemory( diff --git a/sql/statesql/migrations/state_0025_migration.go b/sql/statesql/migrations/state_0025_migration.go new file mode 100644 index 0000000000..6d675981ba --- /dev/null +++ b/sql/statesql/migrations/state_0025_migration.go @@ -0,0 +1,100 @@ +package migrations + +import ( + "context" + "encoding/hex" + "time" + + "go.uber.org/zap" + + "github.com/spacemeshos/go-spacemesh/codec" + "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/log" + "github.com/spacemeshos/go-spacemesh/malfeasance/wire" + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/builder" + "github.com/spacemeshos/go-spacemesh/sql/identities" +) + +type migration0025 struct { + handler malfeasanceValidator +} + +var _ sql.Migration = &migration0025{} + +func New0025Migration(handler malfeasanceValidator) *migration0025 { + return &migration0025{ + handler: handler, + } +} + +func (*migration0025) Name() string { + return "check DB for invalid malfeasance proofs" +} + +func (*migration0025) Order() int { + return 25 +} + +func (*migration0025) Rollback() error { + return nil +} + +func (m *migration0025) Apply(db sql.Executor, logger *zap.Logger) error { + updates := map[types.NodeID][]byte{} + + err := identities.IterateOps(db, builder.Operations{}, func(nodeID types.NodeID, bytes []byte, t time.Time) bool { + proof := &wire.MalfeasanceProof{} + codec.MustDecode(bytes, proof) + + id, err := m.handler.Validate(context.Background(), proof) + if err == nil && id == nodeID { + logger.Debug("Proof is valid", log.ZShortStringer("smesherID", nodeID)) + return true + } + + if proof.Proof.Type != wire.InvalidPrevATX { + logger.Warn("Found invalid proof during migration that cannot be fixed", + log.ZShortStringer("smesherID", nodeID), + zap.String("proof", hex.EncodeToString(bytes)), + zap.Error(err), + ) + return true + } + + proof.Proof.Data.(*wire.InvalidPrevATXProof).Atx1.VRFNonce = nil + id, err = m.handler.Validate(context.Background(), proof) + if err == nil && id == nodeID { + updates[nodeID] = codec.MustEncode(proof) + return true + } + logger.Error("Failed to fix invalid malfeasance proof during migration", + log.ZShortStringer("smesherID", nodeID), + zap.String("proof", hex.EncodeToString(bytes)), + zap.Error(err), + ) + return true + }) + if err != nil { + return err + } + for nodeID, proofBytes := range updates { + if _, err := db.Exec(` + UPDATE identities + SET proof = ?2 + WHERE pubkey = ?1 + `, func(stmt *sql.Statement) { + stmt.BindBytes(1, nodeID.Bytes()) + stmt.BindBytes(2, proofBytes) + }, nil); err != nil { + logger.Error("Failed to update invalid proof", + log.ZShortStringer("smesherID", nodeID), + zap.Error(err), + ) + } + logger.Info("Fixed invalid proof during migration", + log.ZShortStringer("smesherID", nodeID), + ) + } + return nil +} diff --git a/sql/statesql/migrations/state_0025_migration_test.go b/sql/statesql/migrations/state_0025_migration_test.go new file mode 100644 index 0000000000..ba6b16e2a3 --- /dev/null +++ b/sql/statesql/migrations/state_0025_migration_test.go @@ -0,0 +1,217 @@ +package migrations + +import ( + "context" + "errors" + "slices" + "testing" + "time" + + "github.com/stretchr/testify/require" + gomock "go.uber.org/mock/gomock" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "go.uber.org/zap/zaptest" + "go.uber.org/zap/zaptest/observer" + + "github.com/spacemeshos/go-spacemesh/activation/wire" + "github.com/spacemeshos/go-spacemesh/codec" + "github.com/spacemeshos/go-spacemesh/common/types" + mwire "github.com/spacemeshos/go-spacemesh/malfeasance/wire" + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/statesql" +) + +func Test0025Migration(t *testing.T) { + setup := func(tb testing.TB) sql.StateDatabase { + schema, err := statesql.Schema() + require.NoError(tb, err) + schema.Migrations = slices.DeleteFunc(schema.Migrations, func(m sql.Migration) bool { + return m.Order() >= 25 + }) + db := statesql.InMemoryTest( + tb, + sql.WithDatabaseSchema(schema), + sql.WithNoCheckSchemaDrift(), + sql.WithForceMigrations(true), + ) + return db + } + + t.Run("valid proof is noop", func(t *testing.T) { + db := setup(t) + + // store valid proof + nodeID := types.RandomNodeID() + proof := &mwire.MalfeasanceProof{ + Proof: mwire.Proof{ + Type: mwire.MultipleATXs, + Data: &mwire.AtxProof{}, + }, + } + require.NoError(t, identities.SetMalicious(db, nodeID, codec.MustEncode(proof), time.Now())) + ctrl := gomock.NewController(t) + mHandler := NewMockmalfeasanceValidator(ctrl) + mHandler.EXPECT().Validate(context.Background(), proof).Return(nodeID, nil) + + m := New0025Migration(mHandler) + require.Equal(t, 25, m.Order()) + require.NoError(t, m.Apply(db, zaptest.NewLogger(t))) + }) + + t.Run("invalid proof is logged", func(t *testing.T) { + db := setup(t) + + // store invalid proof + nodeID := types.RandomNodeID() + proof := &mwire.MalfeasanceProof{ + Proof: mwire.Proof{ + Type: mwire.MultipleATXs, + Data: &mwire.AtxProof{}, + }, + } + require.NoError(t, identities.SetMalicious(db, nodeID, codec.MustEncode(proof), time.Now())) + ctrl := gomock.NewController(t) + mHandler := NewMockmalfeasanceValidator(ctrl) + mHandler.EXPECT().Validate(context.Background(), proof). + Return(types.EmptyNodeID, errors.New("invalid signature")) + + observer, observedLogs := observer.New(zapcore.WarnLevel) + logger := zaptest.NewLogger(t, zaptest.WrapOptions(zap.WrapCore( + func(core zapcore.Core) zapcore.Core { + return zapcore.NewTee(core, observer) + }, + ))) + + m := New0025Migration(mHandler) + require.Equal(t, 25, m.Order()) + require.NoError(t, m.Apply(db, logger)) + + require.Equal(t, 1, observedLogs.Len(), "expected 1 log message") + require.Equal(t, zapcore.WarnLevel, observedLogs.All()[0].Level) + require.Equal(t, "Found invalid proof during migration that cannot be fixed", observedLogs.All()[0].Message) + require.Equal(t, nodeID.ShortString(), observedLogs.All()[0].ContextMap()["smesherID"]) + require.Equal(t, "invalid signature", observedLogs.All()[0].ContextMap()["error"]) + }) + + t.Run("invalid proof is fixed", func(t *testing.T) { + db := setup(t) + + // store invalid proof + nonce := uint64(1337) + nodeID := types.RandomNodeID() + proof := &mwire.MalfeasanceProof{ + Proof: mwire.Proof{ + Type: mwire.InvalidPrevATX, + Data: &mwire.InvalidPrevATXProof{ + Atx1: wire.ActivationTxV1{ + InnerActivationTxV1: wire.InnerActivationTxV1{ + VRFNonce: &nonce, + }, + }, + Atx2: wire.ActivationTxV1{}, + }, + }, + } + require.NoError(t, identities.SetMalicious(db, nodeID, codec.MustEncode(proof), time.Now())) + ctrl := gomock.NewController(t) + mHandler := NewMockmalfeasanceValidator(ctrl) + + // first call to Validate returns error + mHandler.EXPECT().Validate(context.Background(), proof). + Return(types.EmptyNodeID, errors.New("invalid signature")) + + // second call to Validate returns valid nodeID + mHandler.EXPECT().Validate(context.Background(), gomock.Cond(func(x any) bool { + return x.(*mwire.MalfeasanceProof).Proof.Type == mwire.InvalidPrevATX && + x.(*mwire.MalfeasanceProof).Proof.Data.(*mwire.InvalidPrevATXProof).Atx1.VRFNonce == nil + })).Return(nodeID, nil) + + observer, observedLogs := observer.New(zapcore.InfoLevel) + logger := zaptest.NewLogger(t, zaptest.WrapOptions(zap.WrapCore( + func(core zapcore.Core) zapcore.Core { + return zapcore.NewTee(core, observer) + }, + ))) + + m := New0025Migration(mHandler) + require.Equal(t, 25, m.Order()) + require.NoError(t, m.Apply(db, logger)) + + require.Equal(t, 1, observedLogs.Len(), "expected 1 log message") + require.Equal(t, zapcore.InfoLevel, observedLogs.All()[0].Level) + require.Equal(t, "Fixed invalid proof during migration", observedLogs.All()[0].Message) + require.Equal(t, nodeID.ShortString(), observedLogs.All()[0].ContextMap()["smesherID"]) + + // check proof was updated + blob := &sql.Blob{} + err := identities.LoadMalfeasanceBlob(context.Background(), db, nodeID.Bytes(), blob) + require.NoError(t, err) + updatedProof := &mwire.MalfeasanceProof{} + codec.MustDecode(blob.Bytes, updatedProof) + + require.NotEqual(t, proof, updatedProof) + require.Nil(t, updatedProof.Proof.Data.(*mwire.InvalidPrevATXProof).Atx1.VRFNonce) + }) + + t.Run("invalid proof cannot be fixed", func(t *testing.T) { + db := setup(t) + + // store invalid proof + nonce := uint64(1337) + nodeID := types.RandomNodeID() + proof := &mwire.MalfeasanceProof{ + Proof: mwire.Proof{ + Type: mwire.InvalidPrevATX, + Data: &mwire.InvalidPrevATXProof{ + Atx1: wire.ActivationTxV1{ + InnerActivationTxV1: wire.InnerActivationTxV1{ + VRFNonce: &nonce, + }, + }, + Atx2: wire.ActivationTxV1{}, + }, + }, + } + require.NoError(t, identities.SetMalicious(db, nodeID, codec.MustEncode(proof), time.Now())) + ctrl := gomock.NewController(t) + mHandler := NewMockmalfeasanceValidator(ctrl) + + // first call to Validate returns error + mHandler.EXPECT().Validate(context.Background(), proof). + Return(types.EmptyNodeID, errors.New("invalid signature")) + + // second call to Validate still returns error + mHandler.EXPECT().Validate(context.Background(), gomock.Cond(func(x any) bool { + return x.(*mwire.MalfeasanceProof).Proof.Type == mwire.InvalidPrevATX && + x.(*mwire.MalfeasanceProof).Proof.Data.(*mwire.InvalidPrevATXProof).Atx1.VRFNonce == nil + })).Return(types.EmptyNodeID, errors.New("invalid signature")) + + observer, observedLogs := observer.New(zapcore.ErrorLevel) + logger := zaptest.NewLogger(t, zaptest.WrapOptions(zap.WrapCore( + func(core zapcore.Core) zapcore.Core { + return zapcore.NewTee(core, observer) + }, + ))) + + m := New0025Migration(mHandler) + require.Equal(t, 25, m.Order()) + require.NoError(t, m.Apply(db, logger)) + + require.Equal(t, 1, observedLogs.Len(), "expected 1 log message") + require.Equal(t, zapcore.ErrorLevel, observedLogs.All()[0].Level) + require.Equal(t, "Failed to fix invalid malfeasance proof during migration", observedLogs.All()[0].Message) + require.Equal(t, nodeID.ShortString(), observedLogs.All()[0].ContextMap()["smesherID"]) + require.Equal(t, "invalid signature", observedLogs.All()[0].ContextMap()["error"]) + + // check proof not updated + blob := &sql.Blob{} + err := identities.LoadMalfeasanceBlob(context.Background(), db, nodeID.Bytes(), blob) + require.NoError(t, err) + updatedProof := &mwire.MalfeasanceProof{} + codec.MustDecode(blob.Bytes, updatedProof) + + require.Equal(t, proof, updatedProof) + }) +} diff --git a/sql/statesql/schema/migrations/0017_atxs_prev_id_nonce_placeholder.sql b/sql/statesql/schema/migrations/0017_atxs_prev_id_nonce_placeholder.sql index 7e4a34351e..d39d54b6a2 100644 --- a/sql/statesql/schema/migrations/0017_atxs_prev_id_nonce_placeholder.sql +++ b/sql/statesql/schema/migrations/0017_atxs_prev_id_nonce_placeholder.sql @@ -1 +1 @@ --- Migration is done entirely in code +-- Migration is done entirely in code and doesn't change the schema. diff --git a/sql/statesql/schema/migrations/0025_check_malfeasance.sql b/sql/statesql/schema/migrations/0025_check_malfeasance.sql new file mode 100644 index 0000000000..d39d54b6a2 --- /dev/null +++ b/sql/statesql/schema/migrations/0025_check_malfeasance.sql @@ -0,0 +1 @@ +-- Migration is done entirely in code and doesn't change the schema. diff --git a/sql/statesql/schema/schema.sql b/sql/statesql/schema/schema.sql index c1cd6a427a..913b1f287f 100755 --- a/sql/statesql/schema/schema.sql +++ b/sql/statesql/schema/schema.sql @@ -1,4 +1,4 @@ -PRAGMA user_version = 24; +PRAGMA user_version = 25; CREATE TABLE accounts ( address CHAR(24),