Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ object AdaptiveSparkPlanExec {
* Apply a list of physical operator rules on a [[SparkPlan]].
*/
def applyPhysicalRules(plan: SparkPlan, rules: Seq[Rule[SparkPlan]]): SparkPlan = {
rules.foldLeft(plan) { case (sp, rule) => rule.apply(sp) }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: the space should be there.

rules.foldLeft(plan) { case (sp, rule) => rule.apply(sp)}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,31 +24,33 @@ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartit
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BuildLeft, BuildRight}
import org.apache.spark.sql.internal.SQLConf

case class OptimizeLocalShuffleReader(conf: SQLConf) extends Rule[SparkPlan] {

def canUseLocalShuffleReaderLeft(join: BroadcastHashJoinExec): Boolean = {
join.buildSide == BuildRight && ShuffleQueryStageExec.isShuffleQueryStageExec(join.left)
}

def canUseLocalShuffleReaderRight(join: BroadcastHashJoinExec): Boolean = {
join.buildSide == BuildLeft && ShuffleQueryStageExec.isShuffleQueryStageExec(join.right)
}

override def apply(plan: SparkPlan): SparkPlan = {
if (!conf.getConf(SQLConf.OPTIMIZE_LOCAL_SHUFFLE_READER_ENABLED)) {
return plan
}

val optimizedPlan = plan.transformDown {
case join: BroadcastHashJoinExec if canUseLocalShuffleReaderRight(join) =>
val localReader = LocalShuffleReaderExec(join.right.asInstanceOf[QueryStageExec])
join.copy(right = localReader)
case join: BroadcastHashJoinExec if canUseLocalShuffleReaderLeft(join) =>
val localReader = LocalShuffleReaderExec(join.left.asInstanceOf[QueryStageExec])
join.copy(left = localReader)
def collectShuffleStages(plan: SparkPlan): Seq[ShuffleQueryStageExec] = plan match {
case _: LocalShuffleReaderExec => Nil
case stage: ShuffleQueryStageExec => Seq(stage)
case ReusedQueryStageExec(_, stage: ShuffleQueryStageExec, _) => Seq(stage)
case _ => plan.children.flatMap(collectShuffleStages)
}
val shuffleStages = collectShuffleStages(plan)

val optimizedPlan = if (shuffleStages.isEmpty ||
!shuffleStages.forall(_.plan.canChangeNumPartitions)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is different from ReduceNumShufflePartitions. ReduceNumShufflePartitions needs to change all the shuffles together, so as long as there is a user-added shuffle, we need to skip it.

OptimizeLocalShuffleReader can add local reader to any shuffle, so it's simple

private def canAddLocalReader(stage: QueryStage): Boolean = stage match {
  case s: ShuffleQueryStage => s.plan.canChangeNumPartitions
  case ReusedQueryStage(s: ShuffleQueryStage) => s.plan.canChangeNumPartitions
}

plan.transformUp {
  case stage: QueryStageExec if canAddLocalReader(stage) =>
    LocalShuffleReaderExec(stage)
}

// For the Exchange introduced by repartition,
// don't apply this rule to avoid additional shuffle introduced for the parent stage.
plan
} else {
plan.transformUp {
case stage: QueryStageExec if (ShuffleQueryStageExec.isShuffleQueryStageExec(stage)) =>
LocalShuffleReaderExec(stage)
}
}

def numExchanges(plan: SparkPlan): Int = {
Expand All @@ -59,7 +61,6 @@ case class OptimizeLocalShuffleReader(conf: SQLConf) extends Rule[SparkPlan] {

val numExchangeBefore = numExchanges(EnsureRequirements(conf).apply(plan))
val numExchangeAfter = numExchanges(EnsureRequirements(conf).apply(optimizedPlan))

if (numExchangeAfter > numExchangeBefore) {
logDebug("OptimizeLocalShuffleReader rule is not applied due" +
" to additional shuffles will be introduced.")
Expand Down Expand Up @@ -108,25 +109,4 @@ case class LocalShuffleReaderExec(child: QueryStageExec) extends UnaryExecNode {
}
cachedShuffleRDD
}

override def generateTreeString(
depth: Int,
lastChildren: Seq[Boolean],
append: String => Unit,
verbose: Boolean,
prefix: String = "",
addSuffix: Boolean = false,
maxFields: Int,
printNodeId: Boolean): Unit = {
super.generateTreeString(depth,
lastChildren,
append,
verbose,
prefix,
addSuffix,
maxFields,
printNodeId)
child.generateTreeString(
depth + 1, lastChildren :+ true, append, verbose, "", false, maxFields, printNodeId)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class AdaptiveQueryExecSuite
assert(smj.size == 1)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
checkNumLocalShuffleReaders(adaptivePlan, 1)
checkNumLocalShuffleReaders(adaptivePlan, 2)
}
}

Expand All @@ -110,7 +110,7 @@ class AdaptiveQueryExecSuite
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)

checkNumLocalShuffleReaders(adaptivePlan, 1)
checkNumLocalShuffleReaders(adaptivePlan, 2)
}
}

Expand All @@ -125,7 +125,7 @@ class AdaptiveQueryExecSuite
assert(smj.size == 1)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
checkNumLocalShuffleReaders(adaptivePlan, 1)
checkNumLocalShuffleReaders(adaptivePlan, 2)
}
}

