Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Support DPP + AQE when find the broadcast exchange can reuse
  • Loading branch information
JkSelf committed Apr 15, 2021
commit 3bc4baf59335f63f003adf74c87ddb2992e515cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import scala.collection.mutable
import scala.concurrent.ExecutionContext
import scala.util.control.NonFatal

import org.apache.spark.SparkException
import org.apache.spark.{broadcast, SparkException}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
Expand Down Expand Up @@ -94,7 +94,7 @@ case class AdaptiveSparkPlanExec(
// A list of physical optimizer rules to be applied to a new stage before its execution. These
// optimizations should be stage-independent.
@transient private val queryStageOptimizerRules: Seq[Rule[SparkPlan]] = Seq(
PlanAdaptiveDynamicPruningFilters(context.stageCache),
PlanAdaptiveDynamicPruningFilters(inputPlan),
ReuseAdaptiveSubquery(context.subqueryCache),
CoalesceShufflePartitions(context.session),
// The following two rules need to make use of 'CustomShuffleReaderExec.partitionSpecs'
Expand Down Expand Up @@ -310,6 +310,11 @@ case class AdaptiveSparkPlanExec(
rdd
}

override def doExecuteBroadcast[T](): broadcast.Broadcast[T] = {
val broadcastPlan = getFinalPhysicalPlan()
broadcastPlan.doExecuteBroadcast()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: getFinalPhysicalPlan().doExecuteBroadcast()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

}

protected override def stringArgs: Iterator[Any] = Iterator(s"isFinalPlan=$isFinalPlan")

override def generateTreeString(
Expand Down Expand Up @@ -476,7 +481,7 @@ case class AdaptiveSparkPlanExec(
throw new IllegalStateException(
"Custom columnar rules cannot transform shuffle node to something else.")
}
ShuffleQueryStageExec(currentStageId, newShuffle, s.canonicalized)
ShuffleQueryStageExec(currentStageId, newShuffle, s.child.canonicalized)
case b: BroadcastExchangeLike =>
val newBroadcast = applyPhysicalRules(
b.withNewChildren(Seq(optimizedPlan)),
Expand All @@ -486,7 +491,7 @@ case class AdaptiveSparkPlanExec(
throw new IllegalStateException(
"Custom columnar rules cannot transform broadcast node to something else.")
}
BroadcastQueryStageExec(currentStageId, newBroadcast, b.canonicalized)
BroadcastQueryStageExec(currentStageId, newBroadcast, b.child.canonicalized)
}
currentStageId += 1
setLogicalLinkForNewQueryStage(queryStage, e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,18 @@

package org.apache.spark.sql.execution.adaptive

import scala.collection.concurrent.TrieMap

import org.apache.spark.sql.catalyst.expressions.{BindReferences, DynamicPruningExpression, Literal}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
import org.apache.spark.sql.execution.joins.{HashedRelationBroadcastMode, HashJoin}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashedRelationBroadcastMode, HashJoin}

/**
* A rule to insert dynamic pruning predicates in order to reuse the results of broadcast.
*/
case class PlanAdaptiveDynamicPruningFilters(
stageCache: TrieMap[SparkPlan, QueryStageExec]) extends Rule[SparkPlan] {
originalPlan: SparkPlan) extends Rule[SparkPlan] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rootPlan

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

def apply(plan: SparkPlan): SparkPlan = {
if (!conf.dynamicPartitionPruningEnabled) {
return plan
Expand All @@ -41,15 +40,26 @@ case class PlanAdaptiveDynamicPruningFilters(
adaptivePlan: AdaptiveSparkPlanExec), exprId, _)) =>
val packedKeys = BindReferences.bindReferences(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can move this into if (canReuseExchange)

HashJoin.rewriteKeyExpr(buildKeys), adaptivePlan.executedPlan.output)
val mode = HashedRelationBroadcastMode(packedKeys)
// plan a broadcast exchange of the build side of the join
val exchange = BroadcastExchangeExec(mode, adaptivePlan.executedPlan)
val existingStage = stageCache.get(exchange.canonicalized)
if (existingStage.nonEmpty && conf.exchangeReuseEnabled) {
val name = s"dynamicpruning#${exprId.id}"
val reuseQueryStage = existingStage.get.newReuseInstance(0, exchange.output)
val broadcastValues =
SubqueryBroadcastExec(name, index, buildKeys, reuseQueryStage)

val canReuseExchange = conf.exchangeReuseEnabled && buildKeys.nonEmpty &&
originalPlan.find {
case BroadcastHashJoinExec(_, _, _, BuildLeft, _, left, _, _) =>
left.sameResult(adaptivePlan.executedPlan)
case BroadcastHashJoinExec(_, _, _, BuildRight, _, _, right, _) =>
right.sameResult(adaptivePlan.executedPlan)
case _ => false
}.isDefined

if(canReuseExchange) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: if (canReuseExchange)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

val mode = HashedRelationBroadcastMode(packedKeys)
// plan a broadcast exchange of the build side of the join
val exchange = BroadcastExchangeExec(mode, adaptivePlan.executedPlan)
exchange.setLogicalLink(adaptivePlan.executedPlan.logicalLink.get)
val newAdaptivePlan = AdaptiveSparkPlanExec(
exchange, adaptivePlan.context, adaptivePlan.preprocessingRules, true)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto: adaptivePlan.copy(inputPlan = exchange)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated.


val broadcastValues = SubqueryBroadcastExec(
name, index, buildKeys, newAdaptivePlan)
DynamicPruningExpression(InSubqueryExec(value, broadcastValues, exprId))
} else {
DynamicPruningExpression(Literal.TrueLiteral)
Expand Down