Skip to content

Commit

Permalink
replace t.error/fatal with assert/request (raft_paper_test.go)
Browse files Browse the repository at this point in the history
Signed-off-by: Xinyuan Du <[email protected]>
  • Loading branch information
MrDXY committed Mar 11, 2024
1 parent 60482e0 commit 7f0764c
Showing 1 changed file with 61 additions and 149 deletions.
210 changes: 61 additions & 149 deletions raft_paper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@ package raft

import (
"fmt"
"reflect"
"sort"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

pb "go.etcd.io/raft/v3/raftpb"
)

Expand Down Expand Up @@ -64,12 +66,8 @@ func testUpdateTermFromMessage(t *testing.T, state StateType) {

r.Step(pb.Message{Type: pb.MsgApp, Term: 2})

if r.Term != 2 {
t.Errorf("term = %d, want %d", r.Term, 2)
}
if r.state != StateFollower {
t.Errorf("state = %v, want %v", r.state, StateFollower)
}
assert.Equal(t, uint64(2), r.Term)
assert.Equal(t, StateFollower, r.state)
}

// TestRejectStaleTermMessage tests that if a server receives a request with
Expand All @@ -88,18 +86,14 @@ func TestRejectStaleTermMessage(t *testing.T) {

r.Step(pb.Message{Type: pb.MsgApp, Term: r.Term - 1})

if called {
t.Errorf("stepFunc called = %v, want %v", called, false)
}
assert.False(t, called)
}

// TestStartAsFollower tests that when servers start up, they begin as followers.
// Reference: section 5.2
func TestStartAsFollower(t *testing.T) {
r := newTestRaft(1, 10, 1, newTestMemoryStorage(withPeers(1, 2, 3)))
if r.state != StateFollower {
t.Errorf("state = %s, want %s", r.state, StateFollower)
}
assert.Equal(t, StateFollower, r.state)
}

// TestLeaderBcastBeat tests that if the leader receives a heartbeat tick,
Expand All @@ -122,13 +116,10 @@ func TestLeaderBcastBeat(t *testing.T) {

msgs := r.readMessages()
sort.Sort(messageSlice(msgs))
wmsgs := []pb.Message{
assert.Equal(t, []pb.Message{
{From: 1, To: 2, Term: 1, Type: pb.MsgHeartbeat},
{From: 1, To: 3, Term: 1, Type: pb.MsgHeartbeat},
}
if !reflect.DeepEqual(msgs, wmsgs) {
t.Errorf("msgs = %v, want %v", msgs, wmsgs)
}
}, msgs)
}

func TestFollowerStartElection(t *testing.T) {
Expand Down Expand Up @@ -164,24 +155,16 @@ func testNonleaderStartElection(t *testing.T, state StateType) {
}
r.advanceMessagesAfterAppend()

if r.Term != 2 {
t.Errorf("term = %d, want 2", r.Term)
}
if r.state != StateCandidate {
t.Errorf("state = %s, want %s", r.state, StateCandidate)
}
if !r.trk.Votes[r.id] {
t.Errorf("vote for self = false, want true")
}
assert.Equal(t, uint64(2), r.Term)
assert.Equal(t, StateCandidate, r.state)
assert.True(t, r.trk.Votes[r.id])

msgs := r.readMessages()
sort.Sort(messageSlice(msgs))
wmsgs := []pb.Message{
assert.Equal(t, []pb.Message{
{From: 1, To: 2, Term: 2, Type: pb.MsgVote},
{From: 1, To: 3, Term: 2, Type: pb.MsgVote},
}
if !reflect.DeepEqual(msgs, wmsgs) {
t.Errorf("msgs = %v, want %v", msgs, wmsgs)
}
}, msgs)
}

// TestLeaderElectionInOneRoundRPC tests all cases that may happen in
Expand Down Expand Up @@ -224,12 +207,8 @@ func TestLeaderElectionInOneRoundRPC(t *testing.T) {
r.Step(pb.Message{From: id, To: 1, Term: r.Term, Type: pb.MsgVoteResp, Reject: !vote})
}

if r.state != tt.state {
t.Errorf("#%d: state = %s, want %s", i, r.state, tt.state)
}
if g := r.Term; g != 1 {
t.Errorf("#%d: term = %d, want %d", i, g, 1)
}
assert.Equal(t, tt.state, r.state, "#%d", i)
assert.Equal(t, uint64(1), r.Term, "#%d", i)
}
}

