Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,26 +27,46 @@ import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExcha
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BuildLeft, BuildRight}
import org.apache.spark.sql.internal.SQLConf

/**
* A rule to optimize the shuffle reader to local reader as far as possible
* when converting the 'SortMergeJoinExec' to 'BroadcastHashJoinExec' in runtime.
*
* This rule can be divided into two steps:
* Step1: Add the local reader in probe side adn then check whether additional
* shuffle introduced. If introduced, we will revert all the local
* reader in probe side.
* Step2: Add the local reader in build side and will not check whether
* additional shuffle introduced.Because the build side will not introduce
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: space after ... introduce.

* additional shuffle.
*/
case class OptimizeLocalShuffleReader(conf: SQLConf) extends Rule[SparkPlan] {

def canUseLocalShuffleReaderLeft(join: BroadcastHashJoinExec): Boolean = {
def canUseLocalShuffleReaderProbeLeft(join: BroadcastHashJoinExec): Boolean = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove it now

join.buildSide == BuildRight && ShuffleQueryStageExec.isShuffleQueryStageExec(join.left)
}

def canUseLocalShuffleReaderRight(join: BroadcastHashJoinExec): Boolean = {
def canUseLocalShuffleReaderProbeRight(join: BroadcastHashJoinExec): Boolean = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

join.buildSide == BuildLeft && ShuffleQueryStageExec.isShuffleQueryStageExec(join.right)
}

def canUseLocalShuffleReaderBuildLeft(join: BroadcastHashJoinExec): Boolean = {
join.buildSide == BuildLeft && ShuffleQueryStageExec.isShuffleQueryStageExec(join.left)
}

def canUseLocalShuffleReaderBuildRight(join: BroadcastHashJoinExec): Boolean = {
join.buildSide == BuildRight && ShuffleQueryStageExec.isShuffleQueryStageExec(join.right)
}

override def apply(plan: SparkPlan): SparkPlan = {
if (!conf.getConf(SQLConf.OPTIMIZE_LOCAL_SHUFFLE_READER_ENABLED)) {
return plan
}

val optimizedPlan = plan.transformDown {
case join: BroadcastHashJoinExec if canUseLocalShuffleReaderRight(join) =>
// Add local reader in probe side.
val tmpOptimizedProbeSidePlan = plan.transformDown {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: withProbeSideLocalReader

case join: BroadcastHashJoinExec if canUseLocalShuffleReaderProbeRight(join) =>
val localReader = LocalShuffleReaderExec(join.right.asInstanceOf[QueryStageExec])
join.copy(right = localReader)
case join: BroadcastHashJoinExec if canUseLocalShuffleReaderLeft(join) =>
case join: BroadcastHashJoinExec if canUseLocalShuffleReaderProbeLeft(join) =>
val localReader = LocalShuffleReaderExec(join.left.asInstanceOf[QueryStageExec])
join.copy(left = localReader)
}
Expand All @@ -56,16 +76,25 @@ case class OptimizeLocalShuffleReader(conf: SQLConf) extends Rule[SparkPlan] {
case e: ShuffleExchangeExec => e
}.length
}

// Check whether additional shuffle introduced. If introduced, revert the local reader.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now this rule converts local shuffle reader for all BroadcastHashJoinExec and then reverts all local shuffle readers if any of local shuffle reader causes additional shuffle.

Can we just revert the local shuffle readers that cause additional shuffle and keep these not?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the best, but I don't know if there is an easy way to do it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can implement using revert all the local reader currently and re-optimize later when we find a better way.

val numExchangeBefore = numExchanges(EnsureRequirements(conf).apply(plan))
val numExchangeAfter = numExchanges(EnsureRequirements(conf).apply(optimizedPlan))

if (numExchangeAfter > numExchangeBefore) {
logDebug("OptimizeLocalShuffleReader rule is not applied due" +
val numExchangeAfter = numExchanges(EnsureRequirements(conf).apply(tmpOptimizedProbeSidePlan))
val optimizedProbeSidePlan = if (numExchangeAfter > numExchangeBefore) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe simply "optimizedPlan"

logDebug("OptimizeLocalShuffleReader rule is not applied in the probe side due" +
" to additional shuffles will be introduced.")
plan
} else {
optimizedPlan
tmpOptimizedProbeSidePlan
}
// Add the local reader in build side and will not check whether
// additional shuffle introduced.
optimizedProbeSidePlan.transformDown {
case join: BroadcastHashJoinExec if canUseLocalShuffleReaderBuildLeft(join) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: case join: BroadcastHashJoinExec if join.buildSide == BuildLeft && isShuffle(join.left) =>

val localReader = LocalShuffleReaderExec(join.left.asInstanceOf[QueryStageExec])
join.copy(left = localReader)
case join: BroadcastHashJoinExec if canUseLocalShuffleReaderBuildRight(join) =>
val localReader = LocalShuffleReaderExec(join.right.asInstanceOf[QueryStageExec])
join.copy(right = localReader)
}
}
}
Expand Down Expand Up @@ -108,25 +137,4 @@ case class LocalShuffleReaderExec(child: QueryStageExec) extends UnaryExecNode {
}
cachedShuffleRDD
}

