Skip to content

Commit 4a6b2b9

Browse files
ulysses-youcloud-fan
authored andcommitted
[SPARK-33832][SQL] Support optimize skewed join even if introduce extra shuffle
### What changes were proposed in this pull request? - move the rule `OptimizeSkewedJoin` from stage optimization phase to stage preparation phase. - run the rule `EnsureRequirements` one more time after the `OptimizeSkewedJoin` rule in the stage preparation phase. - add `SkewJoinAwareCost` to support estimate skewed join cost - add new config to decide if force optimize skewed join - in `OptimizeSkewedJoin`, we generate 2 physical plans, one with skew join optimization and one without. Then we use the cost evaluator w.r.t. the force-skew-join flag and pick the plan with lower cost. ### Why are the changes needed? In general, skewed join has more impact on performance than once more shuffle. It makes sense to force optimize skewed join even if introduce extra shuffle. A common case: ``` HashAggregate SortMergJoin Sort Exchange Sort Exchange ``` and after this PR, the plan looks like: ``` HashAggregate Exchange SortMergJoin (isSkew=true) Sort Exchange Sort Exchange ``` Note that, the new introduced shuffle also can be optimized by AQE. ### Does this PR introduce _any_ user-facing change? Yes, a new config. ### How was this patch tested? * Add new test * pass exists test `SPARK-30524: Do not optimize skew join if introduce additional shuffle` * pass exists test `SPARK-33551: Do not use custom shuffle reader for repartition` Closes #32816 from ulysses-you/support-extra-shuffle. Authored-by: ulysses-you <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent e1e1961 commit 4a6b2b9

File tree

6 files changed

+217
-58
lines changed

6 files changed

