diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 1b2e802ae939..26c95c0fedd1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -590,7 +590,12 @@ case class AdaptiveSparkPlanExec( // Apply `queryStageOptimizerRules` so that we can reuse subquery. // No need to apply `postStageCreationRules` for `InMemoryTableScanExec` // as it's a leaf node. - TableCacheQueryStageExec(currentStageId, optimizeQueryStage(i, isFinalStage = false)) + val newPlan = optimizeQueryStage(i, isFinalStage = false) + if (!newPlan.isInstanceOf[InMemoryTableScanExec]) { + throw SparkException.internalError( + "Custom AQE rules cannot transform table scan node to something else.") + } + TableCacheQueryStageExec(currentStageId, newPlan) } currentStageId += 1 setLogicalLinkForNewQueryStage(queryStage, plan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala index 1f05adc57a4b..9986e5d47870 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.adaptive import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions.{ListQuery, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{DynamicPruningSubquery, ListQuery, SubqueryExpression} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical.UnspecifiedDistribution import org.apache.spark.sql.catalyst.rules.Rule @@ -126,37 +126,22 @@ case class InsertAdaptiveSparkPlan( * Returns an expression-id-to-execution-plan map for all the sub-queries. * For each sub-query, generate the adaptive execution plan for each sub-query by applying this * rule. + * The returned subquery map holds executed plan, then the [[PlanAdaptiveSubqueries]] can take + * them and create a new subquery. */ - private def buildSubqueryMap(plan: SparkPlan): Map[Long, BaseSubqueryExec] = { - val subqueryMap = mutable.HashMap.empty[Long, BaseSubqueryExec] + private def buildSubqueryMap(plan: SparkPlan): Map[Long, SparkPlan] = { + val subqueryMap = mutable.HashMap.empty[Long, SparkPlan] if (!plan.containsAnyPattern(SCALAR_SUBQUERY, IN_SUBQUERY, DYNAMIC_PRUNING_SUBQUERY)) { return subqueryMap.toMap } plan.foreach(_.expressions.filter(_.containsPattern(PLAN_EXPRESSION)).foreach(_.foreach { - case expressions.ScalarSubquery(p, _, exprId, _, _, _) - if !subqueryMap.contains(exprId.id) => - val executedPlan = compileSubquery(p) - verifyAdaptivePlan(executedPlan, p) - val subquery = SubqueryExec.createForScalarSubquery( - s"subquery#${exprId.id}", executedPlan) - subqueryMap.put(exprId.id, subquery) - case expressions.InSubquery(_, ListQuery(query, _, exprId, _, _, _)) - if !subqueryMap.contains(exprId.id) => - val executedPlan = compileSubquery(query) - verifyAdaptivePlan(executedPlan, query) - val subquery = SubqueryExec(s"subquery#${exprId.id}", executedPlan) - subqueryMap.put(exprId.id, subquery) - case expressions.DynamicPruningSubquery(value, buildPlan, - buildKeys, broadcastKeyIndex, onlyInBroadcast, exprId, _) - if !subqueryMap.contains(exprId.id) => - val executedPlan = compileSubquery(buildPlan) - verifyAdaptivePlan(executedPlan, buildPlan) - - val name = s"dynamicpruning#${exprId.id}" - val subquery = SubqueryAdaptiveBroadcastExec( - name, broadcastKeyIndex, onlyInBroadcast, - buildPlan, buildKeys, executedPlan) - subqueryMap.put(exprId.id, subquery) + case e @ (_: expressions.ScalarSubquery | _: ListQuery | _: DynamicPruningSubquery) => + val subquery = e.asInstanceOf[SubqueryExpression] + if (!subqueryMap.contains(subquery.exprId.id)) { + val executedPlan = compileSubquery(subquery.plan) + verifyAdaptivePlan(executedPlan, subquery.plan) + subqueryMap.put(subquery.exprId.id, executedPlan) + } case _ => })) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala index c3f427405835..5b4a7e50db71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala @@ -20,19 +20,20 @@ package org.apache.spark.sql.execution.adaptive import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, DynamicPruningExpression, ListQuery, Literal} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.TreePattern.{DYNAMIC_PRUNING_SUBQUERY, IN_SUBQUERY, - SCALAR_SUBQUERY} +import org.apache.spark.sql.catalyst.trees.TreePattern.{DYNAMIC_PRUNING_SUBQUERY, IN_SUBQUERY, SCALAR_SUBQUERY} import org.apache.spark.sql.execution -import org.apache.spark.sql.execution.{BaseSubqueryExec, InSubqueryExec, SparkPlan} +import org.apache.spark.sql.execution.{InSubqueryExec, SparkPlan, SubqueryAdaptiveBroadcastExec, SubqueryExec} case class PlanAdaptiveSubqueries( - subqueryMap: Map[Long, BaseSubqueryExec]) extends Rule[SparkPlan] { + subqueryMap: Map[Long, SparkPlan]) extends Rule[SparkPlan] { def apply(plan: SparkPlan): SparkPlan = { plan.transformAllExpressionsWithPruning( _.containsAnyPattern(SCALAR_SUBQUERY, IN_SUBQUERY, DYNAMIC_PRUNING_SUBQUERY)) { case expressions.ScalarSubquery(_, _, exprId, _, _, _) => - execution.ScalarSubquery(subqueryMap(exprId.id), exprId) + val subquery = SubqueryExec.createForScalarSubquery( + s"subquery#${exprId.id}", subqueryMap(exprId.id)) + execution.ScalarSubquery(subquery, exprId) case expressions.InSubquery(values, ListQuery(_, _, exprId, _, _, _)) => val expr = if (values.length == 1) { values.head @@ -43,9 +44,14 @@ case class PlanAdaptiveSubqueries( } ) } - InSubqueryExec(expr, subqueryMap(exprId.id), exprId, shouldBroadcast = true) - case expressions.DynamicPruningSubquery(value, _, _, _, _, exprId, _) => - DynamicPruningExpression(InSubqueryExec(value, subqueryMap(exprId.id), exprId)) + val subquery = SubqueryExec(s"subquery#${exprId.id}", subqueryMap(exprId.id)) + InSubqueryExec(expr, subquery, exprId, shouldBroadcast = true) + case expressions.DynamicPruningSubquery(value, buildPlan, + buildKeys, broadcastKeyIndex, onlyInBroadcast, exprId, _) => + val name = s"dynamicpruning#${exprId.id}" + val subquery = SubqueryAdaptiveBroadcastExec(name, broadcastKeyIndex, onlyInBroadcast, + buildPlan, buildKeys, subqueryMap(exprId.id)) + DynamicPruningExpression(InSubqueryExec(value, subquery, exprId)) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReuseAdaptiveSubquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReuseAdaptiveSubquery.scala index df6849447215..c1d0e93e3b97 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReuseAdaptiveSubquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReuseAdaptiveSubquery.scala @@ -33,16 +33,11 @@ case class ReuseAdaptiveSubquery( plan.transformAllExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION)) { case sub: ExecSubqueryExpression => - // The subquery can be already reused (the same Java object) due to filter pushdown - // of table cache. If it happens, we just need to wrap the current subquery with - // `ReusedSubqueryExec` and no need to update the `reuseMap`. - reuseMap.get(sub.plan.canonicalized).map { subquery => - sub.withNewPlan(ReusedSubqueryExec(subquery)) - }.getOrElse { - reuseMap.putIfAbsent(sub.plan.canonicalized, sub.plan) match { - case Some(subquery) => sub.withNewPlan(ReusedSubqueryExec(subquery)) - case None => sub - } + val newPlan = reuseMap.getOrElseUpdate(sub.plan.canonicalized, sub.plan) + if (newPlan.ne(sub.plan)) { + sub.withNewPlan(ReusedSubqueryExec(newPlan)) + } else { + sub } } }