From 7f0764c0c0499eea7333270782137a2bda0d4b37 Mon Sep 17 00:00:00 2001 From: Xinyuan Du Date: Mon, 11 Mar 2024 13:29:38 +0800 Subject: [PATCH] replace t.error/fatal with assert/request (raft_paper_test.go) Signed-off-by: Xinyuan Du --- raft_paper_test.go | 210 +++++++++++++-------------------------------- 1 file changed, 61 insertions(+), 149 deletions(-) diff --git a/raft_paper_test.go b/raft_paper_test.go index eff31f63..7936a276 100644 --- a/raft_paper_test.go +++ b/raft_paper_test.go @@ -28,10 +28,12 @@ package raft import ( "fmt" - "reflect" "sort" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + pb "go.etcd.io/raft/v3/raftpb" ) @@ -64,12 +66,8 @@ func testUpdateTermFromMessage(t *testing.T, state StateType) { r.Step(pb.Message{Type: pb.MsgApp, Term: 2}) - if r.Term != 2 { - t.Errorf("term = %d, want %d", r.Term, 2) - } - if r.state != StateFollower { - t.Errorf("state = %v, want %v", r.state, StateFollower) - } + assert.Equal(t, uint64(2), r.Term) + assert.Equal(t, StateFollower, r.state) } // TestRejectStaleTermMessage tests that if a server receives a request with @@ -88,18 +86,14 @@ func TestRejectStaleTermMessage(t *testing.T) { r.Step(pb.Message{Type: pb.MsgApp, Term: r.Term - 1}) - if called { - t.Errorf("stepFunc called = %v, want %v", called, false) - } + assert.False(t, called) } // TestStartAsFollower tests that when servers start up, they begin as followers. // Reference: section 5.2 func TestStartAsFollower(t *testing.T) { r := newTestRaft(1, 10, 1, newTestMemoryStorage(withPeers(1, 2, 3))) - if r.state != StateFollower { - t.Errorf("state = %s, want %s", r.state, StateFollower) - } + assert.Equal(t, StateFollower, r.state) } // TestLeaderBcastBeat tests that if the leader receives a heartbeat tick, @@ -122,13 +116,10 @@ func TestLeaderBcastBeat(t *testing.T) { msgs := r.readMessages() sort.Sort(messageSlice(msgs)) - wmsgs := []pb.Message{ + assert.Equal(t, []pb.Message{ {From: 1, To: 2, Term: 1, Type: pb.MsgHeartbeat}, {From: 1, To: 3, Term: 1, Type: pb.MsgHeartbeat}, - } - if !reflect.DeepEqual(msgs, wmsgs) { - t.Errorf("msgs = %v, want %v", msgs, wmsgs) - } + }, msgs) } func TestFollowerStartElection(t *testing.T) { @@ -164,24 +155,16 @@ func testNonleaderStartElection(t *testing.T, state StateType) { } r.advanceMessagesAfterAppend() - if r.Term != 2 { - t.Errorf("term = %d, want 2", r.Term) - } - if r.state != StateCandidate { - t.Errorf("state = %s, want %s", r.state, StateCandidate) - } - if !r.trk.Votes[r.id] { - t.Errorf("vote for self = false, want true") - } + assert.Equal(t, uint64(2), r.Term) + assert.Equal(t, StateCandidate, r.state) + assert.True(t, r.trk.Votes[r.id]) + msgs := r.readMessages() sort.Sort(messageSlice(msgs)) - wmsgs := []pb.Message{ + assert.Equal(t, []pb.Message{ {From: 1, To: 2, Term: 2, Type: pb.MsgVote}, {From: 1, To: 3, Term: 2, Type: pb.MsgVote}, - } - if !reflect.DeepEqual(msgs, wmsgs) { - t.Errorf("msgs = %v, want %v", msgs, wmsgs) - } + }, msgs) } // TestLeaderElectionInOneRoundRPC tests all cases that may happen in @@ -224,12 +207,8 @@ func TestLeaderElectionInOneRoundRPC(t *testing.T) { r.Step(pb.Message{From: id, To: 1, Term: r.Term, Type: pb.MsgVoteResp, Reject: !vote}) } - if r.state != tt.state { - t.Errorf("#%d: state = %s, want %s", i, r.state, tt.state) - } - if g := r.Term; g != 1 { - t.Errorf("#%d: term = %d, want %d", i, g, 1) - } + assert.Equal(t, tt.state, r.state, "#%d", i) + assert.Equal(t, uint64(1), r.Term, "#%d", i) } } @@ -255,13 +234,9 @@ func TestFollowerVote(t *testing.T) { r.Step(pb.Message{From: tt.nvote, To: 1, Term: 1, Type: pb.MsgVote}) - msgs := r.msgsAfterAppend - wmsgs := []pb.Message{ + assert.Equal(t, []pb.Message{ {From: 1, To: tt.nvote, Term: 1, Type: pb.MsgVoteResp, Reject: tt.wreject}, - } - if !reflect.DeepEqual(msgs, wmsgs) { - t.Errorf("#%d: msgs = %v, want %v", i, msgs, wmsgs) - } + }, r.msgsAfterAppend, "#%d", i) } } @@ -278,18 +253,12 @@ func TestCandidateFallback(t *testing.T) { for i, tt := range tests { r := newTestRaft(1, 10, 1, newTestMemoryStorage(withPeers(1, 2, 3))) r.Step(pb.Message{From: 1, To: 1, Type: pb.MsgHup}) - if r.state != StateCandidate { - t.Fatalf("unexpected state = %s, want %s", r.state, StateCandidate) - } + require.Equal(t, StateCandidate, r.state, "#%d", i) r.Step(tt) - if g := r.state; g != StateFollower { - t.Errorf("#%d: state = %s, want %s", i, g, StateFollower) - } - if g := r.Term; g != tt.Term { - t.Errorf("#%d: term = %d, want %d", i, g, tt.Term) - } + assert.Equal(t, StateFollower, r.state, "#%d", i) + assert.Equal(t, tt.Term, r.Term, "#%d", i) } } @@ -328,9 +297,7 @@ func testNonleaderElectionTimeoutRandomized(t *testing.T, state StateType) { } for d := et; d < 2*et; d++ { - if !timeouts[d] { - t.Errorf("timeout in %d ticks should happen", d) - } + assert.True(t, timeouts[d], "timeout in %d ticks should happen", d) } } @@ -383,9 +350,7 @@ func testNonleadersElectionTimeoutNonconflict(t *testing.T, state StateType) { } } - if g := float64(conflicts) / 1000; g > 0.3 { - t.Errorf("probability of conflicts = %v, want <= 0.3", g) - } + assert.LessOrEqual(t, float64(conflicts)/1000, 0.3) } // TestLeaderStartReplication tests that when receiving client proposals, @@ -407,25 +372,18 @@ func TestLeaderStartReplication(t *testing.T) { ents := []pb.Entry{{Data: []byte("some data")}} r.Step(pb.Message{From: 1, To: 1, Type: pb.MsgProp, Entries: ents}) - if g := r.raftLog.lastIndex(); g != li+1 { - t.Errorf("lastIndex = %d, want %d", g, li+1) - } - if g := r.raftLog.committed; g != li { - t.Errorf("committed = %d, want %d", g, li) - } + assert.Equal(t, li+1, r.raftLog.lastIndex()) + assert.Equal(t, li, r.raftLog.committed) msgs := r.readMessages() sort.Sort(messageSlice(msgs)) wents := []pb.Entry{{Index: li + 1, Term: 1, Data: []byte("some data")}} - wmsgs := []pb.Message{ + assert.Equal(t, []pb.Message{ {From: 1, To: 2, Term: 1, Type: pb.MsgApp, Index: li, LogTerm: 1, Entries: wents, Commit: li}, {From: 1, To: 3, Term: 1, Type: pb.MsgApp, Index: li, LogTerm: 1, Entries: wents, Commit: li}, - } - if !reflect.DeepEqual(msgs, wmsgs) { - t.Errorf("msgs = %+v, want %+v", msgs, wmsgs) - } - if g := r.raftLog.nextUnstableEnts(); !reflect.DeepEqual(g, wents) { - t.Errorf("ents = %+v, want %+v", g, wents) - } + }, msgs) + assert.Equal(t, []pb.Entry{ + {Index: li + 1, Term: 1, Data: []byte("some data")}, + }, r.raftLog.nextUnstableEnts()) } // TestLeaderCommitEntry tests that when the entry has been safely replicated, @@ -448,25 +406,16 @@ func TestLeaderCommitEntry(t *testing.T) { r.Step(acceptAndReply(m)) } - if g := r.raftLog.committed; g != li+1 { - t.Errorf("committed = %d, want %d", g, li+1) - } - wents := []pb.Entry{{Index: li + 1, Term: 1, Data: []byte("some data")}} - if g := r.raftLog.nextCommittedEnts(true); !reflect.DeepEqual(g, wents) { - t.Errorf("nextCommittedEnts = %+v, want %+v", g, wents) - } + assert.Equal(t, li+1, r.raftLog.committed) + assert.Equal(t, []pb.Entry{ + {Index: li + 1, Term: 1, Data: []byte("some data")}, + }, r.raftLog.nextCommittedEnts(true)) msgs := r.readMessages() sort.Sort(messageSlice(msgs)) for i, m := range msgs { - if w := uint64(i + 2); m.To != w { - t.Errorf("to = %x, want %x", m.To, w) - } - if m.Type != pb.MsgApp { - t.Errorf("type = %v, want %v", m.Type, pb.MsgApp) - } - if m.Commit != li+1 { - t.Errorf("commit = %d, want %d", m.Commit, li+1) - } + assert.Equal(t, uint64(i+2), m.To) + assert.Equal(t, pb.MsgApp, m.Type) + assert.Equal(t, li+1, m.Commit) } } @@ -504,9 +453,7 @@ func TestLeaderAcknowledgeCommit(t *testing.T) { } } - if g := r.raftLog.committed > li; g != tt.wack { - t.Errorf("#%d: ack commit = %v, want %v", i, g, tt.wack) - } + assert.Equal(t, tt.wack, r.raftLog.committed > li, "#%d", i) } } @@ -536,10 +483,10 @@ func TestLeaderCommitPrecedingEntries(t *testing.T) { } li := uint64(len(tt)) - wents := append(tt, pb.Entry{Term: 3, Index: li + 1}, pb.Entry{Term: 3, Index: li + 2, Data: []byte("some data")}) - if g := r.raftLog.nextCommittedEnts(true); !reflect.DeepEqual(g, wents) { - t.Errorf("#%d: ents = %+v, want %+v", i, g, wents) - } + assert.Equal(t, append(tt, + pb.Entry{Term: 3, Index: li + 1}, + pb.Entry{Term: 3, Index: li + 2, Data: []byte("some data")}, + ), r.raftLog.nextCommittedEnts(true), "#%d", i) } } @@ -585,13 +532,8 @@ func TestFollowerCommitEntry(t *testing.T) { r.Step(pb.Message{From: 2, To: 1, Type: pb.MsgApp, Term: 1, Entries: tt.ents, Commit: tt.commit}) - if g := r.raftLog.committed; g != tt.commit { - t.Errorf("#%d: committed = %d, want %d", i, g, tt.commit) - } - wents := tt.ents[:int(tt.commit)] - if g := r.raftLog.nextCommittedEnts(true); !reflect.DeepEqual(g, wents) { - t.Errorf("#%d: nextCommittedEnts = %v, want %v", i, g, wents) - } + assert.Equal(t, tt.commit, r.raftLog.committed, "#%d", i) + assert.Equal(t, tt.ents[:int(tt.commit)], r.raftLog.nextCommittedEnts(true), "#%d", i) } } @@ -630,13 +572,9 @@ func TestFollowerCheckMsgApp(t *testing.T) { r.Step(pb.Message{From: 2, To: 1, Type: pb.MsgApp, Term: 2, LogTerm: tt.term, Index: tt.index}) - msgs := r.readMessages() - wmsgs := []pb.Message{ + assert.Equal(t, []pb.Message{ {From: 1, To: 2, Type: pb.MsgAppResp, Term: 2, Index: tt.windex, Reject: tt.wreject, RejectHint: tt.wrejectHint, LogTerm: tt.wlogterm}, - } - if !reflect.DeepEqual(msgs, wmsgs) { - t.Errorf("#%d: msgs = %+v, want %+v", i, msgs, wmsgs) - } + }, r.readMessages(), "#%d", i) } } @@ -685,12 +623,8 @@ func TestFollowerAppendEntries(t *testing.T) { r.Step(pb.Message{From: 2, To: 1, Type: pb.MsgApp, Term: 2, LogTerm: tt.term, Index: tt.index, Entries: tt.ents}) - if g := r.raftLog.allEntries(); !reflect.DeepEqual(g, tt.wents) { - t.Errorf("#%d: ents = %+v, want %+v", i, g, tt.wents) - } - if g := r.raftLog.nextUnstableEnts(); !reflect.DeepEqual(g, tt.wunstable) { - t.Errorf("#%d: unstableEnts = %+v, want %+v", i, g, tt.wunstable) - } + assert.Equal(t, tt.wents, r.raftLog.allEntries(), "#%d", i) + assert.Equal(t, tt.wunstable, r.raftLog.nextUnstableEnts(), "#%d", i) } } @@ -727,9 +661,7 @@ func TestLeaderSyncFollowerLog(t *testing.T) { n.send(pb.Message{From: 1, To: 1, Type: pb.MsgProp, Entries: []pb.Entry{{}}}) - if g := diffu(ltoa(lead.raftLog), ltoa(follower.raftLog)); g != "" { - t.Errorf("#%d: log diff:\n%s", i, g) - } + assert.Empty(t, diffu(ltoa(lead.raftLog), ltoa(follower.raftLog)), "#%d", i) } } @@ -757,26 +689,14 @@ func TestVoteRequest(t *testing.T) { msgs := r.readMessages() sort.Sort(messageSlice(msgs)) - if len(msgs) != 2 { - t.Fatalf("#%d: len(msg) = %d, want %d", j, len(msgs), 2) - } + require.Len(t, msgs, 2, "#%d", j) for i, m := range msgs { - if m.Type != pb.MsgVote { - t.Errorf("#%d: msgType = %d, want %d", i, m.Type, pb.MsgVote) - } - if m.To != uint64(i+2) { - t.Errorf("#%d: to = %d, want %d", i, m.To, i+2) - } - if m.Term != tt.wterm { - t.Errorf("#%d: term = %d, want %d", i, m.Term, tt.wterm) - } - windex, wlogterm := tt.ents[len(tt.ents)-1].Index, tt.ents[len(tt.ents)-1].Term - if m.Index != windex { - t.Errorf("#%d: index = %d, want %d", i, m.Index, windex) - } - if m.LogTerm != wlogterm { - t.Errorf("#%d: logterm = %d, want %d", i, m.LogTerm, wlogterm) - } + assert.Equal(t, pb.MsgVote, m.Type, "#%d.%d", j, i) + assert.Equal(t, uint64(i+2), m.To, "#%d.%d", j, i) + assert.Equal(t, tt.wterm, m.Term, "#%d.%d", j, i) + + assert.Equal(t, tt.ents[len(tt.ents)-1].Index, m.Index, "#%d.%d", j, i) + assert.Equal(t, tt.ents[len(tt.ents)-1].Term, m.LogTerm, "#%d.%d", j, i) } } } @@ -814,16 +734,10 @@ func TestVoter(t *testing.T) { r.Step(pb.Message{From: 2, To: 1, Type: pb.MsgVote, Term: 3, LogTerm: tt.logterm, Index: tt.index}) msgs := r.readMessages() - if len(msgs) != 1 { - t.Fatalf("#%d: len(msg) = %d, want %d", i, len(msgs), 1) - } + require.Len(t, msgs, 1, "#%d", i) m := msgs[0] - if m.Type != pb.MsgVoteResp { - t.Errorf("#%d: msgType = %d, want %d", i, m.Type, pb.MsgVoteResp) - } - if m.Reject != tt.wreject { - t.Errorf("#%d: reject = %t, want %t", i, m.Reject, tt.wreject) - } + assert.Equal(t, pb.MsgVoteResp, m.Type, "#%d", i) + assert.Equal(t, tt.wreject, m.Reject, "#%d", i) } } @@ -856,9 +770,7 @@ func TestLeaderOnlyCommitsLogFromCurrentTerm(t *testing.T) { r.Step(pb.Message{From: 2, To: 1, Type: pb.MsgAppResp, Term: r.Term, Index: tt.index}) r.advanceMessagesAfterAppend() - if r.raftLog.committed != tt.wcommit { - t.Errorf("#%d: commit = %d, want %d", i, r.raftLog.committed, tt.wcommit) - } + assert.Equal(t, tt.wcommit, r.raftLog.committed, "#%d", i) } }