diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 46a90f600b2a..b08187d0bc3b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -502,7 +502,7 @@ object RemoveNoopOperators extends Rule[LogicalPlan] { } /** - * Pushes down [[LocalLimit]] beneath UNION ALL and beneath the streamed inputs of outer joins. + * Pushes down [[LocalLimit]] beneath UNION ALL and joins. */ object LimitPushDown extends Rule[LogicalPlan] { @@ -539,12 +539,16 @@ object LimitPushDown extends Rule[LogicalPlan] { // pushdown Limit. case LocalLimit(exp, u: Union) => LocalLimit(exp, u.copy(children = u.children.map(maybePushLocalLimit(exp, _)))) - // Add extra limits below JOIN. For LEFT OUTER and RIGHT OUTER JOIN we push limits to - // the left and right sides, respectively. For INNER and CROSS JOIN we push limits to - // both the left and right sides if join condition is empty. It's not safe to push limits - // below FULL OUTER JOIN in the general case without a more invasive rewrite. - // We also need to ensure that this limit pushdown rule will not eventually introduce limits - // on both sides if it is applied multiple times. Therefore: + // Add extra limits below JOIN: + // 1. For LEFT OUTER and RIGHT OUTER JOIN, we push limits to the left and right sides, + // respectively. + // 2. For INNER and CROSS JOIN, we push limits to both the left and right sides if join + // condition is empty. + // 3. For LEFT SEMI and LEFT ANTI JOIN, we push limits to the left side if join condition + // is empty. + // It's not safe to push limits below FULL OUTER JOIN in the general case without a more + // invasive rewrite. We also need to ensure that this limit pushdown rule will not eventually + // introduce limits on both sides if it is applied multiple times. Therefore: // - If one side is already limited, stack another limit on top if the new limit is smaller. // The redundant limit will be collapsed by the CombineLimits rule. case LocalLimit(exp, join @ Join(left, right, joinType, conditionOpt, _)) => @@ -555,6 +559,8 @@ object LimitPushDown extends Rule[LogicalPlan] { join.copy( left = maybePushLocalLimit(exp, left), right = maybePushLocalLimit(exp, right)) + case LeftSemi | LeftAnti if conditionOpt.isEmpty => + join.copy(left = maybePushLocalLimit(exp, left)) case _ => join } LocalLimit(exp, newJoin) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala index 5c760264ff21..7a33b5b4b53d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.Add -import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftOuter, PlanTest, RightOuter} +import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, PlanTest, RightOuter} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -212,4 +212,22 @@ class LimitPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } } + + test("SPARK-34514: Push down limit through LEFT SEMI and LEFT ANTI join") { + // Push down when condition is empty + Seq(LeftSemi, LeftAnti).foreach { joinType => + val originalQuery = x.join(y, joinType).limit(1) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(1, LocalLimit(1, x).join(y, joinType)).analyze + comparePlans(optimized, correctAnswer) + } + + // No push down when condition is not empty + Seq(LeftSemi, LeftAnti).foreach { joinType => + val originalQuery = x.join(y, joinType, Some("x.a".attr === "y.b".attr)).limit(1) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(1, x.join(y, joinType, Some("x.a".attr === "y.b".attr))).analyze + comparePlans(optimized, correctAnswer) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index fe8a080ac5ae..82c49f9cbf29 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -4034,6 +4034,36 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark checkAnswer(df, Row(0, 0) :: Row(0, 1) :: Row(0, 2) :: Nil) } } + + test("SPARK-34514: Push down limit through LEFT SEMI and LEFT ANTI join") { + withTable("left_table", "nonempty_right_table", "empty_right_table") { + spark.range(5).toDF().repartition(1).write.saveAsTable("left_table") + spark.range(3).write.saveAsTable("nonempty_right_table") + spark.range(0).write.saveAsTable("empty_right_table") + Seq("LEFT SEMI", "LEFT ANTI").foreach { joinType => + val joinWithNonEmptyRightDf = spark.sql( + s"SELECT * FROM left_table $joinType JOIN nonempty_right_table LIMIT 3") + val joinWithEmptyRightDf = spark.sql( + s"SELECT * FROM left_table $joinType JOIN empty_right_table LIMIT 3") + + Seq(joinWithNonEmptyRightDf, joinWithEmptyRightDf).foreach { df => + val pushedLocalLimits = df.queryExecution.optimizedPlan.collect { + case l @ LocalLimit(_, _: LogicalRelation) => l + } + assert(pushedLocalLimits.length === 1) + } + + val expectedAnswer = Seq(Row(0), Row(1), Row(2)) + if (joinType == "LEFT SEMI") { + checkAnswer(joinWithNonEmptyRightDf, expectedAnswer) + checkAnswer(joinWithEmptyRightDf, Seq.empty) + } else { + checkAnswer(joinWithNonEmptyRightDf, Seq.empty) + checkAnswer(joinWithEmptyRightDf, expectedAnswer) + } + } + } + } } case class Foo(bar: Option[String])