Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -225,21 +225,23 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
&& muchSmaller(right, left) ||
!RowOrdering.isOrderable(leftKeys) =>
Seq(joins.ShuffledHashJoinExec(
leftKeys, rightKeys, joinType, BuildRight, condition, planLater(left), planLater(right)))
leftKeys, rightKeys, leftKeys, rightKeys, joinType, BuildRight, condition,
planLater(left), planLater(right)))

case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
if !conf.preferSortMergeJoin && canBuildLeft(joinType) && canBuildLocalHashMap(left)
&& muchSmaller(left, right) ||
!RowOrdering.isOrderable(leftKeys) =>
Seq(joins.ShuffledHashJoinExec(
leftKeys, rightKeys, joinType, BuildLeft, condition, planLater(left), planLater(right)))
leftKeys, rightKeys, leftKeys, rightKeys, joinType, BuildLeft, condition, planLater(left),
planLater(right)))

// --- SortMergeJoin ------------------------------------------------------------

case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
if RowOrdering.isOrderable(leftKeys) =>
joins.SortMergeJoinExec(
leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil
joins.SortMergeJoinExec(leftKeys, rightKeys, leftKeys, rightKeys, joinType, condition,
planLater(left), planLater(right)) :: Nil

// --- Without joining keys ------------------------------------------------------------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@

package org.apache.spark.sql.execution.exchange

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec,
SortMergeJoinExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.internal.SQLConf

/**
Expand Down Expand Up @@ -220,45 +220,99 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
operator.withNewChildren(children)
}

private def isSubset(biggerSet: Seq[Expression], smallerSet: Seq[Expression]): Boolean =
smallerSet.length <= biggerSet.length &&
smallerSet.forall(x => biggerSet.exists(_.semanticEquals(x)))

/**
* Reorders `leftKeys` and `rightKeys` by aligning `currentOrderOfKeys` to be a prefix of
* `expectedOrderOfKeys`
*/
private def reorder(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
expectedOrderOfKeys: Seq[Expression],
currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = {
val leftKeysBuffer = ArrayBuffer[Expression]()
val rightKeysBuffer = ArrayBuffer[Expression]()
expectedOrderOfKeys: Seq[Expression], // comes from child's output partitioning
currentOrderOfKeys: Seq[Expression]): // comes from join predicate
(Seq[Expression], Seq[Expression], Seq[Expression], Seq[Expression]) = {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you please add a comment describing the return type? a tuple4 is not such a descriptive type 😃


assert(leftKeys.length == rightKeys.length)

val allLeftKeys = ArrayBuffer[Expression]()
val allRightKeys = ArrayBuffer[Expression]()
val reorderedLeftKeys = ArrayBuffer[Expression]()
val reorderedRightKeys = ArrayBuffer[Expression]()

// Tracking indicies here to track to which keys are accounted. Using a set based approach
// won't work because its possible that some keys are repeated in the join clause
// eg. a.key1 = b.key1 AND a.key1 = b.key2
val processedIndicies = mutable.Set[Int]()

expectedOrderOfKeys.foreach(expression => {
val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression))
leftKeysBuffer.append(leftKeys(index))
rightKeysBuffer.append(rightKeys(index))
val index = currentOrderOfKeys.zipWithIndex.find { case (currKey, i) =>
!processedIndicies.contains(i) && currKey.semanticEquals(expression)
}.get._2
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the find guaranteed to always succeed?
if so, worth a comment on method's pre/post conditions.

a getOrElse(sys error "...") might also be a good way of documenting this.

processedIndicies.add(index)

reorderedLeftKeys.append(leftKeys(index))
allLeftKeys.append(leftKeys(index))

reorderedRightKeys.append(rightKeys(index))
allRightKeys.append(rightKeys(index))
})
(leftKeysBuffer, rightKeysBuffer)

// If len(currentOrderOfKeys) > len(expectedOrderOfKeys), then the re-ordering won't have
// all the keys. Append the remaining keys to the end so that we are covering all the keys
for (i <- leftKeys.indices) {
if (!processedIndicies.contains(i)) {
allLeftKeys.append(leftKeys(i))
allRightKeys.append(rightKeys(i))
}
}

assert(allLeftKeys.length == leftKeys.length)
assert(allRightKeys.length == rightKeys.length)
assert(reorderedLeftKeys.length == reorderedRightKeys.length)

(allLeftKeys, reorderedLeftKeys, allRightKeys, reorderedRightKeys)
}