override def generateTreeString(
depth: Int,
lastChildren: Seq[Boolean],
append: String => Unit,
verbose: Boolean,
prefix: String = "",
addSuffix: Boolean = false,
maxFields: Int,
printNodeId: Boolean): Unit = {
super.generateTreeString(depth,
lastChildren,
append,
verbose,
prefix,
addSuffix,
maxFields,
printNodeId)
child.generateTreeString(
depth + 1, lastChildren :+ true, append, verbose, "", false, maxFields, printNodeId)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class AdaptiveQueryExecSuite
assert(smj.size == 1)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
checkNumLocalShuffleReaders(adaptivePlan, 1)
checkNumLocalShuffleReaders(adaptivePlan, 2)
}
}

Expand All @@ -110,7 +110,7 @@ class AdaptiveQueryExecSuite
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)

checkNumLocalShuffleReaders(adaptivePlan, 1)
checkNumLocalShuffleReaders(adaptivePlan, 2)
}
}

Expand All @@ -125,7 +125,7 @@ class AdaptiveQueryExecSuite
assert(smj.size == 1)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
checkNumLocalShuffleReaders(adaptivePlan, 1)
checkNumLocalShuffleReaders(adaptivePlan, 2)
}
}

Expand All @@ -141,7 +141,7 @@ class AdaptiveQueryExecSuite
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)

checkNumLocalShuffleReaders(adaptivePlan, 1)
checkNumLocalShuffleReaders(adaptivePlan, 2)
}
}

Expand All @@ -163,9 +163,38 @@ class AdaptiveQueryExecSuite
assert(smj.size == 3)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 3)
// The child of remaining one BroadcastHashJoin is not ShuffleQueryStage.
// So only two LocalShuffleReader.
checkNumLocalShuffleReaders(adaptivePlan, 2)
// *(6) BroadcastHashJoin [b#24], [a#33], Inner, BuildLeft
// :- BroadcastQueryStage 6
// : +- BroadcastExchange HashedRelationBroadcastMode
// : +- LocalShuffleReader
// : +- ShuffleQueryStage 5
// : +- Exchange hashpartitioning(b#24, 5), true, [id=#437]
// : +- *(5) BroadcastHashJoin [key#13], [a#23], Inner, BuildLeft
// : :- BroadcastQueryStage 4
// : : +- BroadcastExchange HashedRelationBroadcastMode
// : : +- LocalShuffleReader
// : : +- ShuffleQueryStage 0
// :
// : +- LocalShuffleReader
// : +- ShuffleQueryStage 1
// : +- Exchange hashpartitioning(a#23, 5), true, [id=#213]
// :
// +- *(6) BroadcastHashJoin [n#93], [a#33], Inner, BuildRight
// :- LocalShuffleReader
// : +- ShuffleQueryStage 2
// : +- Exchange hashpartitioning(n#93, 5), true, [id=#230]
// :
// +- BroadcastQueryStage 7
// +- BroadcastExchange HashedRelationBroadcastMode
// +- LocalShuffleReader
// +- ShuffleQueryStage 3

