diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 3ad3416256c7..f35389d0fbac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -358,6 +358,15 @@ object SQLConf { .booleanConf .createWithDefault(false) + val ADAPTIVE_EXECUTION_FORCE_APPLY = buildConf("spark.sql.adaptive.forceApply") + .internal() + .doc("Adaptive query execution is skipped when the query does not have exchanges or " + + "sub-queries. By setting this config to true (together with " + + s"'${ADAPTIVE_EXECUTION_ENABLED.key}' enabled), Spark will force apply adaptive query " + + "execution for all supported queries.") + .booleanConf + .createWithDefault(false) + val REDUCE_POST_SHUFFLE_PARTITIONS_ENABLED = buildConf("spark.sql.adaptive.shuffle.reducePostShufflePartitions.enabled") .doc(s"When true and '${ADAPTIVE_EXECUTION_ENABLED.key}' is enabled, this enables reducing " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala index 9252827856af..621c063e5a7d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala @@ -40,49 +40,60 @@ case class InsertAdaptiveSparkPlan( private val conf = adaptiveExecutionContext.session.sessionState.conf - def containShuffle(plan: SparkPlan): Boolean = { - plan.find { - case _: Exchange => true - case s: SparkPlan => !s.requiredChildDistribution.forall(_ == UnspecifiedDistribution) - }.isDefined - } - - def containSubQuery(plan: SparkPlan): Boolean = { - plan.find(_.expressions.exists(_.find { - case _: SubqueryExpression => true - case _ => false - }.isDefined)).isDefined - } - override def apply(plan: SparkPlan): SparkPlan = applyInternal(plan, false) private def applyInternal(plan: SparkPlan, isSubquery: Boolean): SparkPlan = plan match { + case _ if !conf.adaptiveExecutionEnabled => plan case _: ExecutedCommandExec => plan - case _ if conf.adaptiveExecutionEnabled && supportAdaptive(plan) - && (isSubquery || containShuffle(plan) || containSubQuery(plan)) => - try { - // Plan sub-queries recursively and pass in the shared stage cache for exchange reuse. Fall - // back to non-adaptive mode if adaptive execution is supported in any of the sub-queries. - val subqueryMap = buildSubqueryMap(plan) - val planSubqueriesRule = PlanAdaptiveSubqueries(subqueryMap) - val preprocessingRules = Seq( - planSubqueriesRule) - // Run pre-processing rules. - val newPlan = AdaptiveSparkPlanExec.applyPhysicalRules(plan, preprocessingRules) - logDebug(s"Adaptive execution enabled for plan: $plan") - AdaptiveSparkPlanExec(newPlan, adaptiveExecutionContext, preprocessingRules, isSubquery) - } catch { - case SubqueryAdaptiveNotSupportedException(subquery) => - logWarning(s"${SQLConf.ADAPTIVE_EXECUTION_ENABLED.key} is enabled " + - s"but is not supported for sub-query: $subquery.") - plan - } - case _ => - if (conf.adaptiveExecutionEnabled) { + case _ if shouldApplyAQE(plan, isSubquery) => + if (supportAdaptive(plan)) { + try { + // Plan sub-queries recursively and pass in the shared stage cache for exchange reuse. + // Fall back to non-AQE mode if AQE is not supported in any of the sub-queries. + val subqueryMap = buildSubqueryMap(plan) + val planSubqueriesRule = PlanAdaptiveSubqueries(subqueryMap) + val preprocessingRules = Seq( + planSubqueriesRule) + // Run pre-processing rules. + val newPlan = AdaptiveSparkPlanExec.applyPhysicalRules(plan, preprocessingRules) + logDebug(s"Adaptive execution enabled for plan: $plan") + AdaptiveSparkPlanExec(newPlan, adaptiveExecutionContext, preprocessingRules, isSubquery) + } catch { + case SubqueryAdaptiveNotSupportedException(subquery) => + logWarning(s"${SQLConf.ADAPTIVE_EXECUTION_ENABLED.key} is enabled " + + s"but is not supported for sub-query: $subquery.") + plan + } + } else { logWarning(s"${SQLConf.ADAPTIVE_EXECUTION_ENABLED.key} is enabled " + s"but is not supported for query: $plan.") + plan } - plan + + case _ => plan + } + + // AQE is only useful when the query has exchanges or sub-queries. This method returns true if + // one of the following conditions is satisfied: + // - The config ADAPTIVE_EXECUTION_FORCE_APPLY is true. + // - The input query is from a sub-query. When this happens, it means we've already decided to + // apply AQE for the main query and we must continue to do it. + // - The query contains exchanges. + // - The query may need to add exchanges. It's an overkill to run `EnsureRequirements` here, so + // we just check `SparkPlan.requiredChildDistribution` and see if it's possible that the + // the query needs to add exchanges later. + // - The query contains sub-query. + private def shouldApplyAQE(plan: SparkPlan, isSubquery: Boolean): Boolean = { + conf.getConf(SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY) || isSubquery || { + plan.find { + case _: Exchange => true + case p if !p.requiredChildDistribution.forall(_ == UnspecifiedDistribution) => true + case p => p.expressions.exists(_.find { + case _: SubqueryExpression => true + case _ => false + }.isDefined) + }.isDefined + } } private def supportAdaptive(plan: SparkPlan): Boolean = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 78a118366474..96e977221e51 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -780,4 +780,13 @@ class AdaptiveQueryExecSuite ) } } + + test("force apply AQE") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") { + val plan = sql("SELECT * FROM testData").queryExecution.executedPlan + assert(plan.isInstanceOf[AdaptiveSparkPlanExec]) + } + } }