From 1b3223e59f8a1e3927cd15148b1bd1f68911ccd9 Mon Sep 17 00:00:00 2001 From: Manan Gupta <35839558+GuptaManan100@users.noreply.github.com> Date: Sat, 28 Dec 2024 15:03:16 +0530 Subject: [PATCH] Fix Data race in semi-join (#17417) Signed-off-by: Manan Gupta --- .../endtoend/vtgate/queries/misc/misc_test.go | 20 +++++ go/vt/vtgate/engine/fake_primitive_test.go | 7 +- go/vt/vtgate/engine/semi_join.go | 13 +-- go/vt/vtgate/engine/semi_join_test.go | 79 +++++++++++++++++++ 4 files changed, 112 insertions(+), 7 deletions(-) diff --git a/go/test/endtoend/vtgate/queries/misc/misc_test.go b/go/test/endtoend/vtgate/queries/misc/misc_test.go index d57345ed0e6..699f04518d7 100644 --- a/go/test/endtoend/vtgate/queries/misc/misc_test.go +++ b/go/test/endtoend/vtgate/queries/misc/misc_test.go @@ -400,3 +400,23 @@ func TestHandleNullableColumn(t *testing.T) { // tbl.nonunq_col is not nullable according to the schema, but because of the left join, it can be NULL mcmp.ExecWithColumnCompare(`select * from t1 left join tbl on t1.id2 = tbl.id where t1.id1 = 6 or tbl.nonunq_col = 6`) } + +// TestSemiJoin tests that the semi join works as intended. +func TestSemiJoin(t *testing.T) { + mcmp, closer := start(t) + defer closer() + + for i := 1; i <= 1000; i++ { + mcmp.Exec(fmt.Sprintf("insert into t1(id1, id2) values (%d, %d)", i, 2*i)) + mcmp.Exec(fmt.Sprintf("insert into tbl(id, unq_col, nonunq_col) values (%d, %d, %d)", i, 2*i, 3*i)) + } + + // Test that the semi join works as intended + for _, mode := range []string{"oltp", "olap"} { + mcmp.Run(mode, func(mcmp *utils.MySQLCompare) { + utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = %s", mode)) + + mcmp.Exec("select id1, id2 from t1 where exists (select id from tbl where nonunq_col = t1.id2) order by id1") + }) + } +} diff --git a/go/vt/vtgate/engine/fake_primitive_test.go b/go/vt/vtgate/engine/fake_primitive_test.go index e992c2a4623..d614cf52d56 100644 --- a/go/vt/vtgate/engine/fake_primitive_test.go +++ b/go/vt/vtgate/engine/fake_primitive_test.go @@ -39,7 +39,8 @@ type fakePrimitive struct { // sendErr is sent at the end of the stream if it's set. sendErr error - log []string + noLog bool + log []string allResultsInOneCall bool @@ -84,7 +85,9 @@ func (f *fakePrimitive) TryExecute(ctx context.Context, vcursor VCursor, bindVar } func (f *fakePrimitive) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { - f.log = append(f.log, fmt.Sprintf("StreamExecute %v %v", printBindVars(bindVars), wantfields)) + if !f.noLog { + f.log = append(f.log, fmt.Sprintf("StreamExecute %v %v", printBindVars(bindVars), wantfields)) + } if f.results == nil { return f.sendErr } diff --git a/go/vt/vtgate/engine/semi_join.go b/go/vt/vtgate/engine/semi_join.go index 8ab0465249c..06c6a9aa94a 100644 --- a/go/vt/vtgate/engine/semi_join.go +++ b/go/vt/vtgate/engine/semi_join.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "strings" + "sync/atomic" "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" @@ -72,24 +73,26 @@ func (jn *SemiJoin) TryExecute(ctx context.Context, vcursor VCursor, bindVars ma // TryStreamExecute performs a streaming exec. func (jn *SemiJoin) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { - joinVars := make(map[string]*querypb.BindVariable) err := vcursor.StreamExecutePrimitive(ctx, jn.Left, bindVars, wantfields, func(lresult *sqltypes.Result) error { + joinVars := make(map[string]*querypb.BindVariable) result := &sqltypes.Result{Fields: projectFields(lresult.Fields, jn.Cols)} for _, lrow := range lresult.Rows { for k, col := range jn.Vars { joinVars[k] = sqltypes.ValueBindVariable(lrow[col]) } - rowAdded := false + var rowAdded atomic.Bool err := vcursor.StreamExecutePrimitive(ctx, jn.Right, combineVars(bindVars, joinVars), false, func(rresult *sqltypes.Result) error { - if len(rresult.Rows) > 0 && !rowAdded { - result.Rows = append(result.Rows, projectRows(lrow, jn.Cols)) - rowAdded = true + if len(rresult.Rows) > 0 { + rowAdded.Store(true) } return nil }) if err != nil { return err } + if rowAdded.Load() { + result.Rows = append(result.Rows, projectRows(lrow, jn.Cols)) + } } return callback(result) }) diff --git a/go/vt/vtgate/engine/semi_join_test.go b/go/vt/vtgate/engine/semi_join_test.go index 9cf55d4f78f..a7086a5ff86 100644 --- a/go/vt/vtgate/engine/semi_join_test.go +++ b/go/vt/vtgate/engine/semi_join_test.go @@ -18,6 +18,7 @@ package engine import ( "context" + "sync" "testing" "vitess.io/vitess/go/test/utils" @@ -161,3 +162,81 @@ func TestSemiJoinStreamExecute(t *testing.T) { "4|d|dd", )) } + +// TestSemiJoinStreamExecuteParallelExecution tests SemiJoin stream execution with parallel execution +// to ensure we have no data races. +func TestSemiJoinStreamExecuteParallelExecution(t *testing.T) { + leftPrim := &fakePrimitive{ + results: []*sqltypes.Result{ + sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + "col1|col2|col3", + "int64|varchar|varchar", + ), + "1|a|aa", + "2|b|bb", + ), sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + "col1|col2|col3", + "int64|varchar|varchar", + ), + "3|c|cc", + "4|d|dd", + ), + }, + async: true, + } + rightFields := sqltypes.MakeTestFields( + "col4|col5|col6", + "int64|varchar|varchar", + ) + rightPrim := &fakePrimitive{ + // we'll return non-empty results for rows 2 and 4 + results: sqltypes.MakeTestStreamingResults(rightFields, + "4|d|dd", + "---", + "---", + "5|e|ee", + "6|f|ff", + "7|g|gg", + ), + async: true, + noLog: true, + } + + jn := &SemiJoin{ + Left: leftPrim, + Right: rightPrim, + Vars: map[string]int{ + "bv": 1, + }, + } + var res *sqltypes.Result + var mu sync.Mutex + err := jn.TryStreamExecute(context.Background(), &noopVCursor{}, map[string]*querypb.BindVariable{}, true, func(result *sqltypes.Result) error { + mu.Lock() + defer mu.Unlock() + if res == nil { + res = result + } else { + res.Rows = append(res.Rows, result.Rows...) + } + return nil + }) + require.NoError(t, err) + leftPrim.ExpectLog(t, []string{ + `StreamExecute true`, + }) + // We'll get all the rows back in left primitive, since we're returning the same set of rows + // from the right primitive that makes them all qualify. + expectResultAnyOrder(t, res, sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + "col1|col2|col3", + "int64|varchar|varchar", + ), + "1|a|aa", + "2|b|bb", + "3|c|cc", + "4|d|dd", + )) +}