-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-28560][SQL][followup] support the build side to local shuffle reader as far as possible in BroadcastHashJoin #26289
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
5b7ff2d
e13f637
782827a
f3bb9ce
152aaa6
e510e96
1e947db
573ffcd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,31 +24,33 @@ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartit | |
| import org.apache.spark.sql.catalyst.rules.Rule | ||
| import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} | ||
| import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchangeExec} | ||
| import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BuildLeft, BuildRight} | ||
| import org.apache.spark.sql.internal.SQLConf | ||
|
|
||
| case class OptimizeLocalShuffleReader(conf: SQLConf) extends Rule[SparkPlan] { | ||
|
|
||
| def canUseLocalShuffleReaderLeft(join: BroadcastHashJoinExec): Boolean = { | ||
cloud-fan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| join.buildSide == BuildRight && ShuffleQueryStageExec.isShuffleQueryStageExec(join.left) | ||
| } | ||
|
|
||
| def canUseLocalShuffleReaderRight(join: BroadcastHashJoinExec): Boolean = { | ||
| join.buildSide == BuildLeft && ShuffleQueryStageExec.isShuffleQueryStageExec(join.right) | ||
| } | ||
|
|
||
| override def apply(plan: SparkPlan): SparkPlan = { | ||
| if (!conf.getConf(SQLConf.OPTIMIZE_LOCAL_SHUFFLE_READER_ENABLED)) { | ||
| return plan | ||
| } | ||
|
|
||
| val optimizedPlan = plan.transformDown { | ||
| case join: BroadcastHashJoinExec if canUseLocalShuffleReaderRight(join) => | ||
| val localReader = LocalShuffleReaderExec(join.right.asInstanceOf[QueryStageExec]) | ||
| join.copy(right = localReader) | ||
| case join: BroadcastHashJoinExec if canUseLocalShuffleReaderLeft(join) => | ||
| val localReader = LocalShuffleReaderExec(join.left.asInstanceOf[QueryStageExec]) | ||
| join.copy(left = localReader) | ||
| def collectShuffleStages(plan: SparkPlan): Seq[ShuffleQueryStageExec] = plan match { | ||
| case _: LocalShuffleReaderExec => Nil | ||
| case stage: ShuffleQueryStageExec => Seq(stage) | ||
| case ReusedQueryStageExec(_, stage: ShuffleQueryStageExec, _) => Seq(stage) | ||
| case _ => plan.children.flatMap(collectShuffleStages) | ||
| } | ||
| val shuffleStages = collectShuffleStages(plan) | ||
|
|
||
| val optimizedPlan = if (shuffleStages.isEmpty || | ||
| !shuffleStages.forall(_.plan.canChangeNumPartitions)) { | ||
|
||
| // For the Exchange introduced by repartition, | ||
| // don't apply this rule to avoid additional shuffle introduced for the parent stage. | ||
| plan | ||
| } else { | ||
| plan.transformUp { | ||
| case stage: QueryStageExec if (ShuffleQueryStageExec.isShuffleQueryStageExec(stage)) => | ||
| LocalShuffleReaderExec(stage) | ||
| } | ||
| } | ||
|
|
||
| def numExchanges(plan: SparkPlan): Int = { | ||
|
|
@@ -59,7 +61,6 @@ case class OptimizeLocalShuffleReader(conf: SQLConf) extends Rule[SparkPlan] { | |
|
|
||
| val numExchangeBefore = numExchanges(EnsureRequirements(conf).apply(plan)) | ||
| val numExchangeAfter = numExchanges(EnsureRequirements(conf).apply(optimizedPlan)) | ||
|
|
||
| if (numExchangeAfter > numExchangeBefore) { | ||
| logDebug("OptimizeLocalShuffleReader rule is not applied due" + | ||
| " to additional shuffles will be introduced.") | ||
|
|
@@ -108,25 +109,4 @@ case class LocalShuffleReaderExec(child: QueryStageExec) extends UnaryExecNode { | |
| } | ||
| cachedShuffleRDD | ||
| } | ||
|
|
||
| override def generateTreeString( | ||
| depth: Int, | ||
| lastChildren: Seq[Boolean], | ||
| append: String => Unit, | ||
| verbose: Boolean, | ||
| prefix: String = "", | ||
| addSuffix: Boolean = false, | ||
| maxFields: Int, | ||
| printNodeId: Boolean): Unit = { | ||
| super.generateTreeString(depth, | ||
| lastChildren, | ||
| append, | ||
| verbose, | ||
| prefix, | ||
| addSuffix, | ||
| maxFields, | ||
| printNodeId) | ||
| child.generateTreeString( | ||
| depth + 1, lastChildren :+ true, append, verbose, "", false, maxFields, printNodeId) | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -93,7 +93,7 @@ class AdaptiveQueryExecSuite | |
| assert(smj.size == 1) | ||
| val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) | ||
| assert(bhj.size == 1) | ||
| checkNumLocalShuffleReaders(adaptivePlan, 1) | ||
| checkNumLocalShuffleReaders(adaptivePlan, 2) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -110,7 +110,7 @@ class AdaptiveQueryExecSuite | |
| val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) | ||
| assert(bhj.size == 1) | ||
|
|
||
| checkNumLocalShuffleReaders(adaptivePlan, 1) | ||
| checkNumLocalShuffleReaders(adaptivePlan, 2) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -125,7 +125,7 @@ class AdaptiveQueryExecSuite | |
| assert(smj.size == 1) | ||
| val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) | ||
| assert(bhj.size == 1) | ||
| checkNumLocalShuffleReaders(adaptivePlan, 1) | ||
| checkNumLocalShuffleReaders(adaptivePlan, 2) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -141,7 +141,7 @@ class AdaptiveQueryExecSuite | |
| val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) | ||
| assert(bhj.size == 1) | ||
|
|
||
| checkNumLocalShuffleReaders(adaptivePlan, 1) | ||
| checkNumLocalShuffleReaders(adaptivePlan, 2) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -163,9 +163,38 @@ class AdaptiveQueryExecSuite | |
| assert(smj.size == 3) | ||
| val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) | ||
| assert(bhj.size == 3) | ||
| // The child of remaining one BroadcastHashJoin is not ShuffleQueryStage. | ||
| // So only two LocalShuffleReader. | ||
| checkNumLocalShuffleReaders(adaptivePlan, 2) | ||
| // *(6) BroadcastHashJoin [b#24], [a#33], Inner, BuildLeft | ||
| // :- BroadcastQueryStage 6 | ||
| // : +- BroadcastExchange HashedRelationBroadcastMode | ||
| // : +- LocalShuffleReader | ||
| // : +- ShuffleQueryStage 5 | ||
| // : +- Exchange hashpartitioning(b#24, 5), true, [id=#437] | ||
| // : +- *(5) BroadcastHashJoin [key#13], [a#23], Inner, BuildLeft | ||
| // : :- BroadcastQueryStage 4 | ||
| // : : +- BroadcastExchange HashedRelationBroadcastMode | ||
| // : : +- LocalShuffleReader | ||
| // : : +- ShuffleQueryStage 0 | ||
| // : | ||
| // : +- LocalShuffleReader | ||
| // : +- ShuffleQueryStage 1 | ||
| // : +- Exchange hashpartitioning(a#23, 5), true, [id=#213] | ||
| // : | ||
| // +- *(6) BroadcastHashJoin [n#93], [a#33], Inner, BuildRight | ||
| // :- LocalShuffleReader | ||
| // : +- ShuffleQueryStage 2 | ||
| // : +- Exchange hashpartitioning(n#93, 5), true, [id=#230] | ||
| // : | ||
| // +- BroadcastQueryStage 7 | ||
| // +- BroadcastExchange HashedRelationBroadcastMode | ||
| // +- LocalShuffleReader | ||
| // +- ShuffleQueryStage 3 | ||
|
|
||
| // After applied the 'OptimizeLocalShuffleReader' rule, we can convert all the four | ||
| // shuffle reader to local shuffle reader in the bottom two 'BroadcastHashJoin'. | ||
| // For the opt level 'BroadcastHashJoin', the probe side is not shuffle query stage | ||
| // and the build side shuffle query stage is also converted to local shuffle reader. | ||
|
|
||
| checkNumLocalShuffleReaders(adaptivePlan, 5) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -189,9 +218,32 @@ class AdaptiveQueryExecSuite | |
| assert(smj.size == 3) | ||
| val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) | ||
| assert(bhj.size == 3) | ||
| // The child of remaining two BroadcastHashJoin is not ShuffleQueryStage. | ||
| // So only two LocalShuffleReader. | ||
| checkNumLocalShuffleReaders(adaptivePlan, 1) | ||
| // *(7) BroadcastHashJoin [b#24], [a#33], Inner, BuildLeft | ||
|
||
| // :- BroadcastQueryStage 6 | ||
| // : +- BroadcastExchange HashedRelationBroadcastMode( | ||
| // : +- LocalShuffleReader | ||
| // : +- ShuffleQueryStage 5 | ||
| // : +- Exchange hashpartitioning(b#24, 5), true, [id=#452] | ||
| // : +- *(5) BroadcastHashJoin [key#13], [a#23], Inner, BuildLeft | ||
| // : :- BroadcastQueryStage 4 | ||
| // : : +- BroadcastExchange HashedRelationBroadcastMode( | ||
| // : : +- LocalShuffleReader | ||
| // : : +- ShuffleQueryStage 0 | ||
| // : | ||
| // : +- LocalShuffleReader | ||
| // : +- ShuffleQueryStage 1 | ||
| // : | ||
| // +- *(7) BroadcastHashJoin [n#93], [a#33], Inner, BuildRight | ||
| // :- LocalShuffleReader | ||
| // : +- ShuffleQueryStage 2 | ||
| // | ||
| // +- BroadcastQueryStage 7 | ||
| // +- BroadcastExchange HashedRelationBroadcastMode | ||
| // +- *(6) HashAggregate(keys=[a#33], functions=[sum(cast(b#34 as bigint))], | ||
| // output=[a#33, sum(b)#219L]) | ||
| // +- CoalescedShuffleReader [0] | ||
| // +- ShuffleQueryStage 3 | ||
| checkNumLocalShuffleReaders(adaptivePlan, 4) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -215,9 +267,32 @@ class AdaptiveQueryExecSuite | |
| assert(smj.size == 3) | ||
| val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) | ||
| assert(bhj.size == 3) | ||
| // The child of remaining two BroadcastHashJoin is not ShuffleQueryStage. | ||
| // So only two LocalShuffleReader. | ||
| checkNumLocalShuffleReaders(adaptivePlan, 1) | ||
| // *(6) BroadcastHashJoin [cast(value#14 as int)], [a#220], Inner, BuildLeft | ||
|
||
| // :- BroadcastQueryStage 7 | ||
| // : +- BroadcastExchange HashedRelationBroadcastMode | ||
| // : +- LocalShuffleReader | ||
| // : +- ShuffleQueryStage 6 | ||
| // : +- Exchange hashpartitioning(cast(value#14 as int), 5), true, [id=#537] | ||
| // : +- *(5) BroadcastHashJoin [key#13], [a#23], Inner, BuildLeft | ||
| // : :- BroadcastQueryStage 4 | ||
| // : : +- BroadcastExchange HashedRelationBroadcastMode | ||
| // : : +- LocalShuffleReader | ||
| // : : +- ShuffleQueryStage 0 | ||
| // : | ||
| // : +- LocalShuffleReader | ||
| // : +- ShuffleQueryStage 1 | ||
| // : | ||
| // +- *(6) BroadcastHashJoin [n#93], [b#218], Inner, BuildLeft | ||
| // :- BroadcastQueryStage 5 | ||
| // : +- BroadcastExchange HashedRelationBroadcastMode | ||
| // : +- LocalShuffleReader | ||
| // : +- ShuffleQueryStage 2 | ||
| // : | ||
| // +- *(6) Filter isnotnull(b#218) | ||
| // +- *(6) HashAggregate(keys=[a#220], functions=[max(b#221)], output=[a#220, b#218]) | ||
| // +- CoalescedShuffleReader [0] | ||
| // +- ShuffleQueryStage 3 | ||
| checkNumLocalShuffleReaders(adaptivePlan, 4) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -232,7 +307,7 @@ class AdaptiveQueryExecSuite | |
| assert(smj.size == 3) | ||
| val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) | ||
| assert(bhj.size == 2) | ||
| checkNumLocalShuffleReaders(adaptivePlan, 2) | ||
| checkNumLocalShuffleReaders(adaptivePlan, 4) | ||
| // Even with local shuffle reader, the query statge reuse can also work. | ||
| val ex = findReusedExchange(adaptivePlan) | ||
| assert(ex.size == 1) | ||
|
|
@@ -250,7 +325,7 @@ class AdaptiveQueryExecSuite | |
| assert(smj.size == 1) | ||
| val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) | ||
| assert(bhj.size == 1) | ||
| checkNumLocalShuffleReaders(adaptivePlan, 1) | ||
| checkNumLocalShuffleReaders(adaptivePlan, 2) | ||
| // Even with local shuffle reader, the query statge reuse can also work. | ||
| val ex = findReusedExchange(adaptivePlan) | ||
| assert(ex.size == 1) | ||
|
|
@@ -270,7 +345,7 @@ class AdaptiveQueryExecSuite | |
| assert(smj.size == 1) | ||
| val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) | ||
| assert(bhj.size == 1) | ||
| checkNumLocalShuffleReaders(adaptivePlan, 1) | ||
| checkNumLocalShuffleReaders(adaptivePlan, 2) | ||
| // Even with local shuffle reader, the query statge reuse can also work. | ||
| val ex = findReusedExchange(adaptivePlan) | ||
| assert(ex.nonEmpty) | ||
|
|
@@ -291,7 +366,7 @@ class AdaptiveQueryExecSuite | |
| assert(smj.size == 1) | ||
| val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) | ||
| assert(bhj.size == 1) | ||
| checkNumLocalShuffleReaders(adaptivePlan, 1) | ||
| checkNumLocalShuffleReaders(adaptivePlan, 2) | ||
| // Even with local shuffle reader, the query statge reuse can also work. | ||
| val ex = findReusedExchange(adaptivePlan) | ||
| assert(ex.isEmpty) | ||
|
|
@@ -315,7 +390,7 @@ class AdaptiveQueryExecSuite | |
| assert(smj.size == 1) | ||
| val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) | ||
| assert(bhj.size == 1) | ||
| checkNumLocalShuffleReaders(adaptivePlan, 1) | ||
| checkNumLocalShuffleReaders(adaptivePlan, 2) | ||
| // Even with local shuffle reader, the query statge reuse can also work. | ||
| val ex = findReusedExchange(adaptivePlan) | ||
| assert(ex.nonEmpty) | ||
|
|
@@ -393,8 +468,10 @@ class AdaptiveQueryExecSuite | |
| assert(smj.size == 2) | ||
| val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) | ||
| assert(bhj.size == 1) | ||
| // additional shuffle exchange introduced, so revert OptimizeLocalShuffleReader rule. | ||
| checkNumLocalShuffleReaders(adaptivePlan, 0) | ||
| // Additional shuffle exchange introduced, so revert OptimizeLocalShuffleReader rule firstly. | ||
| // However when creating new broadcast query stage, we also can | ||
| // change the shuffle reader to local shuffle reader in build side. | ||
| checkNumLocalShuffleReaders(adaptivePlan, 1) | ||
| } | ||
| } | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: the space should be there.