diff --git a/velox/exec/MergeJoin.cpp b/velox/exec/MergeJoin.cpp index 33de24908ccc..e7bb64c15f4a 100644 --- a/velox/exec/MergeJoin.cpp +++ b/velox/exec/MergeJoin.cpp @@ -35,8 +35,10 @@ MergeJoin::MergeJoin( numKeys_{joinNode->leftKeys().size()}, joinNode_(joinNode) { VELOX_USER_CHECK( - joinNode_->isInnerJoin() || joinNode_->isLeftJoin(), - "Merge join supports only inner and left joins. Other join types are not supported yet."); + joinNode_->isInnerJoin() || joinNode_->isLeftJoin() || + joinNode_->isLeftSemiFilterJoin() || + joinNode_->isRightSemiFilterJoin(), + "Merge join supports only inner, left and left semi joins. Other join types are not supported yet."); } void MergeJoin::initialize() { @@ -64,6 +66,12 @@ void MergeJoin::initialize() { } } + if (joinNode_->isRightSemiFilterJoin()) { + VELOX_USER_CHECK( + leftProjections_.empty(), + "The left side projections should be empty for right semi join"); + } + for (auto i = 0; i < rightType->size(); ++i) { auto name = rightType->nameOf(i); auto outIndex = outputType_->getChildIdxIfExists(name); @@ -72,6 +80,12 @@ void MergeJoin::initialize() { } } + if (joinNode_->isLeftSemiFilterJoin()) { + VELOX_USER_CHECK( + rightProjections_.empty(), + "The right side projections should be empty for left semi join"); + } + if (joinNode_->filter()) { initializeFilter(joinNode_->filter(), leftType, rightType); @@ -383,6 +397,17 @@ bool MergeJoin::addToOutput() { auto rightEnd = r == numRights - 1 ? rightMatch_->endIndex : right->size(); + // TODO: Since semi joins only require determining if there is at least + // one match on the other side, we could explore specialized algorithms + // or data structures that short-circuit the join process once a match + // is found. + if (isLeftSemiFilterJoin(joinType_) || + isRightSemiFilterJoin(joinType_)) { + // LeftSemiFilter produce each row from the left at most once. + // RightSemiFilter produce each row from the right at most once. + rightEnd = rightStart + 1; + } + for (auto j = rightStart; j < rightEnd; ++j) { if (outputSize_ == outputBatchSize_) { leftMatch_->setCursor(l, i); diff --git a/velox/exec/fuzzer/JoinFuzzer.cpp b/velox/exec/fuzzer/JoinFuzzer.cpp index f85c4cc26b5e..dea22aaf61fe 100644 --- a/velox/exec/fuzzer/JoinFuzzer.cpp +++ b/velox/exec/fuzzer/JoinFuzzer.cpp @@ -963,7 +963,8 @@ void JoinFuzzer::makeAlternativePlans( .planNode()}); // Use OrderBy + MergeJoin (if join type is inner or left). - if (joinNode->isInnerJoin() || joinNode->isLeftJoin()) { + if (joinNode->isInnerJoin() || joinNode->isLeftJoin() || + joinNode->isLeftSemiFilterJoin() || joinNode->isRightSemiFilterJoin()) { auto planWithSplits = makeMergeJoinPlan( joinType, probeKeys, buildKeys, probeInput, buildInput, outputColumns); plans.push_back(planWithSplits); diff --git a/velox/exec/tests/MergeJoinTest.cpp b/velox/exec/tests/MergeJoinTest.cpp index aeeab5867e82..8ec158050caa 100644 --- a/velox/exec/tests/MergeJoinTest.cpp +++ b/velox/exec/tests/MergeJoinTest.cpp @@ -498,6 +498,49 @@ TEST_F(MergeJoinTest, lazyVectors) { "SELECT c0, rc0, c1, rc1, c2, c3 FROM t, u WHERE t.c0 = u.rc0 and c1 + rc1 < 30"); } +TEST_F(MergeJoinTest, semiJoin) { + auto left = makeRowVector( + {"t0"}, {makeNullableFlatVector({1, 2, 2, 6, std::nullopt})}); + + auto right = makeRowVector( + {"u0"}, + {makeNullableFlatVector( + {1, 2, 2, 7, std::nullopt, std::nullopt})}); + + createDuckDbTable("t", {left}); + createDuckDbTable("u", {right}); + + auto testSemiJoin = [&](const std::string& filter, + const std::string& sql, + const std::vector& outputLayout, + core::JoinType joinType) { + auto planNodeIdGenerator = std::make_shared(); + auto plan = + PlanBuilder(planNodeIdGenerator) + .values({left}) + .mergeJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator).values({right}).planNode(), + filter, + outputLayout, + joinType) + .planNode(); + AssertQueryBuilder(plan, duckDbQueryRunner_).assertResults(sql); + }; + + testSemiJoin( + "t0 >1", + "SELECT t0 FROM t where t0 IN (SELECT u0 from u) and t0 > 1", + {"t0"}, + core::JoinType::kLeftSemiFilter); + testSemiJoin( + "u0 > 1", + "SELECT u0 FROM u where u0 IN (SELECT t0 from t) and u0 > 1", + {"u0"}, + core::JoinType::kRightSemiFilter); +} + TEST_F(MergeJoinTest, nullKeys) { auto left = makeRowVector( {"t0"}, {makeNullableFlatVector({1, 2, 5, std::nullopt})});