diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala index 2fe5b18a75ec..4e97f114e77f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala @@ -21,7 +21,11 @@ import scala.collection.mutable import org.apache.commons.io.FileUtils +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, EnsureRequirements} @@ -196,8 +200,10 @@ case class OptimizeSkewedJoin( } logDebug(s"number of skewed partitions: left $numSkewedLeft, right $numSkewedRight") if (numSkewedLeft > 0 || numSkewedRight > 0) { - Some((AQEShuffleReadExec(left, leftSidePartitions.toSeq), - AQEShuffleReadExec(right, rightSidePartitions.toSeq))) + Some(( + SkewJoinChildWrapper(AQEShuffleReadExec(left, leftSidePartitions.toSeq)), + SkewJoinChildWrapper(AQEShuffleReadExec(right, rightSidePartitions.toSeq)) + )) } else { None } @@ -207,25 +213,19 @@ case class OptimizeSkewedJoin( case smj @ SortMergeJoinExec(_, _, joinType, _, s1 @ SortExec(_, _, ShuffleStage(left: ShuffleQueryStageExec), _), s2 @ SortExec(_, _, ShuffleStage(right: ShuffleQueryStageExec), _), false) => - val newChildren = tryOptimizeJoinChildren(left, right, joinType) - if (newChildren.isDefined) { - val (newLeft, newRight) = newChildren.get - smj.copy( - left = s1.copy(child = newLeft), right = s2.copy(child = newRight), isSkewJoin = true) - } else { - smj - } + tryOptimizeJoinChildren(left, right, joinType).map { + case (newLeft, newRight) => + smj.copy( + left = s1.copy(child = newLeft), right = s2.copy(child = newRight), isSkewJoin = true) + }.getOrElse(smj) case shj @ ShuffledHashJoinExec(_, _, joinType, _, _, ShuffleStage(left: ShuffleQueryStageExec), ShuffleStage(right: ShuffleQueryStageExec), false) => - val newChildren = tryOptimizeJoinChildren(left, right, joinType) - if (newChildren.isDefined) { - val (newLeft, newRight) = newChildren.get - shj.copy(left = newLeft, right = newRight, isSkewJoin = true) - } else { - shj - } + tryOptimizeJoinChildren(left, right, joinType).map { + case (newLeft, newRight) => + shj.copy(left = newLeft, right = newRight, isSkewJoin = true) + }.getOrElse(shj) } override def apply(plan: SparkPlan): SparkPlan = { @@ -252,7 +252,9 @@ case class OptimizeSkewedJoin( // SHJ // Shuffle // Shuffle - val optimized = ensureRequirements.apply(optimizeSkewJoin(plan)) + val optimized = ensureRequirements.apply(optimizeSkewJoin(plan)).transform { + case SkewJoinChildWrapper(child) => child + } val originCost = costEvaluator.evaluateCost(plan) val optimizedCost = costEvaluator.evaluateCost(optimized) // two cases we will pick new plan: @@ -277,3 +279,13 @@ case class OptimizeSkewedJoin( } } } + +// After optimizing skew joins, we need to run EnsureRequirements again to add necessary shuffles +// caused by skew join optimization. However, this shouldn't apply to the sub-plan under skew join, +// as it's guaranteed to satisfy distribution requirement. +case class SkewJoinChildWrapper(plan: SparkPlan) extends LeafExecNode { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + override def output: Seq[Attribute] = plan.output + override def outputPartitioning: Partitioning = plan.outputPartitioning + override def outputOrdering: Seq[SortOrder] = plan.outputOrdering +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/simpleCosting.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/simpleCosting.scala index 864563be3855..9c9c8e13d2d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/simpleCosting.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/simpleCosting.scala @@ -35,36 +35,6 @@ case class SimpleCost(value: Long) extends Cost { } } -/** - * A skew join aware implementation of [[Cost]], which consider shuffle number and skew join number. - * - * We always pick the cost which has more skew join even if it introduces one or more extra shuffle. - * Otherwise, if two costs have the same number of skew join or no skew join, we will pick the one - * with small number of shuffle. - */ -case class SkewJoinAwareCost( - numShuffles: Int, - numSkewJoins: Int) extends Cost { - override def compare(that: Cost): Int = that match { - case other: SkewJoinAwareCost => - // If more skew joins are optimized or less shuffle nodes, it means the cost is lower - if (numSkewJoins > other.numSkewJoins) { - -1 - } else if (numSkewJoins < other.numSkewJoins) { - 1 - } else if (numShuffles < other.numShuffles) { - -1 - } else if (numShuffles > other.numShuffles) { - 1 - } else { - 0 - } - - case _ => - throw QueryExecutionErrors.cannotCompareCostWithTargetCostError(that.toString) - } -} - /** * A skew join aware implementation of [[CostEvaluator]], which counts the number of * [[ShuffleExchangeLike]] nodes and skew join nodes in the plan. @@ -79,7 +49,9 @@ case class SimpleCostEvaluator(forceOptimizeSkewedJoin: Boolean) extends CostEva val numSkewJoins = plan.collect { case j: ShuffledJoin if j.isSkewJoin => j }.size - SkewJoinAwareCost(numShuffles, numSkewJoins) + // We put `-numSkewJoins` in the first 32 bits of the long value, so that it's compared first + // when comparing the cost, and larger `numSkewJoins` means lower cost. + SimpleCost(-numSkewJoins.toLong << 32 | numShuffles) } else { SimpleCost(numShuffles) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 86b2344629d2..e73b48c4e7de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -54,7 +54,7 @@ case class EnsureRequirements( assert(requiredChildDistributions.length == originalChildren.length) assert(requiredChildOrderings.length == originalChildren.length) // Ensure that the operator's children satisfy their output distribution requirements. - var newChildren = originalChildren.zip(requiredChildDistributions).map { + var children = originalChildren.zip(requiredChildDistributions).map { case (child, distribution) if child.outputPartitioning.satisfies(distribution) => child case (child, BroadcastDistribution(mode)) => @@ -74,7 +74,7 @@ case class EnsureRequirements( }.map(_._2) val childrenNumPartitions = - childrenIndexes.map(newChildren(_).outputPartitioning.numPartitions).toSet + childrenIndexes.map(children(_).outputPartitioning.numPartitions).toSet if (childrenNumPartitions.size > 1) { // Get the number of partitions which is explicitly required by the distributions. @@ -92,7 +92,7 @@ case class EnsureRequirements( // 1. We should avoid shuffling these children. // 2. We should have a reasonable parallelism. val nonShuffleChildrenNumPartitions = - childrenIndexes.map(newChildren).filterNot(_.isInstanceOf[ShuffleExchangeExec]) + childrenIndexes.map(children).filterNot(_.isInstanceOf[ShuffleExchangeExec]) .map(_.outputPartitioning.numPartitions) val expectedChildrenNumPartitions = if (nonShuffleChildrenNumPartitions.nonEmpty) { if (nonShuffleChildrenNumPartitions.length == childrenIndexes.length) { @@ -111,7 +111,7 @@ case class EnsureRequirements( val targetNumPartitions = requiredNumPartitions.getOrElse(expectedChildrenNumPartitions) - newChildren = newChildren.zip(requiredChildDistributions).zipWithIndex.map { + children = children.zip(requiredChildDistributions).zipWithIndex.map { case ((child, distribution), index) if childrenIndexes.contains(index) => if (child.outputPartitioning.numPartitions == targetNumPartitions) { child @@ -129,7 +129,7 @@ case class EnsureRequirements( } // Now that we've performed any necessary shuffles, add sorts to guarantee output orderings: - newChildren = newChildren.zip(requiredChildOrderings).map { case (child, requiredOrdering) => + children = children.zip(requiredChildOrderings).map { case (child, requiredOrdering) => // If child.outputOrdering already satisfies the requiredOrdering, we do not need to sort. if (SortOrder.orderingSatisfies(child.outputOrdering, requiredOrdering)) { child @@ -138,7 +138,7 @@ case class EnsureRequirements( } } - newChildren + children } private def reorder(