diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala index f21d459603970..1c1ee7d03a4df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala @@ -231,44 +231,24 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) return plan } - def collectShuffleStages(plan: SparkPlan): Seq[ShuffleQueryStageExec] = plan match { - case stage: ShuffleQueryStageExec => Seq(stage) - case _ => plan.children.flatMap(collectShuffleStages) + // We try to optimize every skewed sort-merge/shuffle-hash joins in the query plan. If this + // introduces extra shuffles, we give up the optimization and return the original query plan, or + // accept the extra shuffles if the force-apply config is true. + // TODO: It's possible that only one skewed join in the query plan leads to extra shuffles and + // we only need to skip optimizing that join. We should make the strategy smarter here. + val optimized = optimizeSkewJoin(plan) + val requirementSatisfied = if (ensureRequirements.requiredDistribution.isDefined) { + ValidateRequirements.validate(optimized, ensureRequirements.requiredDistribution.get) + } else { + ValidateRequirements.validate(optimized) } - - val shuffleStages = collectShuffleStages(plan) - - if (shuffleStages.length == 2) { - // When multi table join, there will be too many complex combination to consider. - // Currently we only handle 2 table join like following use case. - // SMJ - // Sort - // Shuffle - // Sort - // Shuffle - // Or - // SHJ - // Shuffle - // Shuffle - val optimized = optimizeSkewJoin(plan) - val requirementSatisfied = if (ensureRequirements.requiredDistribution.isDefined) { - ValidateRequirements.validate(optimized, ensureRequirements.requiredDistribution.get) - } else { - ValidateRequirements.validate(optimized) + if (requirementSatisfied) { + optimized.transform { + case SkewJoinChildWrapper(child) => child } - // Two cases we will apply the skewed join optimization: - // 1. optimize the skew join without extra shuffle - // 2. optimize the skew join with extra shuffle but the force-apply config is true. - if (requirementSatisfied) { - optimized.transform { - case SkewJoinChildWrapper(child) => child - } - } else if (conf.getConf(SQLConf.ADAPTIVE_FORCE_OPTIMIZE_SKEWED_JOIN)) { - ensureRequirements.apply(optimized).transform { - case SkewJoinChildWrapper(child) => child - } - } else { - plan + } else if (conf.getConf(SQLConf.ADAPTIVE_FORCE_OPTIMIZE_SKEWED_JOIN)) { + ensureRequirements.apply(optimized).transform { + case SkewJoinChildWrapper(child) => child } } else { plan diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 07bc5282faa87..51d476703a768 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -2310,6 +2310,53 @@ class AdaptiveQueryExecSuite assert(bhj.length == 1) } } + + test("SPARK-37328: skew join with 3 tables") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "100", + SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "100", + SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1", + SQLConf.SHUFFLE_PARTITIONS.key -> "10") { + withTempView("skewData1", "skewData2", "skewData3") { + spark + .range(0, 1000, 1, 10) + .selectExpr("id % 3 as key1", "id % 3 as value1") + .createOrReplaceTempView("skewData1") + spark + .range(0, 1000, 1, 10) + .selectExpr("id % 1 as key2", "id as value2") + .createOrReplaceTempView("skewData2") + spark + .range(0, 1000, 1, 10) + .selectExpr("id % 1 as key3", "id as value3") + .createOrReplaceTempView("skewData3") + + // skewedJoin doesn't happen in last stage + val (_, adaptive1) = + runAdaptiveAndVerifyResult("SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 " + + "JOIN skewData3 ON value2 = value3") + val shuffles1 = collect(adaptive1) { + case s: ShuffleExchangeExec => s + } + assert(shuffles1.size == 4) + val smj1 = findTopLevelSortMergeJoin(adaptive1) + assert(smj1.size == 2 && smj1.last.isSkewJoin && !smj1.head.isSkewJoin) + + // Query has two skewJoin in two continuous stages. + val (_, adaptive2) = + runAdaptiveAndVerifyResult("SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 " + + "JOIN skewData3 ON value1 = value3") + val shuffles2 = collect(adaptive2) { + case s: ShuffleExchangeExec => s + } + assert(shuffles2.size == 4) + val smj2 = findTopLevelSortMergeJoin(adaptive2) + assert(smj2.size == 2 && smj2.forall(_.isSkewJoin)) + } + } + } } /**