Skip to content
Closed
Next Next commit
support the build side and probe side to local shuffle reader as far …
…as possible in BroadcastHashJoin
  • Loading branch information
JkSelf committed Oct 29, 2019
commit 5b7ff2dc680e37032cee048f987e3809d6ebfd94
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ case class AdaptiveSparkPlanExec(
}

private def newQueryStage(e: Exchange): QueryStageExec = {
val optimizedPlan = applyPhysicalRules(e.child, queryStageOptimizerRules)
val optimizedPlan = applyPhysicalRules(e.child, queryStageOptimizerRules, Some(e))
val queryStage = e match {
case s: ShuffleExchangeExec =>
ShuffleQueryStageExec(currentStageId, s.copy(child = optimizedPlan))
Expand Down Expand Up @@ -524,8 +524,16 @@ object AdaptiveSparkPlanExec {
/**
* Apply a list of physical operator rules on a [[SparkPlan]].
*/
def applyPhysicalRules(plan: SparkPlan, rules: Seq[Rule[SparkPlan]]): SparkPlan = {
rules.foldLeft(plan) { case (sp, rule) => rule.apply(sp) }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: the space should be there.

def applyPhysicalRules(
plan: SparkPlan, rules: Seq[Rule[SparkPlan]],
parent: Option[SparkPlan] = None): SparkPlan = {
rules.foldLeft(plan) {
case (sp, rule) =>
if (parent.nonEmpty && rule.isInstanceOf[OptimizeLocalShuffleReader]) {
rule.asInstanceOf[OptimizeLocalShuffleReader].parent = parent
}
rule.apply(sp)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,32 +23,39 @@ import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BuildLeft, BuildRight}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, EnsureRequirements, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.internal.SQLConf

case class OptimizeLocalShuffleReader(conf: SQLConf) extends Rule[SparkPlan] {
case class OptimizeLocalShuffleReader(
conf: SQLConf,
var parent: Option[SparkPlan] = None) extends Rule[SparkPlan] {

def canUseLocalShuffleReaderLeft(join: BroadcastHashJoinExec): Boolean = {
join.buildSide == BuildRight && ShuffleQueryStageExec.isShuffleQueryStageExec(join.left)
}

def canUseLocalShuffleReaderRight(join: BroadcastHashJoinExec): Boolean = {
join.buildSide == BuildLeft && ShuffleQueryStageExec.isShuffleQueryStageExec(join.right)
def canUseLocalShuffleReader(plan: SparkPlan): Boolean = {
ShuffleQueryStageExec.isShuffleQueryStageExec(plan)
}

override def apply(plan: SparkPlan): SparkPlan = {
if (!conf.getConf(SQLConf.OPTIMIZE_LOCAL_SHUFFLE_READER_ENABLED)) {
return plan
}

val optimizedPlan = plan.transformDown {
case join: BroadcastHashJoinExec if canUseLocalShuffleReaderRight(join) =>
val localReader = LocalShuffleReaderExec(join.right.asInstanceOf[QueryStageExec])
join.copy(right = localReader)
case join: BroadcastHashJoinExec if canUseLocalShuffleReaderLeft(join) =>
val localReader = LocalShuffleReaderExec(join.left.asInstanceOf[QueryStageExec])
join.copy(left = localReader)
// In order to choose the best optimal plan after applied local reader rule.
// When `BroadcastExchangeExec` + `ShuffleQueryStageExec` occurs,
// we also make the shuffle reader to local reader.
val newPlan = if (canUseLocalShuffleReader(plan) &&
parent.nonEmpty && parent.get.isInstanceOf[BroadcastExchangeExec]) {
LocalShuffleReaderExec(plan.asInstanceOf[QueryStageExec])
} else plan

val optimizedPlan = newPlan.transformDown {
case join: BroadcastHashJoinExec =>
val optimizedRightPlan = if (canUseLocalShuffleReader(join.right)) {
LocalShuffleReaderExec(join.right.asInstanceOf[QueryStageExec])
} else join.right
val optimizedLeftPlan = if (canUseLocalShuffleReader(join.left)) {
LocalShuffleReaderExec(join.left.asInstanceOf[QueryStageExec])
} else join.left
join.copy(left = optimizedLeftPlan, right = optimizedRightPlan)
}

def numExchanges(plan: SparkPlan): Int = {
Expand All @@ -59,7 +66,6 @@ case class OptimizeLocalShuffleReader(conf: SQLConf) extends Rule[SparkPlan] {

val numExchangeBefore = numExchanges(EnsureRequirements(conf).apply(plan))
val numExchangeAfter = numExchanges(EnsureRequirements(conf).apply(optimizedPlan))

if (numExchangeAfter > numExchangeBefore) {
logDebug("OptimizeLocalShuffleReader rule is not applied due" +
" to additional shuffles will be introduced.")
Expand Down Expand Up @@ -108,25 +114,4 @@ case class LocalShuffleReaderExec(child: QueryStageExec) extends UnaryExecNode {
}
cachedShuffleRDD
}

override def generateTreeString(
depth: Int,
lastChildren: Seq[Boolean],
append: String => Unit,
verbose: Boolean,
prefix: String = "",
addSuffix: Boolean = false,
maxFields: Int,
printNodeId: Boolean): Unit = {
super.generateTreeString(depth,
lastChildren,
append,
verbose,
prefix,
addSuffix,
maxFields,
printNodeId)
child.generateTreeString(
depth + 1, lastChildren :+ true, append, verbose, "", false, maxFields, printNodeId)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class AdaptiveQueryExecSuite
assert(smj.size == 1)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
checkNumLocalShuffleReaders(adaptivePlan, 1)
checkNumLocalShuffleReaders(adaptivePlan, 2)
}
}

Expand All @@ -110,7 +110,7 @@ class AdaptiveQueryExecSuite
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)

checkNumLocalShuffleReaders(adaptivePlan, 1)
checkNumLocalShuffleReaders(adaptivePlan, 2)
}
}

Expand All @@ -125,7 +125,7 @@ class AdaptiveQueryExecSuite
assert(smj.size == 1)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
checkNumLocalShuffleReaders(adaptivePlan, 1)
checkNumLocalShuffleReaders(adaptivePlan, 2)
}
}

Expand All @@ -141,7 +141,7 @@ class AdaptiveQueryExecSuite
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)

checkNumLocalShuffleReaders(adaptivePlan, 1)
checkNumLocalShuffleReaders(adaptivePlan, 2)
}
}

Expand All @@ -163,9 +163,38 @@ class AdaptiveQueryExecSuite
assert(smj.size == 3)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 3)
// The child of remaining one BroadcastHashJoin is not ShuffleQueryStage.
// So only two LocalShuffleReader.
checkNumLocalShuffleReaders(adaptivePlan, 2)
// *(6) BroadcastHashJoin [b#24], [a#33], Inner, BuildLeft
// :- BroadcastQueryStage 6
// : +- BroadcastExchange HashedRelationBroadcastMode
// : +- LocalShuffleReader
// : +- ShuffleQueryStage 5
// : +- Exchange hashpartitioning(b#24, 5), true, [id=#437]
// : +- *(5) BroadcastHashJoin [key#13], [a#23], Inner, BuildLeft
// : :- BroadcastQueryStage 4
// : : +- BroadcastExchange HashedRelationBroadcastMode
// : : +- LocalShuffleReader
// : : +- ShuffleQueryStage 0
// :
// : +- LocalShuffleReader
// : +- ShuffleQueryStage 1
// : +- Exchange hashpartitioning(a#23, 5), true, [id=#213]
// :
// +- *(6) BroadcastHashJoin [n#93], [a#33], Inner, BuildRight
// :- LocalShuffleReader
// : +- ShuffleQueryStage 2
// : +- Exchange hashpartitioning(n#93, 5), true, [id=#230]
// :
// +- BroadcastQueryStage 7
// +- BroadcastExchange HashedRelationBroadcastMode
// +- LocalShuffleReader
// +- ShuffleQueryStage 3

// After applied the 'OptimizeLocalShuffleReader' rule, we can convert all the four
// shuffle reader to local shuffle reader in the bottom two 'BroadcastHashJoin'.
// For the opt level 'BroadcastHashJoin', the probe side is not shuffle query stage
// and the build side shuffle query stage is also converted to local shuffle reader.

checkNumLocalShuffleReaders(adaptivePlan, 5)
}
}

