diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 910294853c31..21eba0a941c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -225,21 +225,23 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { && muchSmaller(right, left) || !RowOrdering.isOrderable(leftKeys) => Seq(joins.ShuffledHashJoinExec( - leftKeys, rightKeys, joinType, BuildRight, condition, planLater(left), planLater(right))) + leftKeys, rightKeys, leftKeys, rightKeys, joinType, BuildRight, condition, + planLater(left), planLater(right))) case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) if !conf.preferSortMergeJoin && canBuildLeft(joinType) && canBuildLocalHashMap(left) && muchSmaller(left, right) || !RowOrdering.isOrderable(leftKeys) => Seq(joins.ShuffledHashJoinExec( - leftKeys, rightKeys, joinType, BuildLeft, condition, planLater(left), planLater(right))) + leftKeys, rightKeys, leftKeys, rightKeys, joinType, BuildLeft, condition, planLater(left), + planLater(right))) // --- SortMergeJoin ------------------------------------------------------------ case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) if RowOrdering.isOrderable(leftKeys) => - joins.SortMergeJoinExec( - leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil + joins.SortMergeJoinExec(leftKeys, rightKeys, leftKeys, rightKeys, joinType, condition, + planLater(left), planLater(right)) :: Nil // --- Without joining keys ------------------------------------------------------------ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index e3d28388c547..b94b7c9a1996 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -17,14 +17,14 @@ package org.apache.spark.sql.execution.exchange +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, - SortMergeJoinExec} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.internal.SQLConf /** @@ -220,45 +220,99 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { operator.withNewChildren(children) } + private def isSubset(biggerSet: Seq[Expression], smallerSet: Seq[Expression]): Boolean = + smallerSet.length <= biggerSet.length && + smallerSet.forall(x => biggerSet.exists(_.semanticEquals(x))) + + /** + * Reorders `leftKeys` and `rightKeys` by aligning `currentOrderOfKeys` to be a prefix of + * `expectedOrderOfKeys` + */ private def reorder( leftKeys: Seq[Expression], rightKeys: Seq[Expression], - expectedOrderOfKeys: Seq[Expression], - currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { - val leftKeysBuffer = ArrayBuffer[Expression]() - val rightKeysBuffer = ArrayBuffer[Expression]() + expectedOrderOfKeys: Seq[Expression], // comes from child's output partitioning + currentOrderOfKeys: Seq[Expression]): // comes from join predicate + (Seq[Expression], Seq[Expression], Seq[Expression], Seq[Expression]) = { + + assert(leftKeys.length == rightKeys.length) + + val allLeftKeys = ArrayBuffer[Expression]() + val allRightKeys = ArrayBuffer[Expression]() + val reorderedLeftKeys = ArrayBuffer[Expression]() + val reorderedRightKeys = ArrayBuffer[Expression]() + + // Tracking indicies here to track to which keys are accounted. Using a set based approach + // won't work because its possible that some keys are repeated in the join clause + // eg. a.key1 = b.key1 AND a.key1 = b.key2 + val processedIndicies = mutable.Set[Int]() expectedOrderOfKeys.foreach(expression => { - val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression)) - leftKeysBuffer.append(leftKeys(index)) - rightKeysBuffer.append(rightKeys(index)) + val index = currentOrderOfKeys.zipWithIndex.find { case (currKey, i) => + !processedIndicies.contains(i) && currKey.semanticEquals(expression) + }.get._2 + processedIndicies.add(index) + + reorderedLeftKeys.append(leftKeys(index)) + allLeftKeys.append(leftKeys(index)) + + reorderedRightKeys.append(rightKeys(index)) + allRightKeys.append(rightKeys(index)) }) - (leftKeysBuffer, rightKeysBuffer) + + // If len(currentOrderOfKeys) > len(expectedOrderOfKeys), then the re-ordering won't have + // all the keys. Append the remaining keys to the end so that we are covering all the keys + for (i <- leftKeys.indices) { + if (!processedIndicies.contains(i)) { + allLeftKeys.append(leftKeys(i)) + allRightKeys.append(rightKeys(i)) + } + } + + assert(allLeftKeys.length == leftKeys.length) + assert(allRightKeys.length == rightKeys.length) + assert(reorderedLeftKeys.length == reorderedRightKeys.length) + + (allLeftKeys, reorderedLeftKeys, allRightKeys, reorderedRightKeys) } + /** + * Returns a tuple of "all join keys" and "minimal set of join keys which are satisfied by + * child" for both the left and right sub-trees + * + * eg. Assume `table_left` and `table_right` are joined over columns [a, b, c, d] and + * both the tables were already shuffled over columns [c, d] before the join is performed. The + * join keys are re-ordered as [c, d, a, b] by this method. The output of the method would be: + * + * { + * left:[c, d, a, b], // leftJoinKeys + * left:[c, d], // leftDistributionKeys + * right:[c, d, a, b], // rightJoinKeys + * right:[c, d] // rightDistributionKeys + * } + */ private def reorderJoinKeys( leftKeys: Seq[Expression], rightKeys: Seq[Expression], leftPartitioning: Partitioning, - rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = { + rightPartitioning: Partitioning): + (Seq[Expression], Seq[Expression], Seq[Expression], Seq[Expression]) = { + if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) { - leftPartitioning match { - case HashPartitioning(leftExpressions, _) - if leftExpressions.length == leftKeys.length && - leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) => + (leftPartitioning, rightPartitioning) match { + case (HashPartitioning(leftExpressions, _), HashPartitioning(_, _)) + if isSubset(leftKeys, leftExpressions) => reorder(leftKeys, rightKeys, leftExpressions, leftKeys) - case _ => rightPartitioning match { - case HashPartitioning(rightExpressions, _) - if rightExpressions.length == rightKeys.length && - rightKeys.forall(x => rightExpressions.exists(_.semanticEquals(x))) => - reorder(leftKeys, rightKeys, rightExpressions, rightKeys) + case (HashPartitioning(_, _), HashPartitioning(rightExpressions, _)) + if isSubset(rightKeys, rightExpressions) => + reorder(leftKeys, rightKeys, rightExpressions, rightKeys) - case _ => (leftKeys, rightKeys) - } + case _ => + (leftKeys, leftKeys, rightKeys, rightKeys) } } else { - (leftKeys, rightKeys) + (leftKeys, leftKeys, rightKeys, rightKeys) } } @@ -271,23 +325,24 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { */ private def reorderJoinPredicates(plan: SparkPlan): SparkPlan = { plan.transformUp { - case BroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, - right) => - val (reorderedLeftKeys, reorderedRightKeys) = - reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) - BroadcastHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition, - left, right) - - case ShuffledHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) => - val (reorderedLeftKeys, reorderedRightKeys) = - reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) - ShuffledHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition, - left, right) - - case SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, left, right) => - val (reorderedLeftKeys, reorderedRightKeys) = - reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) - SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, condition, left, right) + case ShuffledHashJoinExec(leftJoinKeys, rightJoinKeys, _, _, joinType, buildSide, condition, + left, right) => + val (reorderedLeftKeys, reducedLeftDistributionKeys, reorderedRightKeys, + reducedRightDistributionKeys) = + reorderJoinKeys(leftJoinKeys, rightJoinKeys, left.outputPartitioning, + right.outputPartitioning) + + ShuffledHashJoinExec(reorderedLeftKeys, reorderedRightKeys, reducedLeftDistributionKeys, + reducedRightDistributionKeys, joinType, buildSide, condition, left, right) + + case SortMergeJoinExec(leftJoinKeys, rightJoinKeys, _, _, joinType, condition, left, right) => + val (reorderedLeftKeys, reducedLeftDistributionKeys, reorderedRightKeys, + reducedRightDistributionKeys) = + reorderJoinKeys(leftJoinKeys, rightJoinKeys, left.outputPartitioning, + right.outputPartitioning) + + SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, reducedLeftDistributionKeys, + reducedRightDistributionKeys, joinType, condition, left, right) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala index 897a4dae39f3..cb2bc2f2676f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala @@ -32,6 +32,8 @@ import org.apache.spark.sql.execution.metric.SQLMetrics case class ShuffledHashJoinExec( leftKeys: Seq[Expression], rightKeys: Seq[Expression], + leftDistributionKeys: Seq[Expression], + rightDistributionKeys: Seq[Expression], joinType: JoinType, buildSide: BuildSide, condition: Option[Expression], @@ -46,7 +48,9 @@ case class ShuffledHashJoinExec( "avgHashProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hash probe")) override def requiredChildDistribution: Seq[Distribution] = - HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil + HashClusteredDistribution(leftDistributionKeys) :: + HashClusteredDistribution(rightDistributionKeys) :: + Nil private def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = { val buildDataSize = longMetric("buildDataSize") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 2de2f30eb05d..45ef4bdc7b8a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -34,8 +34,10 @@ import org.apache.spark.util.collection.BitSet * Performs a sort merge join of two child relations. */ case class SortMergeJoinExec( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], + leftJoinKeys: Seq[Expression], + rightJoinKeys: Seq[Expression], + leftDistributionKeys: Seq[Expression], + rightDistributionKeys: Seq[Expression], joinType: JoinType, condition: Option[Expression], left: SparkPlan, @@ -78,24 +80,26 @@ case class SortMergeJoinExec( } override def requiredChildDistribution: Seq[Distribution] = - HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil + HashClusteredDistribution(leftDistributionKeys) :: + HashClusteredDistribution(rightDistributionKeys) :: + Nil override def outputOrdering: Seq[SortOrder] = joinType match { // For inner join, orders of both sides keys should be kept. case _: InnerLike => - val leftKeyOrdering = getKeyOrdering(leftKeys, left.outputOrdering) - val rightKeyOrdering = getKeyOrdering(rightKeys, right.outputOrdering) + val leftKeyOrdering = getKeyOrdering(leftJoinKeys, left.outputOrdering) + val rightKeyOrdering = getKeyOrdering(rightJoinKeys, right.outputOrdering) leftKeyOrdering.zip(rightKeyOrdering).map { case (lKey, rKey) => // Also add the right key and its `sameOrderExpressions` SortOrder(lKey.child, Ascending, lKey.sameOrderExpressions + rKey.child ++ rKey .sameOrderExpressions) } // For left and right outer joins, the output is ordered by the streamed input's join keys. - case LeftOuter => getKeyOrdering(leftKeys, left.outputOrdering) - case RightOuter => getKeyOrdering(rightKeys, right.outputOrdering) + case LeftOuter => getKeyOrdering(leftJoinKeys, left.outputOrdering) + case RightOuter => getKeyOrdering(rightJoinKeys, right.outputOrdering) // There are null rows in both streams, so there is no order. case FullOuter => Nil - case LeftExistence(_) => getKeyOrdering(leftKeys, left.outputOrdering) + case LeftExistence(_) => getKeyOrdering(leftJoinKeys, left.outputOrdering) case x => throw new IllegalArgumentException( s"${getClass.getSimpleName} should not take $x as the JoinType") @@ -122,7 +126,7 @@ case class SortMergeJoinExec( } override def requiredChildOrdering: Seq[Seq[SortOrder]] = - requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil + requiredOrders(leftJoinKeys) :: requiredOrders(rightJoinKeys) :: Nil private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = { // This must be ascending in order to agree with the `keyOrdering` defined in `doExecute()`. @@ -130,10 +134,10 @@ case class SortMergeJoinExec( } private def createLeftKeyGenerator(): Projection = - UnsafeProjection.create(leftKeys, left.output) + UnsafeProjection.create(leftJoinKeys, left.output) private def createRightKeyGenerator(): Projection = - UnsafeProjection.create(rightKeys, right.output) + UnsafeProjection.create(rightJoinKeys, right.output) private def getSpillThreshold: Int = { sqlContext.conf.sortMergeJoinExecBufferSpillThreshold @@ -157,7 +161,7 @@ case class SortMergeJoinExec( } // An ordering that can be used to compare keys from both sides. - val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) + val keyOrdering = newNaturalAscendingOrdering(leftJoinKeys.map(_.dataType)) val resultProj: InternalRow => InternalRow = UnsafeProjection.create(output, output) joinType match { @@ -398,7 +402,7 @@ case class SortMergeJoinExec( private def copyKeys(ctx: CodegenContext, vars: Seq[ExprCode]): Seq[ExprCode] = { vars.zipWithIndex.map { case (ev, i) => - ctx.addBufferedState(leftKeys(i).dataType, "value", ev.value) + ctx.addBufferedState(leftJoinKeys(i).dataType, "value", ev.value) } } @@ -406,7 +410,7 @@ case class SortMergeJoinExec( val comparisons = a.zip(b).zipWithIndex.map { case ((l, r), i) => s""" |if (comp == 0) { - | comp = ${ctx.genComp(leftKeys(i).dataType, l.value, r.value)}; + | comp = ${ctx.genComp(leftJoinKeys(i).dataType, l.value, r.value)}; |} """.stripMargin.trim } @@ -427,9 +431,9 @@ case class SortMergeJoinExec( val rightRow = ctx.addMutableState("InternalRow", "rightRow", forceInline = true) // Create variables for join keys from both sides. - val leftKeyVars = createJoinKey(ctx, leftRow, leftKeys, left.output) + val leftKeyVars = createJoinKey(ctx, leftRow, leftJoinKeys, left.output) val leftAnyNull = leftKeyVars.map(_.isNull).mkString(" || ") - val rightKeyTmpVars = createJoinKey(ctx, rightRow, rightKeys, right.output) + val rightKeyTmpVars = createJoinKey(ctx, rightRow, rightJoinKeys, right.output) val rightAnyNull = rightKeyTmpVars.map(_.isNull).mkString(" || ") // Copy the right key as class members so they could be used in next function call. val rightKeyVars = copyKeys(ctx, rightKeyTmpVars) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 771e1186e63a..4d56d2a3e7f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -836,8 +836,8 @@ class JoinSuite extends QueryTest with SharedSQLContext { joinPairs.foreach { case(join1, join2) => - val leftKeys = join1.leftKeys - val rightKeys = join1.rightKeys + val leftKeys = join1.leftJoinKeys + val rightKeys = join1.rightJoinKeys val outputOrderingPhysical = join1.outputOrdering val outputOrderingExecuted = join2.outputOrdering diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index f8b26f5b28cc..895834125d8d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -420,6 +420,8 @@ class PlannerSuite extends SharedSQLContext { None) val inputPlan = SortMergeJoinExec( + Literal(1) :: Nil, + Literal(1) :: Nil, Literal(1) :: Nil, Literal(1) :: Nil, Inner, @@ -437,6 +439,8 @@ class PlannerSuite extends SharedSQLContext { // nested exchanges val inputPlan2 = SortMergeJoinExec( + Literal(1) :: Nil, + Literal(1) :: Nil, Literal(1) :: Nil, Literal(1) :: Nil, Inner, @@ -496,7 +500,9 @@ class PlannerSuite extends SharedSQLContext { test("EnsureRequirements skips sort when either side of join keys is required after inner SMJ") { Seq(Inner, Cross).foreach { joinType => - val innerSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, joinType, None, planA, planB) + val innerSmj = + SortMergeJoinExec(exprA :: Nil, exprB :: Nil, exprA :: Nil, exprB :: Nil, joinType, None, + planA, planB) // Both left and right keys should be sorted after the SMJ. Seq(orderingA, orderingB).foreach { ordering => assertSortRequirementsAreSatisfied( @@ -510,8 +516,13 @@ class PlannerSuite extends SharedSQLContext { test("EnsureRequirements skips sort when key order of a parent SMJ is propagated from its " + "child SMJ") { Seq(Inner, Cross).foreach { joinType => - val childSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, joinType, None, planA, planB) - val parentSmj = SortMergeJoinExec(exprB :: Nil, exprC :: Nil, joinType, None, childSmj, planC) + val childSmj = + SortMergeJoinExec(exprA :: Nil, exprB :: Nil, exprA :: Nil, exprB :: Nil, joinType, None, + planA, planB) + + val parentSmj = + SortMergeJoinExec(exprB :: Nil, exprC :: Nil, exprB :: Nil, exprC :: Nil, joinType, None, + childSmj, planC) // After the second SMJ, exprA, exprB and exprC should all be sorted. Seq(orderingA, orderingB, orderingC).foreach { ordering => assertSortRequirementsAreSatisfied( @@ -524,7 +535,8 @@ class PlannerSuite extends SharedSQLContext { test("EnsureRequirements for sort operator after left outer sort merge join") { // Only left key is sorted after left outer SMJ (thus doesn't need a sort). - val leftSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, LeftOuter, None, planA, planB) + val leftSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, exprA :: Nil, exprB :: Nil, + LeftOuter, None, planA, planB) Seq((orderingA, false), (orderingB, true)).foreach { case (ordering, needSort) => assertSortRequirementsAreSatisfied( childPlan = leftSmj, @@ -535,7 +547,8 @@ class PlannerSuite extends SharedSQLContext { test("EnsureRequirements for sort operator after right outer sort merge join") { // Only right key is sorted after right outer SMJ (thus doesn't need a sort). - val rightSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, RightOuter, None, planA, planB) + val rightSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, exprA :: Nil, exprB :: Nil, + RightOuter, None, planA, planB) Seq((orderingA, true), (orderingB, false)).foreach { case (ordering, needSort) => assertSortRequirementsAreSatisfied( childPlan = rightSmj, @@ -546,7 +559,8 @@ class PlannerSuite extends SharedSQLContext { test("EnsureRequirements adds sort after full outer sort merge join") { // Neither keys is sorted after full outer SMJ, so they both need sorts. - val fullSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, FullOuter, None, planA, planB) + val fullSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, exprA :: Nil, exprB :: Nil, + FullOuter, None, planA, planB) Seq(orderingA, orderingB).foreach { ordering => assertSortRequirementsAreSatisfied( childPlan = fullSmj, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala index 38377164c10e..cfce2095432d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala @@ -106,14 +106,14 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => EnsureRequirements(left.sqlContext.sessionState.conf).apply( - ShuffledHashJoinExec( - leftKeys, rightKeys, joinType, BuildRight, boundCondition, left, right)), + ShuffledHashJoinExec(leftKeys, rightKeys, leftKeys, rightKeys, joinType, BuildRight, + boundCondition, left, right)), expectedAnswer, sortAnswers = true) checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => EnsureRequirements(left.sqlContext.sessionState.conf).apply( - createLeftSemiPlusJoin(ShuffledHashJoinExec( - leftKeys, rightKeys, leftSemiPlus, BuildRight, boundCondition, left, right))), + createLeftSemiPlusJoin(ShuffledHashJoinExec(leftKeys, rightKeys, leftKeys, rightKeys, + leftSemiPlus, BuildRight, boundCondition, left, right))), expectedAnswer, sortAnswers = true) } @@ -144,13 +144,14 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => EnsureRequirements(left.sqlContext.sessionState.conf).apply( - SortMergeJoinExec(leftKeys, rightKeys, joinType, boundCondition, left, right)), + SortMergeJoinExec(leftKeys, rightKeys, leftKeys, rightKeys, joinType, boundCondition, + left, right)), expectedAnswer, sortAnswers = true) checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => EnsureRequirements(left.sqlContext.sessionState.conf).apply( - createLeftSemiPlusJoin(SortMergeJoinExec( - leftKeys, rightKeys, leftSemiPlus, boundCondition, left, right))), + createLeftSemiPlusJoin(SortMergeJoinExec(leftKeys, rightKeys, leftKeys, rightKeys, + leftSemiPlus, boundCondition, left, right))), expectedAnswer, sortAnswers = true) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 4408ece11225..c233a01ffe16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -109,8 +109,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { leftPlan: SparkPlan, rightPlan: SparkPlan, side: BuildSide) = { - val shuffledHashJoin = joins.ShuffledHashJoinExec(leftKeys, rightKeys, Inner, - side, None, leftPlan, rightPlan) + val shuffledHashJoin = joins.ShuffledHashJoinExec(leftKeys, rightKeys, leftKeys, rightKeys, + Inner, side, None, leftPlan, rightPlan) val filteredJoin = boundCondition.map(FilterExec(_, shuffledHashJoin)).getOrElse(shuffledHashJoin) EnsureRequirements(spark.sessionState.conf).apply(filteredJoin) @@ -122,8 +122,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { boundCondition: Option[Expression], leftPlan: SparkPlan, rightPlan: SparkPlan) = { - val sortMergeJoin = joins.SortMergeJoinExec(leftKeys, rightKeys, Inner, boundCondition, - leftPlan, rightPlan) + val sortMergeJoin = joins.SortMergeJoinExec(leftKeys, rightKeys, leftKeys, rightKeys, Inner, + boundCondition, leftPlan, rightPlan) EnsureRequirements(spark.sessionState.conf).apply(sortMergeJoin) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 001feb0f2b39..df6a2824dc9f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -83,8 +83,8 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { val buildSide = if (joinType == LeftOuter) BuildRight else BuildLeft checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => EnsureRequirements(spark.sessionState.conf).apply( - ShuffledHashJoinExec( - leftKeys, rightKeys, joinType, buildSide, boundCondition, left, right)), + ShuffledHashJoinExec(leftKeys, rightKeys, leftKeys, rightKeys, joinType, buildSide, + boundCondition, left, right)), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } @@ -116,7 +116,8 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => EnsureRequirements(spark.sessionState.conf).apply( - SortMergeJoinExec(leftKeys, rightKeys, joinType, boundCondition, left, right)), + SortMergeJoinExec(leftKeys, rightKeys, leftKeys, rightKeys, joinType, boundCondition, + left, right)), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index fb61fa716b94..b108dd2a5db6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -569,18 +569,33 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { }) } + test("SPARK-18067 Avoid shuffling Join's child if join keys are superset of child's " + + "partitioning keys") { + + val bucketedTableTestSpec = BucketedTableTestSpec( + Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j"))), + numPartitions = 1, + expectedShuffle = false) + + Seq( + Seq("i", "j", "k"), + Seq("i", "k", "j"), + Seq("j", "k", "i"), + Seq("j", "i", "k"), + Seq("k", "j", "i"), + Seq("k", "i", "j") + ).foreach(joinKeys => { + testBucketing( + bucketedTableTestSpecLeft = bucketedTableTestSpec, + bucketedTableTestSpecRight = bucketedTableTestSpec, + joinCondition = joinCondition(joinKeys) + ) + }) + } + test("SPARK-19122 No re-ordering should happen if set of join columns != set of child's " + "partitioning columns") { - // join predicates is a super set of child's partitioning columns - val bucketedTableTestSpec1 = - BucketedTableTestSpec(Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j"))), numPartitions = 1) - testBucketing( - bucketedTableTestSpecLeft = bucketedTableTestSpec1, - bucketedTableTestSpecRight = bucketedTableTestSpec1, - joinCondition = joinCondition(Seq("i", "j", "k")) - ) - // child's partitioning columns is a super set of join predicates val bucketedTableTestSpec2 = BucketedTableTestSpec(Some(BucketSpec(8, Seq("i", "j", "k"), Seq("i", "j", "k"))),