Expand All @@ -141,7 +141,7 @@ class AdaptiveQueryExecSuite
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)

checkNumLocalShuffleReaders(adaptivePlan, 1)
checkNumLocalShuffleReaders(adaptivePlan, 2)
}
}

Expand All @@ -163,9 +163,38 @@ class AdaptiveQueryExecSuite
assert(smj.size == 3)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 3)
// The child of remaining one BroadcastHashJoin is not ShuffleQueryStage.
// So only two LocalShuffleReader.
checkNumLocalShuffleReaders(adaptivePlan, 2)
// *(6) BroadcastHashJoin [b#24], [a#33], Inner, BuildLeft
// :- BroadcastQueryStage 6
// : +- BroadcastExchange HashedRelationBroadcastMode
// : +- LocalShuffleReader
// : +- ShuffleQueryStage 5
// : +- Exchange hashpartitioning(b#24, 5), true, [id=#437]
// : +- *(5) BroadcastHashJoin [key#13], [a#23], Inner, BuildLeft
// : :- BroadcastQueryStage 4
// : : +- BroadcastExchange HashedRelationBroadcastMode
// : : +- LocalShuffleReader
// : : +- ShuffleQueryStage 0
// :
// : +- LocalShuffleReader
// : +- ShuffleQueryStage 1
// : +- Exchange hashpartitioning(a#23, 5), true, [id=#213]
// :
// +- *(6) BroadcastHashJoin [n#93], [a#33], Inner, BuildRight
// :- LocalShuffleReader
// : +- ShuffleQueryStage 2
// : +- Exchange hashpartitioning(n#93, 5), true, [id=#230]
// :
// +- BroadcastQueryStage 7
// +- BroadcastExchange HashedRelationBroadcastMode
// +- LocalShuffleReader
// +- ShuffleQueryStage 3

// After applied the 'OptimizeLocalShuffleReader' rule, we can convert all the four
// shuffle reader to local shuffle reader in the bottom two 'BroadcastHashJoin'.
// For the opt level 'BroadcastHashJoin', the probe side is not shuffle query stage
// and the build side shuffle query stage is also converted to local shuffle reader.

checkNumLocalShuffleReaders(adaptivePlan, 5)
}
}

Expand All @@ -189,9 +218,32 @@ class AdaptiveQueryExecSuite
assert(smj.size == 3)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 3)
// The child of remaining two BroadcastHashJoin is not ShuffleQueryStage.
// So only two LocalShuffleReader.
checkNumLocalShuffleReaders(adaptivePlan, 1)
// *(7) BroadcastHashJoin [b#24], [a#33], Inner, BuildLeft
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit verbose to put the entire query plan here. How about we only put the sketch?

BroadcastHashJoin
+- BroadcastExchange
   +- LocalShuffleReader*
      +- ShuffleExchange
         +- BroadcastHashJoin
            +- BroadcastExchange
               +- LocalShuffleReader*
                  +- ShuffleExchange
            +- LocalShuffleReader*
               +- ShuffleExchange
+- BroadcastHashJoin
   +- LocalShuffleReader*
      +- ShuffleExchange
   +- BroadcastExchange
      +-HashAggregate
         +- CoalescedShuffleReader
            +- ShuffleExchange

// :- BroadcastQueryStage 6
// : +- BroadcastExchange HashedRelationBroadcastMode(
// : +- LocalShuffleReader
// : +- ShuffleQueryStage 5
// : +- Exchange hashpartitioning(b#24, 5), true, [id=#452]
// : +- *(5) BroadcastHashJoin [key#13], [a#23], Inner, BuildLeft
// : :- BroadcastQueryStage 4
// : : +- BroadcastExchange HashedRelationBroadcastMode(
// : : +- LocalShuffleReader
// : : +- ShuffleQueryStage 0
// :
// : +- LocalShuffleReader
// : +- ShuffleQueryStage 1
// :
// +- *(7) BroadcastHashJoin [n#93], [a#33], Inner, BuildRight
// :- LocalShuffleReader
// : +- ShuffleQueryStage 2
//
// +- BroadcastQueryStage 7
// +- BroadcastExchange HashedRelationBroadcastMode
// +- *(6) HashAggregate(keys=[a#33], functions=[sum(cast(b#34 as bigint))],
// output=[a#33, sum(b)#219L])
// +- CoalescedShuffleReader [0]
// +- ShuffleQueryStage 3
checkNumLocalShuffleReaders(adaptivePlan, 4)
}
}