// After applied the 'OptimizeLocalShuffleReader' rule, we can convert all the four
// shuffle reader to local shuffle reader in the bottom two 'BroadcastHashJoin'.
// For the opt level 'BroadcastHashJoin', the probe side is not shuffle query stage
// and the build side shuffle query stage is also converted to local shuffle reader.

checkNumLocalShuffleReaders(adaptivePlan, 5)
}
}

Expand All @@ -189,9 +218,32 @@ class AdaptiveQueryExecSuite
assert(smj.size == 3)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 3)
// The child of remaining two BroadcastHashJoin is not ShuffleQueryStage.
// So only two LocalShuffleReader.
checkNumLocalShuffleReaders(adaptivePlan, 1)
// *(7) BroadcastHashJoin [b#24], [a#33], Inner, BuildLeft
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit verbose to put the entire query plan here. How about we only put the sketch?

BroadcastHashJoin
+- BroadcastExchange
   +- LocalShuffleReader*
      +- ShuffleExchange
         +- BroadcastHashJoin
            +- BroadcastExchange
               +- LocalShuffleReader*
                  +- ShuffleExchange
            +- LocalShuffleReader*
               +- ShuffleExchange
+- BroadcastHashJoin
   +- LocalShuffleReader*
      +- ShuffleExchange
   +- BroadcastExchange
      +-HashAggregate
         +- CoalescedShuffleReader
            +- ShuffleExchange

// :- BroadcastQueryStage 6
// : +- BroadcastExchange HashedRelationBroadcastMode(
// : +- LocalShuffleReader
// : +- ShuffleQueryStage 5
// : +- Exchange hashpartitioning(b#24, 5), true, [id=#452]
// : +- *(5) BroadcastHashJoin [key#13], [a#23], Inner, BuildLeft
// : :- BroadcastQueryStage 4
// : : +- BroadcastExchange HashedRelationBroadcastMode(
// : : +- LocalShuffleReader
// : : +- ShuffleQueryStage 0
// :
// : +- LocalShuffleReader
// : +- ShuffleQueryStage 1
// :
// +- *(7) BroadcastHashJoin [n#93], [a#33], Inner, BuildRight
// :- LocalShuffleReader
// : +- ShuffleQueryStage 2
//
// +- BroadcastQueryStage 7
// +- BroadcastExchange HashedRelationBroadcastMode
// +- *(6) HashAggregate(keys=[a#33], functions=[sum(cast(b#34 as bigint))],
// output=[a#33, sum(b)#219L])
// +- CoalescedShuffleReader [0]
// +- ShuffleQueryStage 3
checkNumLocalShuffleReaders(adaptivePlan, 4)
}
}

Expand All @@ -215,9 +267,32 @@ class AdaptiveQueryExecSuite
assert(smj.size == 3)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 3)
// The child of remaining two BroadcastHashJoin is not ShuffleQueryStage.
// So only two LocalShuffleReader.
checkNumLocalShuffleReaders(adaptivePlan, 1)
// *(6) BroadcastHashJoin [cast(value#14 as int)], [a#220], Inner, BuildLeft
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

