Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
rebase + updates
  • Loading branch information
tejasapatil committed Jan 20, 2018
commit 00bb14b0145a2bd42c8b4c8a9d4f108322804f71
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@

package org.apache.spark.sql.execution.exchange

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec,
SortMergeJoinExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.internal.SQLConf

/**
Expand Down Expand Up @@ -220,45 +220,99 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
operator.withNewChildren(children)
}

private def isSubset(biggerSet: Seq[Expression], smallerSet: Seq[Expression]): Boolean =
smallerSet.length <= biggerSet.length &&
smallerSet.forall(x => biggerSet.exists(_.semanticEquals(x)))

/**
* Reorders `leftKeys` and `rightKeys` by aligning `currentOrderOfKeys` to be a prefix of
* `expectedOrderOfKeys`
*/
private def reorder(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
expectedOrderOfKeys: Seq[Expression],
currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = {
val leftKeysBuffer = ArrayBuffer[Expression]()
val rightKeysBuffer = ArrayBuffer[Expression]()
expectedOrderOfKeys: Seq[Expression], // comes from child's output partitioning
currentOrderOfKeys: Seq[Expression]): // comes from join predicate
(Seq[Expression], Seq[Expression], Seq[Expression], Seq[Expression]) = {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you please add a comment describing the return type? a tuple4 is not such a descriptive type 😃


assert(leftKeys.length == rightKeys.length)

val allLeftKeys = ArrayBuffer[Expression]()
val allRightKeys = ArrayBuffer[Expression]()
val reorderedLeftKeys = ArrayBuffer[Expression]()
val reorderedRightKeys = ArrayBuffer[Expression]()

// Tracking indicies here to track to which keys are accounted. Using a set based approach
// won't work because its possible that some keys are repeated in the join clause
// eg. a.key1 = b.key1 AND a.key1 = b.key2
val processedIndicies = mutable.Set[Int]()

expectedOrderOfKeys.foreach(expression => {
val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression))
leftKeysBuffer.append(leftKeys(index))
rightKeysBuffer.append(rightKeys(index))
val index = currentOrderOfKeys.zipWithIndex.find { case (currKey, i) =>
!processedIndicies.contains(i) && currKey.semanticEquals(expression)
}.get._2
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the find guaranteed to always succeed?
if so, worth a comment on method's pre/post conditions.

a getOrElse(sys error "...") might also be a good way of documenting this.

processedIndicies.add(index)

reorderedLeftKeys.append(leftKeys(index))
allLeftKeys.append(leftKeys(index))

reorderedRightKeys.append(rightKeys(index))
allRightKeys.append(rightKeys(index))
})
(leftKeysBuffer, rightKeysBuffer)

// If len(currentOrderOfKeys) > len(expectedOrderOfKeys), then the re-ordering won't have
// all the keys. Append the remaining keys to the end so that we are covering all the keys
for (i <- leftKeys.indices) {
if (!processedIndicies.contains(i)) {
allLeftKeys.append(leftKeys(i))
allRightKeys.append(rightKeys(i))
}
}

assert(allLeftKeys.length == leftKeys.length)
assert(allRightKeys.length == rightKeys.length)
assert(reorderedLeftKeys.length == reorderedRightKeys.length)

(allLeftKeys, reorderedLeftKeys, allRightKeys, reorderedRightKeys)
}

/**
* Returns a tuple of "all join keys" and "minimal set of join keys which are satisfied by
* child" for both the left and right sub-trees
*
* eg. Assume `table_left` and `table_right` are joined over columns [a, b, c, d] and
* both the tables were already shuffled over columns [c, d] before the join is performed. The
* join keys are re-ordered as [c, d, a, b] by this method. The output of the method would be:
*
* {
* left:[c, d, a, b], // leftJoinKeys
* left:[c, d], // leftDistributionKeys
* right:[c, d, a, b], // rightJoinKeys
* right:[c, d] // rightDistributionKeys
* }
*/
private def reorderJoinKeys(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
leftPartitioning: Partitioning,
rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = {
rightPartitioning: Partitioning):
(Seq[Expression], Seq[Expression], Seq[Expression], Seq[Expression]) = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add some documentation to explain what the return value is.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added more doc. I wasn't sure how to make it easier to understand. Hope that the example helps with that


if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) {
leftPartitioning match {
case HashPartitioning(leftExpressions, _)
if leftExpressions.length == leftKeys.length &&
leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) =>
(leftPartitioning, rightPartitioning) match {
case (HashPartitioning(leftExpressions, _), HashPartitioning(_, _))
if isSubset(leftKeys, leftExpressions) =>
reorder(leftKeys, rightKeys, leftExpressions, leftKeys)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if leftPartitioning is HashPartitioning, we don't need to care about rightPartitioning at all?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

given that this was only done over SortMergeJoinExec and ShuffledHashJoinExec where both the partitionings are HashPartitioning, things worked fine. I have changed this to have a stricter check.


case _ => rightPartitioning match {
case HashPartitioning(rightExpressions, _)
if rightExpressions.length == rightKeys.length &&
rightKeys.forall(x => rightExpressions.exists(_.semanticEquals(x))) =>
reorder(leftKeys, rightKeys, rightExpressions, rightKeys)
case (HashPartitioning(_, _), HashPartitioning(rightExpressions, _))
if isSubset(rightKeys, rightExpressions) =>
reorder(leftKeys, rightKeys, rightExpressions, rightKeys)

case _ => (leftKeys, rightKeys)
}
case _ =>
(leftKeys, leftKeys, rightKeys, rightKeys)
}
} else {
(leftKeys, rightKeys)
(leftKeys, leftKeys, rightKeys, rightKeys)
}
}

Expand All @@ -271,23 +325,24 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
*/
private def reorderJoinPredicates(plan: SparkPlan): SparkPlan = {
plan.transformUp {
case BroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removal of BroadcastHashJoinExec is intentional. The children are expected to have BroadcastDistribution or UnspecifiedDistribution so this method wont help here (this optimization only helps in case of shuffle based joins)

right) =>
val (reorderedLeftKeys, reorderedRightKeys) =
reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning)
BroadcastHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition,
left, right)

case ShuffledHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) =>
val (reorderedLeftKeys, reorderedRightKeys) =
reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning)
ShuffledHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition,
left, right)

case SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, left, right) =>
val (reorderedLeftKeys, reorderedRightKeys) =
reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning)
SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, condition, left, right)
case ShuffledHashJoinExec(leftJoinKeys, rightJoinKeys, _, _, joinType, buildSide, condition,
left, right) =>
val (reorderedLeftKeys, reducedLeftDistributionKeys, reorderedRightKeys,
reducedRightDistributionKeys) =
reorderJoinKeys(leftJoinKeys, rightJoinKeys, left.outputPartitioning,
right.outputPartitioning)

ShuffledHashJoinExec(reorderedLeftKeys, reorderedRightKeys, reducedLeftDistributionKeys,
reducedRightDistributionKeys, joinType, buildSide, condition, left, right)

case SortMergeJoinExec(leftJoinKeys, rightJoinKeys, _, _, joinType, condition, left, right) =>
val (reorderedLeftKeys, reducedLeftDistributionKeys, reorderedRightKeys,
reducedRightDistributionKeys) =
reorderJoinKeys(leftJoinKeys, rightJoinKeys, left.outputPartitioning,
right.outputPartitioning)

SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, reducedLeftDistributionKeys,
reducedRightDistributionKeys, joinType, condition, left, right)
}
}

Expand Down