/**
* Returns a tuple of "all join keys" and "minimal set of join keys which are satisfied by
* child" for both the left and right sub-trees
*
* eg. Assume `table_left` and `table_right` are joined over columns [a, b, c, d] and
* both the tables were already shuffled over columns [c, d] before the join is performed. The
* join keys are re-ordered as [c, d, a, b] by this method. The output of the method would be:
*
* {
* left:[c, d, a, b], // leftJoinKeys
* left:[c, d], // leftDistributionKeys
* right:[c, d, a, b], // rightJoinKeys
* right:[c, d] // rightDistributionKeys
* }
*/
private def reorderJoinKeys(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
leftPartitioning: Partitioning,
rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = {
rightPartitioning: Partitioning):
(Seq[Expression], Seq[Expression], Seq[Expression], Seq[Expression]) = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add some documentation to explain what the return value is.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added more doc. I wasn't sure how to make it easier to understand. Hope that the example helps with that


if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) {
leftPartitioning match {
case HashPartitioning(leftExpressions, _)
if leftExpressions.length == leftKeys.length &&
leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) =>
(leftPartitioning, rightPartitioning) match {
case (HashPartitioning(leftExpressions, _), HashPartitioning(_, _))
if isSubset(leftKeys, leftExpressions) =>
reorder(leftKeys, rightKeys, leftExpressions, leftKeys)

case _ => rightPartitioning match {
case HashPartitioning(rightExpressions, _)
if rightExpressions.length == rightKeys.length &&
rightKeys.forall(x => rightExpressions.exists(_.semanticEquals(x))) =>
reorder(leftKeys, rightKeys, rightExpressions, rightKeys)
case (HashPartitioning(_, _), HashPartitioning(rightExpressions, _))
if isSubset(rightKeys, rightExpressions) =>
reorder(leftKeys, rightKeys, rightExpressions, rightKeys)

case _ => (leftKeys, rightKeys)
}
case _ =>
(leftKeys, leftKeys, rightKeys, rightKeys)
}
} else {
(leftKeys, rightKeys)
(leftKeys, leftKeys, rightKeys, rightKeys)
}
}

Expand All @@ -271,23 +325,24 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
*/
private def reorderJoinPredicates(plan: SparkPlan): SparkPlan = {
plan.transformUp {
case BroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removal of BroadcastHashJoinExec is intentional. The children are expected to have BroadcastDistribution or UnspecifiedDistribution so this method wont help here (this optimization only helps in case of shuffle based joins)

right) =>
val (reorderedLeftKeys, reorderedRightKeys) =
reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning)
BroadcastHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition,
left, right)

case ShuffledHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) =>
val (reorderedLeftKeys, reorderedRightKeys) =
reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning)
ShuffledHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition,
left, right)

case SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, left, right) =>
val (reorderedLeftKeys, reorderedRightKeys) =
reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning)
SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, condition, left, right)
case ShuffledHashJoinExec(leftJoinKeys, rightJoinKeys, _, _, joinType, buildSide, condition,
left, right) =>
val (reorderedLeftKeys, reducedLeftDistributionKeys, reorderedRightKeys,
reducedRightDistributionKeys) =
reorderJoinKeys(leftJoinKeys, rightJoinKeys, left.outputPartitioning,
right.outputPartitioning)

ShuffledHashJoinExec(reorderedLeftKeys, reorderedRightKeys, reducedLeftDistributionKeys,
reducedRightDistributionKeys, joinType, buildSide, condition, left, right)

case SortMergeJoinExec(leftJoinKeys, rightJoinKeys, _, _, joinType, condition, left, right) =>
val (reorderedLeftKeys, reducedLeftDistributionKeys, reorderedRightKeys,
reducedRightDistributionKeys) =
reorderJoinKeys(leftJoinKeys, rightJoinKeys, left.outputPartitioning,
right.outputPartitioning)

SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, reducedLeftDistributionKeys,
reducedRightDistributionKeys, joinType, condition, left, right)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ import org.apache.spark.sql.execution.metric.SQLMetrics
case class ShuffledHashJoinExec(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
leftDistributionKeys: Seq[Expression],
rightDistributionKeys: Seq[Expression],
joinType: JoinType,
buildSide: BuildSide,
condition: Option[Expression],
Expand All @@ -46,7 +48,9 @@ case class ShuffledHashJoinExec(
"avgHashProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hash probe"))

override def requiredChildDistribution: Seq[Distribution] =
HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil
HashClusteredDistribution(leftDistributionKeys) ::
HashClusteredDistribution(rightDistributionKeys) ::
Nil