+217
-58
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -666,6 +666,13 @@ object SQLConf {
666666
.booleanConf
667667
.createWithDefault(true)
668668

669+
val ADAPTIVE_FORCE_OPTIMIZE_SKEWED_JOIN =
670+
buildConf("spark.sql.adaptive.forceOptimizeSkewedJoin")
671+
.doc("When true, force enable OptimizeSkewedJoin even if it introduces extra shuffle.")
672+
.version("3.3.0")
673+
.booleanConf
674+
.createWithDefault(false)
675+
669676
val ADAPTIVE_CUSTOM_COST_EVALUATOR_CLASS =
670677
buildConf("spark.sql.adaptive.customCostEvaluatorClass")
671678
.doc("The custom cost evaluator class to be used for adaptive execution. If not being set," +

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -97,27 +97,36 @@ case class AdaptiveSparkPlanExec(
9797
AQEUtils.getRequiredDistribution(inputPlan)
9898
}
9999

100+
@transient private val costEvaluator =
101+
conf.getConf(SQLConf.ADAPTIVE_CUSTOM_COST_EVALUATOR_CLASS) match {
102+
case Some(className) => CostEvaluator.instantiate(className, session.sparkContext.getConf)
103+
case _ => SimpleCostEvaluator(conf.getConf(SQLConf.ADAPTIVE_FORCE_OPTIMIZE_SKEWED_JOIN))
104+
}
105+
100106
// A list of physical plan rules to be applied before creation of query stages. The physical
101107
// plan should reach a final status of query stages (i.e., no more addition or removal of
102108
// Exchange nodes) after running these rules.
103-
@transient private val queryStagePreparationRules: Seq[Rule[SparkPlan]] = Seq(
104-
RemoveRedundantProjects,
109+
@transient private val queryStagePreparationRules: Seq[Rule[SparkPlan]] = {
105110
// For cases like `df.repartition(a, b).select(c)`, there is no distribution requirement for
106111
// the final plan, but we do need to respect the user-specified repartition. Here we ask
107112
// `EnsureRequirements` to not optimize out the user-specified repartition-by-col to work
108113
// around this case.
109-
EnsureRequirements(optimizeOutRepartition = requiredDistribution.isDefined),
110-
RemoveRedundantSorts,
111-
DisableUnnecessaryBucketedScan
112-
) ++ context.session.sessionState.queryStagePrepRules
114+
val ensureRequirements =
115+
EnsureRequirements(requiredDistribution.isDefined, requiredDistribution)
116+
Seq(
117+
RemoveRedundantProjects,
118+
ensureRequirements,
119+
RemoveRedundantSorts,
120+
DisableUnnecessaryBucketedScan,
121+
OptimizeSkewedJoin(ensureRequirements, costEvaluator)
122+
) ++ context.session.sessionState.queryStagePrepRules
123+
}
113124

114125
// A list of physical optimizer rules to be applied to a new stage before its execution. These
115126
// optimizations should be stage-independent.
116127
@transient private val queryStageOptimizerRules: Seq[Rule[SparkPlan]] = Seq(
117128
PlanAdaptiveDynamicPruningFilters(this),
118129
ReuseAdaptiveSubquery(context.subqueryCache),
119-
// Skew join does not handle `AQEShuffleRead` so needs to be applied first.
120-
OptimizeSkewedJoin,
121130
OptimizeSkewInRebalancePartitions,
122131
CoalesceShufflePartitions(context.session),
123132
// `OptimizeShuffleWithLocalRead` needs to make use of 'AQEShuffleReadExec.partitionSpecs'
@@ -169,12 +178,6 @@ case class AdaptiveSparkPlanExec(
169178
optimized
170179
}
171180

172-
@transient private val costEvaluator =
173-
conf.getConf(SQLConf.ADAPTIVE_CUSTOM_COST_EVALUATOR_CLASS) match {
174-
case Some(className) => CostEvaluator.instantiate(className, session.sparkContext.getConf)
175-
case _ => SimpleCostEvaluator
176-
}
177-
178181
@transient val initialPlan = context.session.withActive {
179182
applyPhysicalRules(
180183
inputPlan, queryStagePreparationRules, Some((planChangeLogger, "AQE Preparations")))

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ import scala.collection.mutable
2222
import org.apache.commons.io.FileUtils
2323

2424
import org.apache.spark.sql.catalyst.plans._
25+
import org.apache.spark.sql.catalyst.rules.Rule
2526
import org.apache.spark.sql.execution._
26-
import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleOrigin}
27+
import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, EnsureRequirements}
2728
import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec}
2829
import org.apache.spark.sql.internal.SQLConf
2930

@@ -48,9 +49,10 @@ import org.apache.spark.sql.internal.SQLConf
4849
* (L3, R3-1), (L3, R3-2),
4950
* (L4-1, R4-1), (L4-2, R4-1), (L4-1, R4-2), (L4-2, R4-2)
5051
*/
51-
object OptimizeSkewedJoin extends AQEShuffleReadRule {
52-
53-
override val supportedShuffleOrigins: Seq[ShuffleOrigin] = Seq(ENSURE_REQUIREMENTS)
52+
case class OptimizeSkewedJoin(
53+
ensureRequirements: EnsureRequirements,
54+
costEvaluator: CostEvaluator)
55+
extends Rule[SparkPlan] {
5456

5557
/**
5658
* A partition is considered as a skewed partition if its size is larger than the median
@@ -250,15 +252,26 @@ object OptimizeSkewedJoin extends AQEShuffleReadRule {
250252
// SHJ
251253
// Shuffle
252254
// Shuffle
253-
optimizeSkewJoin(plan)
255+
val optimized = ensureRequirements.apply(optimizeSkewJoin(plan))
256+
val originCost = costEvaluator.evaluateCost(plan)
257+
val optimizedCost = costEvaluator.evaluateCost(optimized)
258+
// two cases we will pick new plan:
259+
// 1. optimize the skew join without extra shuffle
260+
// 2. optimize the skew join with extra shuffle but the costEvaluator think it's better
261+
if (optimizedCost <= originCost) {
262+
optimized
263+
} else {
264+
plan
265+
}
254266
} else {
255267
plan
256268
}
257269
}
258270

259271
object ShuffleStage {
260272
def unapply(plan: SparkPlan): Option[ShuffleQueryStageExec] = plan match {
261-
case s: ShuffleQueryStageExec if s.mapStats.isDefined && isSupported(s.shuffle) =>
273+
case s: ShuffleQueryStageExec if s.isMaterialized && s.mapStats.isDefined &&
274+
s.shuffle.shuffleOrigin == ENSURE_REQUIREMENTS =>
262275
Some(s)
263276
case _ => None
264277
}

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/simpleCosting.scala

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.adaptive
2020
import org.apache.spark.sql.errors.QueryExecutionErrors
2121
import org.apache.spark.sql.execution.SparkPlan
2222
import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
23+
import org.apache.spark.sql.execution.joins.ShuffledJoin
2324

2425
/**
2526
* A simple implementation of [[Cost]], which takes a number of [[Long]] as the cost value.
@@ -35,15 +36,52 @@ case class SimpleCost(value: Long) extends Cost {
3536
}
3637

3738
/**
38-
* A simple implementation of [[CostEvaluator]], which counts the number of
39-
* [[ShuffleExchangeLike]] nodes in the plan.
39+
* A skew join aware implementation of [[Cost]], which consider shuffle number and skew join number.
40+
*
41+
* We always pick the cost which has more skew join even if it introduces one or more extra shuffle.
42+
* Otherwise, if two costs have the same number of skew join or no skew join, we will pick the one
43+
* with small number of shuffle.
4044
*/
41-
object SimpleCostEvaluator extends CostEvaluator {
45+
case class SkewJoinAwareCost(
46+
numShuffles: Int,
47+
numSkewJoins: Int) extends Cost {
48+
override def compare(that: Cost): Int = that match {
49+
case other: SkewJoinAwareCost =>
50+
// If more skew joins are optimized or less shuffle nodes, it means the cost is lower
51+
if (numSkewJoins > other.numSkewJoins) {
52+
-1
53+
} else if (numSkewJoins < other.numSkewJoins) {
54+
1
55+
} else if (numShuffles < other.numShuffles) {
56+
-1
57+
} else if (numShuffles > other.numShuffles) {
58+
1
59+
} else {
60+
0
61+
}
62+
63+
case _ =>
64+
throw QueryExecutionErrors.cannotCompareCostWithTargetCostError(that.toString)
65+
}
66+
}
4267

68+
/**
69+
* A skew join aware implementation of [[CostEvaluator]], which counts the number of
70+
* [[ShuffleExchangeLike]] nodes and skew join nodes in the plan.
71+
*/
72+
case class SimpleCostEvaluator(forceOptimizeSkewedJoin: Boolean) extends CostEvaluator {
4373
override def evaluateCost(plan: SparkPlan): Cost = {
44-
val cost = plan.collect {
74+
val numShuffles = plan.collect {
4575
case s: ShuffleExchangeLike => s
4676
}.size
47-
SimpleCost(cost)
77+
78+
if (forceOptimizeSkewedJoin) {
79+
val numSkewJoins = plan.collect {
80+
case j: ShuffledJoin if j.isSkewJoin => j
81+
}.size
82+
SkewJoinAwareCost(numShuffles, numSkewJoins)
83+
} else {
84+
SimpleCost(numShuffles)
85+
}
4886
}
4987
}

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala

Lines changed: 63 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -38,26 +38,31 @@ import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoin
3838
* but can be false in AQE when AQE optimization may change the plan
3939
* output partitioning and need to retain the user-specified
4040
* repartition shuffles in the plan.
41+
* @param requiredDistribution The root required distribution we should ensure. This value is used
42+
* in AQE in case we change final stage output partitioning.
4143
*/
42-
case class EnsureRequirements(optimizeOutRepartition: Boolean = true) extends Rule[SparkPlan] {
43-
44-
private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = {
45-
val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution
46-
val requiredChildOrderings: Seq[Seq[SortOrder]] = operator.requiredChildOrdering
47-
var children: Seq[SparkPlan] = operator.children
48-
assert(requiredChildDistributions.length == children.length)
49-
assert(requiredChildOrderings.length == children.length)
44+
case class EnsureRequirements(
45+
optimizeOutRepartition: Boolean = true,
46+
requiredDistribution: Option[Distribution] = None)
47+
extends Rule[SparkPlan] {
5048

49+
private def ensureDistributionAndOrdering(
50+
originalChildren: Seq[SparkPlan],
51+
requiredChildDistributions: Seq[Distribution],
52+
requiredChildOrderings: Seq[Seq[SortOrder]],
53+
shuffleOrigin: ShuffleOrigin): Seq[SparkPlan] = {
54+
assert(requiredChildDistributions.length == originalChildren.length)
55+
assert(requiredChildOrderings.length == originalChildren.length)
5156
// Ensure that the operator's children satisfy their output distribution requirements.
52-
children = children.zip(requiredChildDistributions).map {
57+
var newChildren = originalChildren.zip(requiredChildDistributions).map {
5358
case (child, distribution) if child.outputPartitioning.satisfies(distribution) =>
5459
child
5560
case (child, BroadcastDistribution(mode)) =>
5661
BroadcastExchangeExec(mode, child)
5762
case (child, distribution) =>
5863
val numPartitions = distribution.requiredNumPartitions
5964
.getOrElse(conf.numShufflePartitions)
60-
ShuffleExchangeExec(distribution.createPartitioning(numPartitions), child)
65+
ShuffleExchangeExec(distribution.createPartitioning(numPartitions), child, shuffleOrigin)
6166
}
6267

6368
// Get the indexes of children which have specified distribution requirements and need to have
@@ -69,7 +74,7 @@ case class EnsureRequirements(optimizeOutRepartition: Boolean = true) extends Ru
6974
}.map(_._2)
7075

7176
val childrenNumPartitions =
72-
childrenIndexes.map(children(_).outputPartitioning.numPartitions).toSet
77+
childrenIndexes.map(newChildren(_).outputPartitioning.numPartitions).toSet
7378

7479
if (childrenNumPartitions.size > 1) {
7580
// Get the number of partitions which is explicitly required by the distributions.
@@ -78,7 +83,7 @@ case class EnsureRequirements(optimizeOutRepartition: Boolean = true) extends Ru
7883
index => requiredChildDistributions(index).requiredNumPartitions
7984
}.toSet
8085
assert(numPartitionsSet.size <= 1,
81-
s"$operator have incompatible requirements of the number of partitions for its children")
86+
s"$requiredChildDistributions have incompatible requirements of the number of partitions")
8287
numPartitionsSet.headOption
8388
}
8489

@@ -87,7 +92,7 @@ case class EnsureRequirements(optimizeOutRepartition: Boolean = true) extends Ru
8792
// 1. We should avoid shuffling these children.
8893
// 2. We should have a reasonable parallelism.
8994
val nonShuffleChildrenNumPartitions =
90-
childrenIndexes.map(children).filterNot(_.isInstanceOf[ShuffleExchangeExec])
95+
childrenIndexes.map(newChildren).filterNot(_.isInstanceOf[ShuffleExchangeExec])
9196
.map(_.outputPartitioning.numPartitions)
9297
val expectedChildrenNumPartitions = if (nonShuffleChildrenNumPartitions.nonEmpty) {
9398
if (nonShuffleChildrenNumPartitions.length == childrenIndexes.length) {
@@ -106,7 +111,7 @@ case class EnsureRequirements(optimizeOutRepartition: Boolean = true) extends Ru
106111

107112
val targetNumPartitions = requiredNumPartitions.getOrElse(expectedChildrenNumPartitions)
108113

109-
children = children.zip(requiredChildDistributions).zipWithIndex.map {
114+
newChildren = newChildren.zip(requiredChildDistributions).zipWithIndex.map {
110115
case ((child, distribution), index) if childrenIndexes.contains(index) =>
111116
if (child.outputPartitioning.numPartitions == targetNumPartitions) {
112117
child
@@ -124,7 +129,7 @@ case class EnsureRequirements(optimizeOutRepartition: Boolean = true) extends Ru
124129
}
125130

126131
// Now that we've performed any necessary shuffles, add sorts to guarantee output orderings:
127-
children = children.zip(requiredChildOrderings).map { case (child, requiredOrdering) =>
132+
newChildren = newChildren.zip(requiredChildOrderings).map { case (child, requiredOrdering) =>
128133
// If child.outputOrdering already satisfies the requiredOrdering, we do not need to sort.
129134
if (SortOrder.orderingSatisfies(child.outputOrdering, requiredOrdering)) {
130135
child
@@ -133,7 +138,7 @@ case class EnsureRequirements(optimizeOutRepartition: Boolean = true) extends Ru
133138
}
134139
}
135140

136-
operator.withNewChildren(children)
141+
newChildren
137142
}
138143

139144
private def reorder(
@@ -254,25 +259,50 @@ case class EnsureRequirements(optimizeOutRepartition: Boolean = true) extends Ru
254259
}
255260
}
256261

257-
def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
258-
case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, shuffleOrigin)
259-
if optimizeOutRepartition &&
260-
(shuffleOrigin == REPARTITION_BY_COL || shuffleOrigin == REPARTITION_BY_NUM) =>
261-
def hasSemanticEqualPartitioning(partitioning: Partitioning): Boolean = {
262-
partitioning match {
263-
case lower: HashPartitioning if upper.semanticEquals(lower) => true
264-
case lower: PartitioningCollection =>
265-
lower.partitionings.exists(hasSemanticEqualPartitioning)
266-
case _ => false
262+
def apply(plan: SparkPlan): SparkPlan = {
263+
val newPlan = plan.transformUp {
264+
case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, shuffleOrigin)
265+
if optimizeOutRepartition &&
266+
(shuffleOrigin == REPARTITION_BY_COL || shuffleOrigin == REPARTITION_BY_NUM) =>
267+
def hasSemanticEqualPartitioning(partitioning: Partitioning): Boolean = {
268+
partitioning match {
269+
case lower: HashPartitioning if upper.semanticEquals(lower) => true
270+
case lower: PartitioningCollection =>
271+
lower.partitionings.exists(hasSemanticEqualPartitioning)
272+
case _ => false
273+
}
267274
}
268-
}
269-
if (hasSemanticEqualPartitioning(child.outputPartitioning)) {
270-
child
275+
if (hasSemanticEqualPartitioning(child.outputPartitioning)) {
276+
child
277+
} else {
278+
operator
279+
}
280+
281+
case operator: SparkPlan =>
282+
val reordered = reorderJoinPredicates(operator)
283+
val newChildren = ensureDistributionAndOrdering(
284+
reordered.children,
285+
reordered.requiredChildDistribution,
286+
reordered.requiredChildOrdering,
287+
ENSURE_REQUIREMENTS)
288+
reordered.withNewChildren(newChildren)
289+
}
290+
291+
if (requiredDistribution.isDefined) {
292+
val shuffleOrigin = if (requiredDistribution.get.requiredNumPartitions.isDefined) {
293+
REPARTITION_BY_NUM
271294
} else {
272-
operator
295+
REPARTITION_BY_COL
273296
}
274-
275-
case operator: SparkPlan =>
276-
ensureDistributionAndOrdering(reorderJoinPredicates(operator))
297+
val finalPlan = ensureDistributionAndOrdering(
298+
newPlan :: Nil,
299+
requiredDistribution.get :: Nil,
300+
Seq(Nil),
301+
shuffleOrigin)
302+
assert(finalPlan.size == 1)
303+
finalPlan.head
304+
} else {
305+
newPlan
306+
}
277307
}
278308
}

0 commit comments

Comments
 (0)