Skip to content

Commit

Permalink
Cherry-pick c25802d with conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
vitess-bot[bot] authored and vitess-bot committed Dec 28, 2024
1 parent 722cbbd commit 3fd2c7f
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 7 deletions.
69 changes: 69 additions & 0 deletions go/test/endtoend/vtgate/queries/misc/misc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -480,3 +480,72 @@ func TestEnumSetVals(t *testing.T) {
mcmp.AssertMatches("select id, enum_col, cast(enum_col as signed) from tbl_enum_set order by enum_col, id", `[[INT64(4) ENUM("xsmall") INT64(1)] [INT64(2) ENUM("small") INT64(2)] [INT64(1) ENUM("medium") INT64(3)] [INT64(5) ENUM("medium") INT64(3)] [INT64(3) ENUM("large") INT64(4)]]`)
mcmp.AssertMatches("select id, set_col, cast(set_col as unsigned) from tbl_enum_set order by set_col, id", `[[INT64(4) SET("a,b") UINT64(3)] [INT64(3) SET("c") UINT64(4)] [INT64(5) SET("a,d") UINT64(9)] [INT64(1) SET("a,b,e") UINT64(19)] [INT64(2) SET("e,f,g") UINT64(112)]]`)
}
<<<<<<< HEAD
=======

func TestTimeZones(t *testing.T) {
testCases := []struct {
name string
targetTZ string
expectedDiff time.Duration
}{
{"UTC to +08:00", "+08:00", 8 * time.Hour},
{"UTC to -08:00", "-08:00", -8 * time.Hour},
{"UTC to +05:30", "+05:30", 5*time.Hour + 30*time.Minute},
{"UTC to -05:45", "-05:45", -(5*time.Hour + 45*time.Minute)},
{"UTC to +09:00", "+09:00", 9 * time.Hour},
{"UTC to -12:00", "-12:00", -12 * time.Hour},
}

// Connect to Vitess
conn, err := mysql.Connect(context.Background(), &vtParams)
require.NoError(t, err)

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Set the initial time zone and get the time
utils.Exec(t, conn, "set time_zone = '+00:00'")
rs1 := utils.Exec(t, conn, "select now()")

// Set the target time zone and get the time
utils.Exec(t, conn, fmt.Sprintf("set time_zone = '%s'", tc.targetTZ))
rs2 := utils.Exec(t, conn, "select now()")

// Parse the times from the query result
layout := "2006-01-02 15:04:05" // MySQL default datetime format
time1, err := time.Parse(layout, rs1.Rows[0][0].ToString())
require.NoError(t, err)
time2, err := time.Parse(layout, rs2.Rows[0][0].ToString())
require.NoError(t, err)

// Calculate the actual difference between time2 and time1
actualDiff := time2.Sub(time1)
allowableDeviation := time.Second // allow up to 1-second difference

// Use a range to allow for slight variations
require.InDeltaf(t, tc.expectedDiff.Seconds(), actualDiff.Seconds(), allowableDeviation.Seconds(),
"time2 should be approximately %v after time1, within 1 second tolerance\n%v vs %v", tc.expectedDiff, time1, time2)
})
}
}

