-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-18067] Avoid shuffling child if join keys are superset of child's partitioning keys #19054
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,14 +17,14 @@ | |
|
|
||
| package org.apache.spark.sql.execution.exchange | ||
|
|
||
| import scala.collection.mutable | ||
| import scala.collection.mutable.ArrayBuffer | ||
|
|
||
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.catalyst.plans.physical._ | ||
| import org.apache.spark.sql.catalyst.rules.Rule | ||
| import org.apache.spark.sql.execution._ | ||
| import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, | ||
| SortMergeJoinExec} | ||
| import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec} | ||
| import org.apache.spark.sql.internal.SQLConf | ||
|
|
||
| /** | ||
|
|
@@ -220,45 +220,99 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { | |
| operator.withNewChildren(children) | ||
| } | ||
|
|
||
| private def isSubset(biggerSet: Seq[Expression], smallerSet: Seq[Expression]): Boolean = | ||
| smallerSet.length <= biggerSet.length && | ||
| smallerSet.forall(x => biggerSet.exists(_.semanticEquals(x))) | ||
|
|
||
| /** | ||
| * Reorders `leftKeys` and `rightKeys` by aligning `currentOrderOfKeys` to be a prefix of | ||
| * `expectedOrderOfKeys` | ||
| */ | ||
| private def reorder( | ||
| leftKeys: Seq[Expression], | ||
| rightKeys: Seq[Expression], | ||
| expectedOrderOfKeys: Seq[Expression], | ||
| currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { | ||
| val leftKeysBuffer = ArrayBuffer[Expression]() | ||
| val rightKeysBuffer = ArrayBuffer[Expression]() | ||
| expectedOrderOfKeys: Seq[Expression], // comes from child's output partitioning | ||
| currentOrderOfKeys: Seq[Expression]): // comes from join predicate | ||
| (Seq[Expression], Seq[Expression], Seq[Expression], Seq[Expression]) = { | ||
|
|
||
| assert(leftKeys.length == rightKeys.length) | ||
|
|
||
| val allLeftKeys = ArrayBuffer[Expression]() | ||
| val allRightKeys = ArrayBuffer[Expression]() | ||
| val reorderedLeftKeys = ArrayBuffer[Expression]() | ||
| val reorderedRightKeys = ArrayBuffer[Expression]() | ||
|
|
||
| // Tracking indicies here to track to which keys are accounted. Using a set based approach | ||
| // won't work because its possible that some keys are repeated in the join clause | ||
| // eg. a.key1 = b.key1 AND a.key1 = b.key2 | ||
| val processedIndicies = mutable.Set[Int]() | ||
|
|
||
| expectedOrderOfKeys.foreach(expression => { | ||
| val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression)) | ||
| leftKeysBuffer.append(leftKeys(index)) | ||
| rightKeysBuffer.append(rightKeys(index)) | ||
| val index = currentOrderOfKeys.zipWithIndex.find { case (currKey, i) => | ||
| !processedIndicies.contains(i) && currKey.semanticEquals(expression) | ||
| }.get._2 | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is the find guaranteed to always succeed? a getOrElse(sys error "...") might also be a good way of documenting this. |
||
| processedIndicies.add(index) | ||
|
|
||
| reorderedLeftKeys.append(leftKeys(index)) | ||
| allLeftKeys.append(leftKeys(index)) | ||
|
|
||
| reorderedRightKeys.append(rightKeys(index)) | ||
| allRightKeys.append(rightKeys(index)) | ||
| }) | ||
| (leftKeysBuffer, rightKeysBuffer) | ||
|
|
||
| // If len(currentOrderOfKeys) > len(expectedOrderOfKeys), then the re-ordering won't have | ||
| // all the keys. Append the remaining keys to the end so that we are covering all the keys | ||
| for (i <- leftKeys.indices) { | ||
| if (!processedIndicies.contains(i)) { | ||
| allLeftKeys.append(leftKeys(i)) | ||
| allRightKeys.append(rightKeys(i)) | ||
| } | ||
| } | ||
|
|
||
| assert(allLeftKeys.length == leftKeys.length) | ||
| assert(allRightKeys.length == rightKeys.length) | ||
| assert(reorderedLeftKeys.length == reorderedRightKeys.length) | ||
|
|
||
| (allLeftKeys, reorderedLeftKeys, allRightKeys, reorderedRightKeys) | ||
| } | ||
|
|
||
| /** | ||
| * Returns a tuple of "all join keys" and "minimal set of join keys which are satisfied by | ||
| * child" for both the left and right sub-trees | ||
| * | ||
| * eg. Assume `table_left` and `table_right` are joined over columns [a, b, c, d] and | ||
| * both the tables were already shuffled over columns [c, d] before the join is performed. The | ||
| * join keys are re-ordered as [c, d, a, b] by this method. The output of the method would be: | ||
| * | ||
| * { | ||
| * left:[c, d, a, b], // leftJoinKeys | ||
| * left:[c, d], // leftDistributionKeys | ||
| * right:[c, d, a, b], // rightJoinKeys | ||
| * right:[c, d] // rightDistributionKeys | ||
| * } | ||
| */ | ||
| private def reorderJoinKeys( | ||
| leftKeys: Seq[Expression], | ||
| rightKeys: Seq[Expression], | ||
| leftPartitioning: Partitioning, | ||
| rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = { | ||
| rightPartitioning: Partitioning): | ||
| (Seq[Expression], Seq[Expression], Seq[Expression], Seq[Expression]) = { | ||
|
||
|
|
||
| if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) { | ||
| leftPartitioning match { | ||
| case HashPartitioning(leftExpressions, _) | ||
| if leftExpressions.length == leftKeys.length && | ||
| leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) => | ||
| (leftPartitioning, rightPartitioning) match { | ||
| case (HashPartitioning(leftExpressions, _), HashPartitioning(_, _)) | ||
| if isSubset(leftKeys, leftExpressions) => | ||
| reorder(leftKeys, rightKeys, leftExpressions, leftKeys) | ||
|
|
||
| case _ => rightPartitioning match { | ||
| case HashPartitioning(rightExpressions, _) | ||
| if rightExpressions.length == rightKeys.length && | ||
| rightKeys.forall(x => rightExpressions.exists(_.semanticEquals(x))) => | ||
| reorder(leftKeys, rightKeys, rightExpressions, rightKeys) | ||
| case (HashPartitioning(_, _), HashPartitioning(rightExpressions, _)) | ||
| if isSubset(rightKeys, rightExpressions) => | ||
| reorder(leftKeys, rightKeys, rightExpressions, rightKeys) | ||
|
|
||
| case _ => (leftKeys, rightKeys) | ||
| } | ||
| case _ => | ||
| (leftKeys, leftKeys, rightKeys, rightKeys) | ||
| } | ||
| } else { | ||
| (leftKeys, rightKeys) | ||
| (leftKeys, leftKeys, rightKeys, rightKeys) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -271,23 +325,24 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { | |
| */ | ||
| private def reorderJoinPredicates(plan: SparkPlan): SparkPlan = { | ||
| plan.transformUp { | ||
| case BroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removal of |
||
| right) => | ||
| val (reorderedLeftKeys, reorderedRightKeys) = | ||
| reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) | ||
| BroadcastHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition, | ||
| left, right) | ||
|
|
||
| case ShuffledHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) => | ||
| val (reorderedLeftKeys, reorderedRightKeys) = | ||
| reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) | ||
| ShuffledHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition, | ||
| left, right) | ||
|
|
||
| case SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, left, right) => | ||
| val (reorderedLeftKeys, reorderedRightKeys) = | ||
| reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) | ||
| SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, condition, left, right) | ||
| case ShuffledHashJoinExec(leftJoinKeys, rightJoinKeys, _, _, joinType, buildSide, condition, | ||
| left, right) => | ||
| val (reorderedLeftKeys, reducedLeftDistributionKeys, reorderedRightKeys, | ||
| reducedRightDistributionKeys) = | ||
| reorderJoinKeys(leftJoinKeys, rightJoinKeys, left.outputPartitioning, | ||
| right.outputPartitioning) | ||
|
|
||
| ShuffledHashJoinExec(reorderedLeftKeys, reorderedRightKeys, reducedLeftDistributionKeys, | ||
| reducedRightDistributionKeys, joinType, buildSide, condition, left, right) | ||
|
|
||
| case SortMergeJoinExec(leftJoinKeys, rightJoinKeys, _, _, joinType, condition, left, right) => | ||
| val (reorderedLeftKeys, reducedLeftDistributionKeys, reorderedRightKeys, | ||
| reducedRightDistributionKeys) = | ||
| reorderJoinKeys(leftJoinKeys, rightJoinKeys, left.outputPartitioning, | ||
| right.outputPartitioning) | ||
|
|
||
| SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, reducedLeftDistributionKeys, | ||
| reducedRightDistributionKeys, joinType, condition, left, right) | ||
| } | ||
| } | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you please add a comment describing the return type? a tuple4 is not such a descriptive type 😃