Skip to content

Commit

Permalink
simplify logic
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Taylor <[email protected]>
  • Loading branch information
systay committed Dec 20, 2024
1 parent 4565dca commit b2c1f9c
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 24 deletions.
8 changes: 4 additions & 4 deletions go/sqltypes/result.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func (result *Result) Copy() *Result {
out := &Result{
RowsAffected: result.RowsAffected,
InsertID: result.InsertID,
InsertIDChanged: result.InsertIDUpdated(),
InsertIDChanged: result.InsertIDChanged,
SessionStateChanges: result.SessionStateChanges,
StatusFlags: result.StatusFlags,
Info: result.Info,
Expand Down Expand Up @@ -132,7 +132,7 @@ func (result *Result) Metadata() *Result {
return &Result{
Fields: result.Fields,
InsertID: result.InsertID,
InsertIDChanged: result.InsertIDUpdated(),
InsertIDChanged: result.InsertIDChanged,
RowsAffected: result.RowsAffected,
Info: result.Info,
SessionStateChanges: result.SessionStateChanges,
Expand All @@ -157,7 +157,7 @@ func (result *Result) Truncate(l int) *Result {

out := &Result{
InsertID: result.InsertID,
InsertIDChanged: result.InsertIDUpdated(),
InsertIDChanged: result.InsertIDChanged,
RowsAffected: result.RowsAffected,
Info: result.Info,
SessionStateChanges: result.SessionStateChanges,
Expand Down Expand Up @@ -333,8 +333,8 @@ func (result *Result) AppendResult(src *Result) {
result.RowsAffected += src.RowsAffected
if src.InsertIDUpdated() {
result.InsertID = src.InsertID
result.InsertIDChanged = true
}
result.InsertIDChanged = result.InsertIDUpdated() || src.InsertIDUpdated()
if result.Fields == nil {
result.Fields = src.Fields
}
Expand Down
35 changes: 15 additions & 20 deletions go/sqltypes/result_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,17 @@ import (

"github.com/stretchr/testify/assert"

"vitess.io/vitess/go/test/utils"

querypb "vitess.io/vitess/go/vt/proto/query"
)

func assertEqualResults(t *testing.T, a, b *Result) {
t.Helper()
assert.Truef(t, a.Equal(b), "Results are not equal: \n%v\n%v", a, b)
if !a.Equal(b) {
t.Errorf("Results are not equal: %v %v", a, b)
}
}

func TestRepair(t *testing.T) {
fields := []*querypb.Field{{
Type: Int64,
Expand All @@ -45,9 +51,7 @@ func TestRepair(t *testing.T) {
},
}
in.Repair(fields)
if !in.Equal(want) {
t.Errorf("Repair:\n%#v, want\n%#v", in, want)
}
assertEqualResults(t, in, want)
}

func TestCopy(t *testing.T) {
Expand All @@ -67,7 +71,7 @@ func TestCopy(t *testing.T) {
},
}
out := in.Copy()
utils.MustMatch(t, in, out)
assertEqualResults(t, in, out)
}

func TestTruncate(t *testing.T) {
Expand All @@ -88,9 +92,7 @@ func TestTruncate(t *testing.T) {
}

out := in.Truncate(0)
if !out.Equal(in) {
t.Errorf("Truncate(0):\n%v, want\n%v", out, in)
}
assertEqualResults(t, in, out)

out = in.Truncate(1)
want := &Result{
Expand All @@ -106,9 +108,7 @@ func TestTruncate(t *testing.T) {
{TestValue(Int64, "3")},
},
}
if !out.Equal(want) {
t.Errorf("Truncate(1):\n%v, want\n%v", out, want)
}
assertEqualResults(t, out, want)
}

func TestStripMetaData(t *testing.T) {
Expand Down Expand Up @@ -286,17 +286,15 @@ func TestStripMetaData(t *testing.T) {
t.Run(tcase.name, func(t *testing.T) {
inCopy := tcase.in.Copy()
out := inCopy.StripMetadata(tcase.includedFields)
if !out.Equal(tcase.expected) {
t.Errorf("StripMetaData unexpected result for %v: %v", tcase.name, out)
}
assertEqualResults(t, out, tcase.expected)
if len(tcase.in.Fields) > 0 {
// check the out array is different than the in array.
if out.Fields[0] == inCopy.Fields[0] && tcase.includedFields != querypb.ExecuteOptions_ALL {
t.Errorf("StripMetaData modified original Field for %v", tcase.name)
}
}
// check we didn't change the original result.
utils.MustMatch(t, tcase.in, inCopy)
assertEqualResults(t, tcase.in, inCopy)
})
}
}
Expand Down Expand Up @@ -348,10 +346,7 @@ func TestAppendResult(t *testing.T) {
}

result.AppendResult(src)

if !result.Equal(want) {
t.Errorf("Got:\n%#v, want:\n%#v", result, want)
}
assertEqualResults(t, result, want)
}

func TestReplaceKeyspace(t *testing.T) {
Expand Down

0 comments on commit b2c1f9c

Please sign in to comment.