Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@ import scala.collection.mutable

import org.apache.commons.io.FileUtils

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, EnsureRequirements}
Expand Down Expand Up @@ -196,8 +200,10 @@ case class OptimizeSkewedJoin(
}
logDebug(s"number of skewed partitions: left $numSkewedLeft, right $numSkewedRight")
if (numSkewedLeft > 0 || numSkewedRight > 0) {
Some((AQEShuffleReadExec(left, leftSidePartitions.toSeq),
AQEShuffleReadExec(right, rightSidePartitions.toSeq)))
Some((
SkewJoinChildWrapper(AQEShuffleReadExec(left, leftSidePartitions.toSeq)),
SkewJoinChildWrapper(AQEShuffleReadExec(right, rightSidePartitions.toSeq))
))
} else {
None
}
Expand All @@ -207,25 +213,19 @@ case class OptimizeSkewedJoin(
case smj @ SortMergeJoinExec(_, _, joinType, _,
s1 @ SortExec(_, _, ShuffleStage(left: ShuffleQueryStageExec), _),
s2 @ SortExec(_, _, ShuffleStage(right: ShuffleQueryStageExec), _), false) =>
val newChildren = tryOptimizeJoinChildren(left, right, joinType)
if (newChildren.isDefined) {
val (newLeft, newRight) = newChildren.get
smj.copy(
left = s1.copy(child = newLeft), right = s2.copy(child = newRight), isSkewJoin = true)
} else {
smj
}
tryOptimizeJoinChildren(left, right, joinType).map {
case (newLeft, newRight) =>
smj.copy(
left = s1.copy(child = newLeft), right = s2.copy(child = newRight), isSkewJoin = true)
}.getOrElse(smj)

case shj @ ShuffledHashJoinExec(_, _, joinType, _, _,
ShuffleStage(left: ShuffleQueryStageExec),
ShuffleStage(right: ShuffleQueryStageExec), false) =>
val newChildren = tryOptimizeJoinChildren(left, right, joinType)
if (newChildren.isDefined) {
val (newLeft, newRight) = newChildren.get
shj.copy(left = newLeft, right = newRight, isSkewJoin = true)
} else {
shj
}
tryOptimizeJoinChildren(left, right, joinType).map {
case (newLeft, newRight) =>
shj.copy(left = newLeft, right = newRight, isSkewJoin = true)
}.getOrElse(shj)
}

override def apply(plan: SparkPlan): SparkPlan = {
Expand All @@ -252,7 +252,9 @@ case class OptimizeSkewedJoin(
// SHJ
// Shuffle
// Shuffle
val optimized = ensureRequirements.apply(optimizeSkewJoin(plan))
val optimized = ensureRequirements.apply(optimizeSkewJoin(plan)).transform {
case SkewJoinChildWrapper(child) => child
}
val originCost = costEvaluator.evaluateCost(plan)
val optimizedCost = costEvaluator.evaluateCost(optimized)
// two cases we will pick new plan:
Expand All @@ -277,3 +279,13 @@ case class OptimizeSkewedJoin(
}
}
}

// After optimizing skew joins, we need to run EnsureRequirements again to add necessary shuffles
// caused by skew join optimization. However, this shouldn't apply to the sub-plan under skew join,
// as it's guaranteed to satisfy distribution requirement.
case class SkewJoinChildWrapper(plan: SparkPlan) extends LeafExecNode {
override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException()
override def output: Seq[Attribute] = plan.output
override def outputPartitioning: Partitioning = plan.outputPartitioning
override def outputOrdering: Seq[SortOrder] = plan.outputOrdering
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,36 +35,6 @@ case class SimpleCost(value: Long) extends Cost {
}
}

