diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 8d251eeab848..96fcdfb052d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -459,6 +459,8 @@ object LimitPushDown extends Rule[LogicalPlan] { val newJoin = joinType match { case RightOuter => join.copy(right = maybePushLocalLimit(exp, right)) case LeftOuter => join.copy(left = maybePushLocalLimit(exp, left)) + case Cross => join.copy(left = maybePushLocalLimit(exp, left), + right = maybePushLocalLimit(exp, right)) case _ => join } LocalLimit(exp, newJoin) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala index 17fb9fc5d11e..792c782fbc5b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.Add -import org.apache.spark.sql.catalyst.plans.{FullOuter, LeftOuter, PlanTest, RightOuter} +import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, LeftOuter, PlanTest, RightOuter} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -171,4 +171,25 @@ class LimitPushdownSuite extends PlanTest { // No pushdown for FULL OUTER JOINS. comparePlans(optimized, originalQuery) } + + test("cross join") { + val originalQuery = x.join(y, Cross).limit(1) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(1, LocalLimit(1, x).join(LocalLimit(1, y), Cross)).analyze + comparePlans(optimized, correctAnswer) + } + + test("cross join and left sides are limited") { + val originalQuery = x.limit(2).join(y, Cross).limit(1) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(1, LocalLimit(1, x).join(LocalLimit(1, y), Cross)).analyze + comparePlans(optimized, correctAnswer) + } + + test("cross join and right sides are limited") { + val originalQuery = x.join(y.limit(2), Cross).limit(1) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(1, LocalLimit(1, x).join(LocalLimit(1, y), Cross)).analyze + comparePlans(optimized, correctAnswer) + } }