private def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = {
val buildDataSize = longMetric("buildDataSize")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ import org.apache.spark.util.collection.BitSet
* Performs a sort merge join of two child relations.
*/
case class SortMergeJoinExec(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
leftJoinKeys: Seq[Expression],
rightJoinKeys: Seq[Expression],
leftDistributionKeys: Seq[Expression],
rightDistributionKeys: Seq[Expression],
joinType: JoinType,
condition: Option[Expression],
left: SparkPlan,
Expand Down Expand Up @@ -78,24 +80,26 @@ case class SortMergeJoinExec(
}

override def requiredChildDistribution: Seq[Distribution] =
HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil
HashClusteredDistribution(leftDistributionKeys) ::
HashClusteredDistribution(rightDistributionKeys) ::
Nil

override def outputOrdering: Seq[SortOrder] = joinType match {
// For inner join, orders of both sides keys should be kept.
case _: InnerLike =>
val leftKeyOrdering = getKeyOrdering(leftKeys, left.outputOrdering)
val rightKeyOrdering = getKeyOrdering(rightKeys, right.outputOrdering)
val leftKeyOrdering = getKeyOrdering(leftJoinKeys, left.outputOrdering)
val rightKeyOrdering = getKeyOrdering(rightJoinKeys, right.outputOrdering)
leftKeyOrdering.zip(rightKeyOrdering).map { case (lKey, rKey) =>
// Also add the right key and its `sameOrderExpressions`
SortOrder(lKey.child, Ascending, lKey.sameOrderExpressions + rKey.child ++ rKey
.sameOrderExpressions)
}
// For left and right outer joins, the output is ordered by the streamed input's join keys.
case LeftOuter => getKeyOrdering(leftKeys, left.outputOrdering)
case RightOuter => getKeyOrdering(rightKeys, right.outputOrdering)
case LeftOuter => getKeyOrdering(leftJoinKeys, left.outputOrdering)
case RightOuter => getKeyOrdering(rightJoinKeys, right.outputOrdering)
// There are null rows in both streams, so there is no order.
case FullOuter => Nil
case LeftExistence(_) => getKeyOrdering(leftKeys, left.outputOrdering)
case LeftExistence(_) => getKeyOrdering(leftJoinKeys, left.outputOrdering)
case x =>
throw new IllegalArgumentException(
s"${getClass.getSimpleName} should not take $x as the JoinType")
Expand All @@ -122,18 +126,18 @@ case class SortMergeJoinExec(
}

override def requiredChildOrdering: Seq[Seq[SortOrder]] =
requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil
requiredOrders(leftJoinKeys) :: requiredOrders(rightJoinKeys) :: Nil

private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = {
// This must be ascending in order to agree with the `keyOrdering` defined in `doExecute()`.
keys.map(SortOrder(_, Ascending))
}

private def createLeftKeyGenerator(): Projection =
UnsafeProjection.create(leftKeys, left.output)
UnsafeProjection.create(leftJoinKeys, left.output)

private def createRightKeyGenerator(): Projection =
UnsafeProjection.create(rightKeys, right.output)
UnsafeProjection.create(rightJoinKeys, right.output)

private def getSpillThreshold: Int = {
sqlContext.conf.sortMergeJoinExecBufferSpillThreshold
Expand All @@ -157,7 +161,7 @@ case class SortMergeJoinExec(
}

// An ordering that can be used to compare keys from both sides.
val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType))
val keyOrdering = newNaturalAscendingOrdering(leftJoinKeys.map(_.dataType))
val resultProj: InternalRow => InternalRow = UnsafeProjection.create(output, output)

joinType match {
Expand Down Expand Up @@ -398,15 +402,15 @@ case class SortMergeJoinExec(

private def copyKeys(ctx: CodegenContext, vars: Seq[ExprCode]): Seq[ExprCode] = {
vars.zipWithIndex.map { case (ev, i) =>
ctx.addBufferedState(leftKeys(i).dataType, "value", ev.value)
ctx.addBufferedState(leftJoinKeys(i).dataType, "value", ev.value)
}
}

private def genComparison(ctx: CodegenContext, a: Seq[ExprCode], b: Seq[ExprCode]): String = {
val comparisons = a.zip(b).zipWithIndex.map { case ((l, r), i) =>
s"""
|if (comp == 0) {
| comp = ${ctx.genComp(leftKeys(i).dataType, l.value, r.value)};
| comp = ${ctx.genComp(leftJoinKeys(i).dataType, l.value, r.value)};
|}
""".stripMargin.trim
}
Expand All @@ -427,9 +431,9 @@ case class SortMergeJoinExec(
val rightRow = ctx.addMutableState("InternalRow", "rightRow", forceInline = true)

// Create variables for join keys from both sides.
val leftKeyVars = createJoinKey(ctx, leftRow, leftKeys, left.output)
val leftKeyVars = createJoinKey(ctx, leftRow, leftJoinKeys, left.output)
val leftAnyNull = leftKeyVars.map(_.isNull).mkString(" || ")
val rightKeyTmpVars = createJoinKey(ctx, rightRow, rightKeys, right.output)
val rightKeyTmpVars = createJoinKey(ctx, rightRow, rightJoinKeys, right.output)
val rightAnyNull = rightKeyTmpVars.map(_.isNull).mkString(" || ")
// Copy the right key as class members so they could be used in next function call.
val rightKeyVars = copyKeys(ctx, rightKeyTmpVars)
Expand Down
4 changes: 2 additions & 2 deletions sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -836,8 +836,8 @@ class JoinSuite extends QueryTest with SharedSQLContext {

joinPairs.foreach {
case(join1, join2) =>
val leftKeys = join1.leftKeys
val rightKeys = join1.rightKeys
val leftKeys = join1.leftJoinKeys
val rightKeys = join1.rightJoinKeys
val outputOrderingPhysical = join1.outputOrdering
val outputOrderingExecuted = join2.outputOrdering

Expand Down
Loading