// TestSemiJoin tests that the semi join works as intended.
func TestSemiJoin(t *testing.T) {
mcmp, closer := start(t)
defer closer()

for i := 1; i <= 1000; i++ {
mcmp.Exec(fmt.Sprintf("insert into t1(id1, id2) values (%d, %d)", i, 2*i))
mcmp.Exec(fmt.Sprintf("insert into tbl(id, unq_col, nonunq_col) values (%d, %d, %d)", i, 2*i, 3*i))
}

// Test that the semi join works as intended
for _, mode := range []string{"oltp", "olap"} {
mcmp.Run(mode, func(mcmp *utils.MySQLCompare) {
utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = %s", mode))

mcmp.Exec("select id1, id2 from t1 where exists (select id from tbl where nonunq_col = t1.id2) order by id1")
})
}
}
>>>>>>> c25802da2a (Fix Data race in semi-join (#17417))
7 changes: 5 additions & 2 deletions go/vt/vtgate/engine/fake_primitive_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ type fakePrimitive struct {
// sendErr is sent at the end of the stream if it's set.
sendErr error

log []string
noLog bool
log []string

allResultsInOneCall bool

Expand Down Expand Up @@ -85,7 +86,9 @@ func (f *fakePrimitive) TryExecute(ctx context.Context, vcursor VCursor, bindVar
}

func (f *fakePrimitive) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
f.log = append(f.log, fmt.Sprintf("StreamExecute %v %v", printBindVars(bindVars), wantfields))
if !f.noLog {
f.log = append(f.log, fmt.Sprintf("StreamExecute %v %v", printBindVars(bindVars), wantfields))
}
if f.results == nil {
return f.sendErr
}
Expand Down
13 changes: 8 additions & 5 deletions go/vt/vtgate/engine/semi_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package engine

import (
"context"
"sync/atomic"

"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
Expand Down Expand Up @@ -62,24 +63,26 @@ func (jn *SemiJoin) TryExecute(ctx context.Context, vcursor VCursor, bindVars ma

// TryStreamExecute performs a streaming exec.
func (jn *SemiJoin) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
joinVars := make(map[string]*querypb.BindVariable)
err := vcursor.StreamExecutePrimitive(ctx, jn.Left, bindVars, wantfields, func(lresult *sqltypes.Result) error {
joinVars := make(map[string]*querypb.BindVariable)
result := &sqltypes.Result{Fields: lresult.Fields}
for _, lrow := range lresult.Rows {
for k, col := range jn.Vars {
joinVars[k] = sqltypes.ValueBindVariable(lrow[col])
}
rowAdded := false
var rowAdded atomic.Bool
err := vcursor.StreamExecutePrimitive(ctx, jn.Right, combineVars(bindVars, joinVars), false, func(rresult *sqltypes.Result) error {
if len(rresult.Rows) > 0 && !rowAdded {
result.Rows = append(result.Rows, lrow)
rowAdded = true
if len(rresult.Rows) > 0 {
rowAdded.Store(true)
}
return nil
})
if err != nil {
return err
}
if rowAdded.Load() {
result.Rows = append(result.Rows, lrow)
}
}
return callback(result)
})
Expand Down
79 changes: 79 additions & 0 deletions go/vt/vtgate/engine/semi_join_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package engine

import (
"context"
"sync"
"testing"

"vitess.io/vitess/go/test/utils"
Expand Down Expand Up @@ -159,3 +160,81 @@ func TestSemiJoinStreamExecute(t *testing.T) {
"4|d|dd",
))
}

// TestSemiJoinStreamExecuteParallelExecution tests SemiJoin stream execution with parallel execution
// to ensure we have no data races.
func TestSemiJoinStreamExecuteParallelExecution(t *testing.T) {
leftPrim := &fakePrimitive{
results: []*sqltypes.Result{
sqltypes.MakeTestResult(
sqltypes.MakeTestFields(
"col1|col2|col3",
"int64|varchar|varchar",
),
"1|a|aa",
"2|b|bb",
), sqltypes.MakeTestResult(
sqltypes.MakeTestFields(
"col1|col2|col3",
"int64|varchar|varchar",
),
"3|c|cc",
"4|d|dd",
),
},
async: true,
}
rightFields := sqltypes.MakeTestFields(
"col4|col5|col6",
"int64|varchar|varchar",
)
rightPrim := &fakePrimitive{
// we'll return non-empty results for rows 2 and 4
results: sqltypes.MakeTestStreamingResults(rightFields,
"4|d|dd",
"---",
"---",
"5|e|ee",
"6|f|ff",
"7|g|gg",
),
async: true,
noLog: true,
}

jn := &SemiJoin{
Left: leftPrim,
Right: rightPrim,
Vars: map[string]int{
"bv": 1,
},
}
var res *sqltypes.Result
var mu sync.Mutex
err := jn.TryStreamExecute(context.Background(), &noopVCursor{}, map[string]*querypb.BindVariable{}, true, func(result *sqltypes.Result) error {
mu.Lock()
defer mu.Unlock()
if res == nil {
res = result
} else {
res.Rows = append(res.Rows, result.Rows...)
}
return nil
})
require.NoError(t, err)
leftPrim.ExpectLog(t, []string{
`StreamExecute true`,
})
// We'll get all the rows back in left primitive, since we're returning the same set of rows
// from the right primitive that makes them all qualify.
expectResultAnyOrder(t, res, sqltypes.MakeTestResult(
sqltypes.MakeTestFields(
"col1|col2|col3",
"int64|varchar|varchar",
),
"1|a|aa",
"2|b|bb",
"3|c|cc",
"4|d|dd",
))
}

0 comments on commit 3fd2c7f

Please sign in to comment.