/**
* A skew join aware implementation of [[Cost]], which consider shuffle number and skew join number.
*
* We always pick the cost which has more skew join even if it introduces one or more extra shuffle.
* Otherwise, if two costs have the same number of skew join or no skew join, we will pick the one
* with small number of shuffle.
*/
case class SkewJoinAwareCost(
numShuffles: Int,
numSkewJoins: Int) extends Cost {
override def compare(that: Cost): Int = that match {
case other: SkewJoinAwareCost =>
// If more skew joins are optimized or less shuffle nodes, it means the cost is lower
if (numSkewJoins > other.numSkewJoins) {
-1
} else if (numSkewJoins < other.numSkewJoins) {
1
} else if (numShuffles < other.numShuffles) {
-1
} else if (numShuffles > other.numShuffles) {
1
} else {
0
}

case _ =>
throw QueryExecutionErrors.cannotCompareCostWithTargetCostError(that.toString)
}
}

/**
* A skew join aware implementation of [[CostEvaluator]], which counts the number of
* [[ShuffleExchangeLike]] nodes and skew join nodes in the plan.
Expand All @@ -79,7 +49,9 @@ case class SimpleCostEvaluator(forceOptimizeSkewedJoin: Boolean) extends CostEva
val numSkewJoins = plan.collect {
case j: ShuffledJoin if j.isSkewJoin => j
}.size
SkewJoinAwareCost(numShuffles, numSkewJoins)
// We put `-numSkewJoins` in the first 32 bits of the long value, so that it's compared first
// when comparing the cost, and larger `numSkewJoins` means lower cost.
SimpleCost(-numSkewJoins.toLong << 32 | numShuffles)
} else {
SimpleCost(numShuffles)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ case class EnsureRequirements(
assert(requiredChildDistributions.length == originalChildren.length)
assert(requiredChildOrderings.length == originalChildren.length)
// Ensure that the operator's children satisfy their output distribution requirements.
var newChildren = originalChildren.zip(requiredChildDistributions).map {
var children = originalChildren.zip(requiredChildDistributions).map {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel the original name children is better, as it's a var and we keep updating it.

case (child, distribution) if child.outputPartitioning.satisfies(distribution) =>
child
case (child, BroadcastDistribution(mode)) =>
Expand All @@ -74,7 +74,7 @@ case class EnsureRequirements(
}.map(_._2)

val childrenNumPartitions =
childrenIndexes.map(newChildren(_).outputPartitioning.numPartitions).toSet
childrenIndexes.map(children(_).outputPartitioning.numPartitions).toSet

if (childrenNumPartitions.size > 1) {
// Get the number of partitions which is explicitly required by the distributions.
Expand All @@ -92,7 +92,7 @@ case class EnsureRequirements(
// 1. We should avoid shuffling these children.
// 2. We should have a reasonable parallelism.
val nonShuffleChildrenNumPartitions =
childrenIndexes.map(newChildren).filterNot(_.isInstanceOf[ShuffleExchangeExec])
childrenIndexes.map(children).filterNot(_.isInstanceOf[ShuffleExchangeExec])
.map(_.outputPartitioning.numPartitions)
val expectedChildrenNumPartitions = if (nonShuffleChildrenNumPartitions.nonEmpty) {
if (nonShuffleChildrenNumPartitions.length == childrenIndexes.length) {
Expand All @@ -111,7 +111,7 @@ case class EnsureRequirements(

val targetNumPartitions = requiredNumPartitions.getOrElse(expectedChildrenNumPartitions)

newChildren = newChildren.zip(requiredChildDistributions).zipWithIndex.map {
children = children.zip(requiredChildDistributions).zipWithIndex.map {
case ((child, distribution), index) if childrenIndexes.contains(index) =>
if (child.outputPartitioning.numPartitions == targetNumPartitions) {
child
Expand All @@ -129,7 +129,7 @@ case class EnsureRequirements(
}

// Now that we've performed any necessary shuffles, add sorts to guarantee output orderings:
newChildren = newChildren.zip(requiredChildOrderings).map { case (child, requiredOrdering) =>
children = children.zip(requiredChildOrderings).map { case (child, requiredOrdering) =>
// If child.outputOrdering already satisfies the requiredOrdering, we do not need to sort.
if (SortOrder.orderingSatisfies(child.outputOrdering, requiredOrdering)) {
child
Expand All @@ -138,7 +138,7 @@ case class EnsureRequirements(
}
}

newChildren
children
}

private def reorder(
Expand Down