Expand All @@ -255,13 +234,9 @@ func TestFollowerVote(t *testing.T) {

r.Step(pb.Message{From: tt.nvote, To: 1, Term: 1, Type: pb.MsgVote})

msgs := r.msgsAfterAppend
wmsgs := []pb.Message{
assert.Equal(t, []pb.Message{
{From: 1, To: tt.nvote, Term: 1, Type: pb.MsgVoteResp, Reject: tt.wreject},
}
if !reflect.DeepEqual(msgs, wmsgs) {
t.Errorf("#%d: msgs = %v, want %v", i, msgs, wmsgs)
}
}, r.msgsAfterAppend, "#%d", i)
}
}

Expand All @@ -278,18 +253,12 @@ func TestCandidateFallback(t *testing.T) {
for i, tt := range tests {
r := newTestRaft(1, 10, 1, newTestMemoryStorage(withPeers(1, 2, 3)))
r.Step(pb.Message{From: 1, To: 1, Type: pb.MsgHup})
if r.state != StateCandidate {
t.Fatalf("unexpected state = %s, want %s", r.state, StateCandidate)
}
require.Equal(t, StateCandidate, r.state, "#%d", i)

r.Step(tt)

if g := r.state; g != StateFollower {
t.Errorf("#%d: state = %s, want %s", i, g, StateFollower)
}
if g := r.Term; g != tt.Term {
t.Errorf("#%d: term = %d, want %d", i, g, tt.Term)
}
assert.Equal(t, StateFollower, r.state, "#%d", i)
assert.Equal(t, tt.Term, r.Term, "#%d", i)
}
}

Expand Down Expand Up @@ -328,9 +297,7 @@ func testNonleaderElectionTimeoutRandomized(t *testing.T, state StateType) {
}

for d := et; d < 2*et; d++ {
if !timeouts[d] {
t.Errorf("timeout in %d ticks should happen", d)
}
assert.True(t, timeouts[d], "timeout in %d ticks should happen", d)
}
}

Expand Down Expand Up @@ -383,9 +350,7 @@ func testNonleadersElectionTimeoutNonconflict(t *testing.T, state StateType) {
}
}

if g := float64(conflicts) / 1000; g > 0.3 {
t.Errorf("probability of conflicts = %v, want <= 0.3", g)
}
assert.LessOrEqual(t, float64(conflicts)/1000, 0.3)
}

// TestLeaderStartReplication tests that when receiving client proposals,
Expand All @@ -407,25 +372,18 @@ func TestLeaderStartReplication(t *testing.T) {
ents := []pb.Entry{{Data: []byte("some data")}}
r.Step(pb.Message{From: 1, To: 1, Type: pb.MsgProp, Entries: ents})

if g := r.raftLog.lastIndex(); g != li+1 {
t.Errorf("lastIndex = %d, want %d", g, li+1)
}
if g := r.raftLog.committed; g != li {
t.Errorf("committed = %d, want %d", g, li)
}
assert.Equal(t, li+1, r.raftLog.lastIndex())
assert.Equal(t, li, r.raftLog.committed)
msgs := r.readMessages()
sort.Sort(messageSlice(msgs))
wents := []pb.Entry{{Index: li + 1, Term: 1, Data: []byte("some data")}}
wmsgs := []pb.Message{
assert.Equal(t, []pb.Message{
{From: 1, To: 2, Term: 1, Type: pb.MsgApp, Index: li, LogTerm: 1, Entries: wents, Commit: li},
{From: 1, To: 3, Term: 1, Type: pb.MsgApp, Index: li, LogTerm: 1, Entries: wents, Commit: li},
}
if !reflect.DeepEqual(msgs, wmsgs) {
t.Errorf("msgs = %+v, want %+v", msgs, wmsgs)
}
if g := r.raftLog.nextUnstableEnts(); !reflect.DeepEqual(g, wents) {
t.Errorf("ents = %+v, want %+v", g, wents)
}
}, msgs)
assert.Equal(t, []pb.Entry{
{Index: li + 1, Term: 1, Data: []byte("some data")},
}, r.raftLog.nextUnstableEnts())
}

// TestLeaderCommitEntry tests that when the entry has been safely replicated,
Expand All @@ -448,25 +406,16 @@ func TestLeaderCommitEntry(t *testing.T) {
r.Step(acceptAndReply(m))
}