Expand All @@ -215,9 +267,32 @@ class AdaptiveQueryExecSuite
assert(smj.size == 3)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 3)
// The child of remaining two BroadcastHashJoin is not ShuffleQueryStage.
// So only two LocalShuffleReader.
checkNumLocalShuffleReaders(adaptivePlan, 1)
// *(6) BroadcastHashJoin [cast(value#14 as int)], [a#220], Inner, BuildLeft
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

// :- BroadcastQueryStage 7
// : +- BroadcastExchange HashedRelationBroadcastMode
// : +- LocalShuffleReader
// : +- ShuffleQueryStage 6
// : +- Exchange hashpartitioning(cast(value#14 as int), 5), true, [id=#537]
// : +- *(5) BroadcastHashJoin [key#13], [a#23], Inner, BuildLeft
// : :- BroadcastQueryStage 4
// : : +- BroadcastExchange HashedRelationBroadcastMode
// : : +- LocalShuffleReader
// : : +- ShuffleQueryStage 0
// :
// : +- LocalShuffleReader
// : +- ShuffleQueryStage 1
// :
// +- *(6) BroadcastHashJoin [n#93], [b#218], Inner, BuildLeft
// :- BroadcastQueryStage 5
// : +- BroadcastExchange HashedRelationBroadcastMode
// : +- LocalShuffleReader
// : +- ShuffleQueryStage 2
// :
// +- *(6) Filter isnotnull(b#218)
// +- *(6) HashAggregate(keys=[a#220], functions=[max(b#221)], output=[a#220, b#218])
// +- CoalescedShuffleReader [0]
// +- ShuffleQueryStage 3
checkNumLocalShuffleReaders(adaptivePlan, 4)
}
}

Expand All @@ -232,7 +307,7 @@ class AdaptiveQueryExecSuite
assert(smj.size == 3)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 2)
checkNumLocalShuffleReaders(adaptivePlan, 2)
checkNumLocalShuffleReaders(adaptivePlan, 4)
// Even with local shuffle reader, the query statge reuse can also work.
val ex = findReusedExchange(adaptivePlan)
assert(ex.size == 1)
Expand All @@ -250,7 +325,7 @@ class AdaptiveQueryExecSuite
assert(smj.size == 1)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
checkNumLocalShuffleReaders(adaptivePlan, 1)
checkNumLocalShuffleReaders(adaptivePlan, 2)
// Even with local shuffle reader, the query statge reuse can also work.
val ex = findReusedExchange(adaptivePlan)
assert(ex.size == 1)
Expand All @@ -270,7 +345,7 @@ class AdaptiveQueryExecSuite
assert(smj.size == 1)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
checkNumLocalShuffleReaders(adaptivePlan, 1)
checkNumLocalShuffleReaders(adaptivePlan, 2)
// Even with local shuffle reader, the query statge reuse can also work.
val ex = findReusedExchange(adaptivePlan)
assert(ex.nonEmpty)
Expand All @@ -291,7 +366,7 @@ class AdaptiveQueryExecSuite
assert(smj.size == 1)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
checkNumLocalShuffleReaders(adaptivePlan, 1)
checkNumLocalShuffleReaders(adaptivePlan, 2)
// Even with local shuffle reader, the query statge reuse can also work.
val ex = findReusedExchange(adaptivePlan)
assert(ex.isEmpty)
Expand All @@ -315,7 +390,7 @@ class AdaptiveQueryExecSuite
assert(smj.size == 1)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
checkNumLocalShuffleReaders(adaptivePlan, 1)
checkNumLocalShuffleReaders(adaptivePlan, 2)
// Even with local shuffle reader, the query statge reuse can also work.
val ex = findReusedExchange(adaptivePlan)
assert(ex.nonEmpty)
Expand Down Expand Up @@ -393,8 +468,10 @@ class AdaptiveQueryExecSuite
assert(smj.size == 2)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
// additional shuffle exchange introduced, so revert OptimizeLocalShuffleReader rule.
checkNumLocalShuffleReaders(adaptivePlan, 0)
// Additional shuffle exchange introduced, so revert OptimizeLocalShuffleReader rule firstly.
// However when creating new broadcast query stage, we also can
// change the shuffle reader to local shuffle reader in build side.
checkNumLocalShuffleReaders(adaptivePlan, 1)
}
}

Expand Down