diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 2d03fbfb0d31..1636b33b6213 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -569,7 +569,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), left, right) joinType match { - case Inner => + case _ @ (Inner | LeftSemi) => // push down the single side only join filter for both sides sub queries val newLeft = leftJoinConditions. reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) @@ -577,7 +577,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) val newJoinCond = commonJoinCondition.reduceLeftOption(And) - Join(newLeft, newRight, Inner, newJoinCond) + Join(newLeft, newRight, joinType, newJoinCond) case RightOuter => // push down the left side only join filter for left side sub query val newLeft = leftJoinConditions. @@ -586,14 +586,14 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { val newJoinCond = (rightJoinConditions ++ commonJoinCondition).reduceLeftOption(And) Join(newLeft, newRight, RightOuter, newJoinCond) - case _ @ (LeftOuter | LeftSemi) => + case LeftOuter => // push down the right side only join filter for right sub query val newLeft = left val newRight = rightJoinConditions. reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) val newJoinCond = (leftJoinConditions ++ commonJoinCondition).reduceLeftOption(And) - Join(newLeft, newRight, joinType, newJoinCond) + Join(newLeft, newRight, LeftOuter, newJoinCond) case FullOuter => f } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index aa9708b164ef..67f3ac2e990b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries import org.apache.spark.sql.catalyst.expressions.{Count, Explode} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.plans.{PlanTest, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.plans.{LeftSemi, PlanTest, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -43,6 +43,8 @@ class FilterPushdownSuite extends PlanTest { val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val testRelation1 = LocalRelation('d.int) + // This test already passes. test("eliminate subqueries") { val originalQuery = @@ -197,6 +199,23 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("joins: push down left semi join") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + val originalQuery = { + x.join(y, LeftSemi, Option("x.a".attr === "y.d".attr && "x.b".attr >= 1 && "y.d".attr >= 2)) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.where('b >= 1) + val right = testRelation1.where('d >= 2) + val correctAnswer = + left.join(right, LeftSemi, Option("a".attr === "d".attr)).analyze + + comparePlans(optimized, correctAnswer) + } + test("joins: push down left outer join #1") { val x = testRelation.subquery('x) val y = testRelation.subquery('y)