if g := r.raftLog.committed; g != li+1 {
t.Errorf("committed = %d, want %d", g, li+1)
}
wents := []pb.Entry{{Index: li + 1, Term: 1, Data: []byte("some data")}}
if g := r.raftLog.nextCommittedEnts(true); !reflect.DeepEqual(g, wents) {
t.Errorf("nextCommittedEnts = %+v, want %+v", g, wents)
}
assert.Equal(t, li+1, r.raftLog.committed)
assert.Equal(t, []pb.Entry{
{Index: li + 1, Term: 1, Data: []byte("some data")},
}, r.raftLog.nextCommittedEnts(true))
msgs := r.readMessages()
sort.Sort(messageSlice(msgs))
for i, m := range msgs {
if w := uint64(i + 2); m.To != w {
t.Errorf("to = %x, want %x", m.To, w)
}
if m.Type != pb.MsgApp {
t.Errorf("type = %v, want %v", m.Type, pb.MsgApp)
}
if m.Commit != li+1 {
t.Errorf("commit = %d, want %d", m.Commit, li+1)
}
assert.Equal(t, uint64(i+2), m.To)
assert.Equal(t, pb.MsgApp, m.Type)
assert.Equal(t, li+1, m.Commit)
}
}

Expand Down Expand Up @@ -504,9 +453,7 @@ func TestLeaderAcknowledgeCommit(t *testing.T) {
}
}

if g := r.raftLog.committed > li; g != tt.wack {
t.Errorf("#%d: ack commit = %v, want %v", i, g, tt.wack)
}
assert.Equal(t, tt.wack, r.raftLog.committed > li, "#%d", i)
}
}

Expand Down Expand Up @@ -536,10 +483,10 @@ func TestLeaderCommitPrecedingEntries(t *testing.T) {
}

li := uint64(len(tt))
wents := append(tt, pb.Entry{Term: 3, Index: li + 1}, pb.Entry{Term: 3, Index: li + 2, Data: []byte("some data")})
if g := r.raftLog.nextCommittedEnts(true); !reflect.DeepEqual(g, wents) {
t.Errorf("#%d: ents = %+v, want %+v", i, g, wents)
}
assert.Equal(t, append(tt,
pb.Entry{Term: 3, Index: li + 1},
pb.Entry{Term: 3, Index: li + 2, Data: []byte("some data")},
), r.raftLog.nextCommittedEnts(true), "#%d", i)
}
}

Expand Down Expand Up @@ -585,13 +532,8 @@ func TestFollowerCommitEntry(t *testing.T) {

r.Step(pb.Message{From: 2, To: 1, Type: pb.MsgApp, Term: 1, Entries: tt.ents, Commit: tt.commit})

if g := r.raftLog.committed; g != tt.commit {
t.Errorf("#%d: committed = %d, want %d", i, g, tt.commit)
}
wents := tt.ents[:int(tt.commit)]
if g := r.raftLog.nextCommittedEnts(true); !reflect.DeepEqual(g, wents) {
t.Errorf("#%d: nextCommittedEnts = %v, want %v", i, g, wents)
}
assert.Equal(t, tt.commit, r.raftLog.committed, "#%d", i)
assert.Equal(t, tt.ents[:int(tt.commit)], r.raftLog.nextCommittedEnts(true), "#%d", i)
}
}

Expand Down Expand Up @@ -630,13 +572,9 @@ func TestFollowerCheckMsgApp(t *testing.T) {

r.Step(pb.Message{From: 2, To: 1, Type: pb.MsgApp, Term: 2, LogTerm: tt.term, Index: tt.index})

msgs := r.readMessages()
wmsgs := []pb.Message{
assert.Equal(t, []pb.Message{
{From: 1, To: 2, Type: pb.MsgAppResp, Term: 2, Index: tt.windex, Reject: tt.wreject, RejectHint: tt.wrejectHint, LogTerm: tt.wlogterm},
}
if !reflect.DeepEqual(msgs, wmsgs) {
t.Errorf("#%d: msgs = %+v, want %+v", i, msgs, wmsgs)
}
}, r.readMessages(), "#%d", i)
}
}