Expand All @@ -189,9 +218,32 @@ class AdaptiveQueryExecSuite
assert(smj.size == 3)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 3)
// The child of remaining two BroadcastHashJoin is not ShuffleQueryStage.
// So only two LocalShuffleReader.
checkNumLocalShuffleReaders(adaptivePlan, 1)
// *(7) BroadcastHashJoin [b#24], [a#33], Inner, BuildLeft
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit verbose to put the entire query plan here. How about we only put the sketch?

BroadcastHashJoin
+- BroadcastExchange
   +- LocalShuffleReader*
      +- ShuffleExchange
         +- BroadcastHashJoin
            +- BroadcastExchange
               +- LocalShuffleReader*
                  +- ShuffleExchange
            +- LocalShuffleReader*
               +- ShuffleExchange
+- BroadcastHashJoin
   +- LocalShuffleReader*
      +- ShuffleExchange
   +- BroadcastExchange
      +-HashAggregate
         +- CoalescedShuffleReader
            +- ShuffleExchange

// :- BroadcastQueryStage 6
// : +- BroadcastExchange HashedRelationBroadcastMode(
// : +- LocalShuffleReader
// : +- ShuffleQueryStage 5
// : +- Exchange hashpartitioning(b#24, 5), true, [id=#452]
// : +- *(5) BroadcastHashJoin [key#13], [a#23], Inner, BuildLeft
// : :- BroadcastQueryStage 4
// : : +- BroadcastExchange HashedRelationBroadcastMode(
// : : +- LocalShuffleReader
// : : +- ShuffleQueryStage 0
// :
// : +- LocalShuffleReader
// : +- ShuffleQueryStage 1
// :
// +- *(7) BroadcastHashJoin [n#93], [a#33], Inner, BuildRight
// :- LocalShuffleReader
// : +- ShuffleQueryStage 2
//
// +- BroadcastQueryStage 7
// +- BroadcastExchange HashedRelationBroadcastMode
// +- *(6) HashAggregate(keys=[a#33], functions=[sum(cast(b#34 as bigint))],
// output=[a#33, sum(b)#219L])
// +- CoalescedShuffleReader [0]
// +- ShuffleQueryStage 3
checkNumLocalShuffleReaders(adaptivePlan, 4)
}
}

