Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -231,44 +231,24 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements)
return plan
}

def collectShuffleStages(plan: SparkPlan): Seq[ShuffleQueryStageExec] = plan match {
case stage: ShuffleQueryStageExec => Seq(stage)
case _ => plan.children.flatMap(collectShuffleStages)
// We try to optimize every skewed sort-merge/shuffle-hash joins in the query plan. If this
// introduces extra shuffles, we give up the optimization and return the original query plan, or
// accept the extra shuffles if the force-apply config is true.
// TODO: It's possible that only one skewed join in the query plan leads to extra shuffles and
// we only need to skip optimizing that join. We should make the strategy smarter here.
val optimized = optimizeSkewJoin(plan)
val requirementSatisfied = if (ensureRequirements.requiredDistribution.isDefined) {
ValidateRequirements.validate(optimized, ensureRequirements.requiredDistribution.get)
} else {
ValidateRequirements.validate(optimized)
}

val shuffleStages = collectShuffleStages(plan)

if (shuffleStages.length == 2) {
// When multi table join, there will be too many complex combination to consider.
// Currently we only handle 2 table join like following use case.
// SMJ
// Sort
// Shuffle
// Sort
// Shuffle
// Or
// SHJ
// Shuffle
// Shuffle
val optimized = optimizeSkewJoin(plan)
val requirementSatisfied = if (ensureRequirements.requiredDistribution.isDefined) {
ValidateRequirements.validate(optimized, ensureRequirements.requiredDistribution.get)
} else {
ValidateRequirements.validate(optimized)
if (requirementSatisfied) {
optimized.transform {
case SkewJoinChildWrapper(child) => child
}
// Two cases we will apply the skewed join optimization:
// 1. optimize the skew join without extra shuffle
// 2. optimize the skew join with extra shuffle but the force-apply config is true.
if (requirementSatisfied) {
optimized.transform {
case SkewJoinChildWrapper(child) => child
}
} else if (conf.getConf(SQLConf.ADAPTIVE_FORCE_OPTIMIZE_SKEWED_JOIN)) {
ensureRequirements.apply(optimized).transform {
case SkewJoinChildWrapper(child) => child
}
} else {
plan
} else if (conf.getConf(SQLConf.ADAPTIVE_FORCE_OPTIMIZE_SKEWED_JOIN)) {
ensureRequirements.apply(optimized).transform {
case SkewJoinChildWrapper(child) => child
}
} else {
plan
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2310,6 +2310,53 @@ class AdaptiveQueryExecSuite
assert(bhj.length == 1)
}
}

test("SPARK-37328: skew join with 3 tables") {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "100",
SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "100",
SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1",
SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
withTempView("skewData1", "skewData2", "skewData3") {
spark
.range(0, 1000, 1, 10)
.selectExpr("id % 3 as key1", "id % 3 as value1")
.createOrReplaceTempView("skewData1")
spark
.range(0, 1000, 1, 10)
.selectExpr("id % 1 as key2", "id as value2")
.createOrReplaceTempView("skewData2")
spark
.range(0, 1000, 1, 10)
.selectExpr("id % 1 as key3", "id as value3")
.createOrReplaceTempView("skewData3")

// skewedJoin doesn't happen in last stage
val (_, adaptive1) =
runAdaptiveAndVerifyResult("SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 " +
"JOIN skewData3 ON value2 = value3")
val shuffles1 = collect(adaptive1) {
case s: ShuffleExchangeExec => s
}
assert(shuffles1.size == 4)
val smj1 = findTopLevelSortMergeJoin(adaptive1)
assert(smj1.size == 2 && smj1.last.isSkewJoin && !smj1.head.isSkewJoin)

// Query has two skewJoin in two continuous stages.
val (_, adaptive2) =
runAdaptiveAndVerifyResult("SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 " +
"JOIN skewData3 ON value1 = value3")
val shuffles2 = collect(adaptive2) {
case s: ShuffleExchangeExec => s
}
assert(shuffles2.size == 4)
val smj2 = findTopLevelSortMergeJoin(adaptive2)
assert(smj2.size == 2 && smj2.forall(_.isSkewJoin))
}
}
}
}

/**
Expand Down