Expand Down Expand Up @@ -685,12 +623,8 @@ func TestFollowerAppendEntries(t *testing.T) {

r.Step(pb.Message{From: 2, To: 1, Type: pb.MsgApp, Term: 2, LogTerm: tt.term, Index: tt.index, Entries: tt.ents})

if g := r.raftLog.allEntries(); !reflect.DeepEqual(g, tt.wents) {
t.Errorf("#%d: ents = %+v, want %+v", i, g, tt.wents)
}
if g := r.raftLog.nextUnstableEnts(); !reflect.DeepEqual(g, tt.wunstable) {
t.Errorf("#%d: unstableEnts = %+v, want %+v", i, g, tt.wunstable)
}
assert.Equal(t, tt.wents, r.raftLog.allEntries(), "#%d", i)
assert.Equal(t, tt.wunstable, r.raftLog.nextUnstableEnts(), "#%d", i)
}
}

Expand Down Expand Up @@ -727,9 +661,7 @@ func TestLeaderSyncFollowerLog(t *testing.T) {

n.send(pb.Message{From: 1, To: 1, Type: pb.MsgProp, Entries: []pb.Entry{{}}})

if g := diffu(ltoa(lead.raftLog), ltoa(follower.raftLog)); g != "" {
t.Errorf("#%d: log diff:\n%s", i, g)
}
assert.Empty(t, diffu(ltoa(lead.raftLog), ltoa(follower.raftLog)), "#%d", i)
}
}

Expand Down Expand Up @@ -757,26 +689,14 @@ func TestVoteRequest(t *testing.T) {

msgs := r.readMessages()
sort.Sort(messageSlice(msgs))
if len(msgs) != 2 {
t.Fatalf("#%d: len(msg) = %d, want %d", j, len(msgs), 2)
}
require.Len(t, msgs, 2, "#%d", j)
for i, m := range msgs {
if m.Type != pb.MsgVote {
t.Errorf("#%d: msgType = %d, want %d", i, m.Type, pb.MsgVote)
}
if m.To != uint64(i+2) {
t.Errorf("#%d: to = %d, want %d", i, m.To, i+2)
}
if m.Term != tt.wterm {
t.Errorf("#%d: term = %d, want %d", i, m.Term, tt.wterm)
}
windex, wlogterm := tt.ents[len(tt.ents)-1].Index, tt.ents[len(tt.ents)-1].Term
if m.Index != windex {
t.Errorf("#%d: index = %d, want %d", i, m.Index, windex)
}
if m.LogTerm != wlogterm {
t.Errorf("#%d: logterm = %d, want %d", i, m.LogTerm, wlogterm)
}
assert.Equal(t, pb.MsgVote, m.Type, "#%d.%d", j, i)
assert.Equal(t, uint64(i+2), m.To, "#%d.%d", j, i)
assert.Equal(t, tt.wterm, m.Term, "#%d.%d", j, i)

assert.Equal(t, tt.ents[len(tt.ents)-1].Index, m.Index, "#%d.%d", j, i)
assert.Equal(t, tt.ents[len(tt.ents)-1].Term, m.LogTerm, "#%d.%d", j, i)
}
}
}
Expand Down Expand Up @@ -814,16 +734,10 @@ func TestVoter(t *testing.T) {
r.Step(pb.Message{From: 2, To: 1, Type: pb.MsgVote, Term: 3, LogTerm: tt.logterm, Index: tt.index})

msgs := r.readMessages()
if len(msgs) != 1 {
t.Fatalf("#%d: len(msg) = %d, want %d", i, len(msgs), 1)
}
require.Len(t, msgs, 1, "#%d", i)
m := msgs[0]
if m.Type != pb.MsgVoteResp {
t.Errorf("#%d: msgType = %d, want %d", i, m.Type, pb.MsgVoteResp)
}
if m.Reject != tt.wreject {
t.Errorf("#%d: reject = %t, want %t", i, m.Reject, tt.wreject)
}
assert.Equal(t, pb.MsgVoteResp, m.Type, "#%d", i)
assert.Equal(t, tt.wreject, m.Reject, "#%d", i)
}
}

Expand Down Expand Up @@ -856,9 +770,7 @@ func TestLeaderOnlyCommitsLogFromCurrentTerm(t *testing.T) {

r.Step(pb.Message{From: 2, To: 1, Type: pb.MsgAppResp, Term: r.Term, Index: tt.index})
r.advanceMessagesAfterAppend()
if r.raftLog.committed != tt.wcommit {
t.Errorf("#%d: commit = %d, want %d", i, r.raftLog.committed, tt.wcommit)
}
assert.Equal(t, tt.wcommit, r.raftLog.committed, "#%d", i)
}
}

Expand Down

0 comments on commit 7f0764c

Please sign in to comment.