Expand All @@ -215,9 +267,32 @@ class AdaptiveQueryExecSuite
assert(smj.size == 3)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 3)
// The child of remaining two BroadcastHashJoin is not ShuffleQueryStage.
// So only two LocalShuffleReader.
checkNumLocalShuffleReaders(adaptivePlan, 1)
// *(6) BroadcastHashJoin [cast(value#14 as int)], [a#220], Inner, BuildLeft
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

// :- BroadcastQueryStage 7
// : +- BroadcastExchange HashedRelationBroadcastMode
// : +- LocalShuffleReader
// : +- ShuffleQueryStage 6
// : +- Exchange hashpartitioning(cast(value#14 as int), 5), true, [id=#537]
// : +- *(5) BroadcastHashJoin [key#13], [a#23], Inner, BuildLeft
// : :- BroadcastQueryStage 4
// : : +- BroadcastExchange HashedRelationBroadcastMode
// : : +- LocalShuffleReader
// : : +- ShuffleQueryStage 0
// :
// : +- LocalShuffleReader
// : +- ShuffleQueryStage 1
// :
// +- *(6) BroadcastHashJoin [n#93], [b#218], Inner, BuildLeft
// :- BroadcastQueryStage 5
// : +- BroadcastExchange HashedRelationBroadcastMode
// : +- LocalShuffleReader
// : +- ShuffleQueryStage 2
// :
// +- *(6) Filter isnotnull(b#218)
// +- *(6) HashAggregate(keys=[a#220], functions=[max(b#221)], output=[a#220, b#218])
// +- CoalescedShuffleReader [0]
// +- ShuffleQueryStage 3
checkNumLocalShuffleReaders(adaptivePlan, 4)
}
}

Expand All @@ -232,7 +307,7 @@ class AdaptiveQueryExecSuite
assert(smj.size == 3)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 2)
checkNumLocalShuffleReaders(adaptivePlan, 2)
checkNumLocalShuffleReaders(adaptivePlan, 4)
// Even with local shuffle reader, the query statge reuse can also work.
val ex = findReusedExchange(adaptivePlan)
assert(ex.size == 1)
Expand All @@ -250,7 +325,7 @@ class AdaptiveQueryExecSuite
assert(smj.size == 1)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
checkNumLocalShuffleReaders(adaptivePlan, 1)
checkNumLocalShuffleReaders(adaptivePlan, 2)
// Even with local shuffle reader, the query statge reuse can also work.
val ex = findReusedExchange(adaptivePlan)
assert(ex.size == 1)
Expand All @@ -270,7 +345,7 @@ class AdaptiveQueryExecSuite
assert(smj.size == 1)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
checkNumLocalShuffleReaders(adaptivePlan, 1)
checkNumLocalShuffleReaders(adaptivePlan, 2)
// Even with local shuffle reader, the query statge reuse can also work.
val ex = findReusedExchange(adaptivePlan)
assert(ex.nonEmpty)
Expand All @@ -291,7 +366,7 @@ class AdaptiveQueryExecSuite
assert(smj.size == 1)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
checkNumLocalShuffleReaders(adaptivePlan, 1)
checkNumLocalShuffleReaders(adaptivePlan, 2)
// Even with local shuffle reader, the query statge reuse can also work.
val ex = findReusedExchange(adaptivePlan)
assert(ex.isEmpty)
Expand All @@ -315,7 +390,7 @@ class AdaptiveQueryExecSuite
assert(smj.size == 1)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
checkNumLocalShuffleReaders(adaptivePlan, 1)
checkNumLocalShuffleReaders(adaptivePlan, 2)
// Even with local shuffle reader, the query statge reuse can also work.
val ex = findReusedExchange(adaptivePlan)
assert(ex.nonEmpty)
Expand Down Expand Up @@ -393,8 +468,10 @@ class AdaptiveQueryExecSuite
assert(smj.size == 2)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
// additional shuffle exchange introduced, so revert OptimizeLocalShuffleReader rule.
checkNumLocalShuffleReaders(adaptivePlan, 0)
// Additional shuffle exchange introduced, so revert OptimizeLocalShuffleReader rule firstly.
// However when creating new broadcast query stage, we also can
// change the shuffle reader to local shuffle reader in build side.
checkNumLocalShuffleReaders(adaptivePlan, 1)
}
}

Expand Down