// :- BroadcastQueryStage 7
// : +- BroadcastExchange HashedRelationBroadcastMode
// : +- LocalShuffleReader
// : +- ShuffleQueryStage 6
// : +- Exchange hashpartitioning(cast(value#14 as int), 5), true, [id=#537]
// : +- *(5) BroadcastHashJoin [key#13], [a#23], Inner, BuildLeft
// : :- BroadcastQueryStage 4
// : : +- BroadcastExchange HashedRelationBroadcastMode
// : : +- LocalShuffleReader
// : : +- ShuffleQueryStage 0
// :
// : +- LocalShuffleReader
// : +- ShuffleQueryStage 1
// :
// +- *(6) BroadcastHashJoin [n#93], [b#218], Inner, BuildLeft
// :- BroadcastQueryStage 5
// : +- BroadcastExchange HashedRelationBroadcastMode
// : +- LocalShuffleReader
// : +- ShuffleQueryStage 2
// :
// +- *(6) Filter isnotnull(b#218)
// +- *(6) HashAggregate(keys=[a#220], functions=[max(b#221)], output=[a#220, b#218])
// +- CoalescedShuffleReader [0]
// +- ShuffleQueryStage 3
checkNumLocalShuffleReaders(adaptivePlan, 4)
}
}

Expand All @@ -232,8 +307,8 @@ class AdaptiveQueryExecSuite
assert(smj.size == 3)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 2)
checkNumLocalShuffleReaders(adaptivePlan, 2)
// Even with local shuffle reader, the query statge reuse can also work.
checkNumLocalShuffleReaders(adaptivePlan, 4)
// Even with local shuffle reader, the query stage reuse can also work.
val ex = findReusedExchange(adaptivePlan)
assert(ex.size == 1)
}
Expand All @@ -250,8 +325,8 @@ class AdaptiveQueryExecSuite
assert(smj.size == 1)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
checkNumLocalShuffleReaders(adaptivePlan, 1)
// Even with local shuffle reader, the query statge reuse can also work.
checkNumLocalShuffleReaders(adaptivePlan, 2)
// Even with local shuffle reader, the query stage reuse can also work.
val ex = findReusedExchange(adaptivePlan)
assert(ex.size == 1)
}
Expand All @@ -270,8 +345,8 @@ class AdaptiveQueryExecSuite
assert(smj.size == 1)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
checkNumLocalShuffleReaders(adaptivePlan, 1)
// Even with local shuffle reader, the query statge reuse can also work.
checkNumLocalShuffleReaders(adaptivePlan, 2)
// Even with local shuffle reader, the query stage reuse can also work.
val ex = findReusedExchange(adaptivePlan)
assert(ex.nonEmpty)
val sub = findReusedSubquery(adaptivePlan)
Expand All @@ -291,8 +366,8 @@ class AdaptiveQueryExecSuite
assert(smj.size == 1)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
checkNumLocalShuffleReaders(adaptivePlan, 1)
// Even with local shuffle reader, the query statge reuse can also work.
checkNumLocalShuffleReaders(adaptivePlan, 2)
// Even with local shuffle reader, the query stage reuse can also work.
val ex = findReusedExchange(adaptivePlan)
assert(ex.isEmpty)
val sub = findReusedSubquery(adaptivePlan)
Expand All @@ -315,8 +390,8 @@ class AdaptiveQueryExecSuite
assert(smj.size == 1)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
checkNumLocalShuffleReaders(adaptivePlan, 1)
// Even with local shuffle reader, the query statge reuse can also work.
checkNumLocalShuffleReaders(adaptivePlan, 2)
// Even with local shuffle reader, the query stage reuse can also work.
val ex = findReusedExchange(adaptivePlan)
assert(ex.nonEmpty)
assert(ex.head.plan.isInstanceOf[BroadcastQueryStageExec])
Expand Down Expand Up @@ -393,8 +468,9 @@ class AdaptiveQueryExecSuite
assert(smj.size == 2)
val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
// additional shuffle exchange introduced, so revert OptimizeLocalShuffleReader rule.
checkNumLocalShuffleReaders(adaptivePlan, 0)
// Even additional shuffle exchange introduced, we still
// can convert the shuffle reader to local reader in build side.
checkNumLocalShuffleReaders(adaptivePlan, 1)
}
}

Expand Down