Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,18 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.catalyst.expressions.SortOrder
import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, Final, Partial}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec}
import org.apache.spark.sql.internal.SQLConf

/**
* Replace [[HashAggregateExec]] with [[SortAggregateExec]] in the spark plan if:
* Replace hash-based aggregate with sort aggregate in the spark plan if:
*
* 1. The plan is a pair of partial and final [[HashAggregateExec]], and the child of partial
* aggregate satisfies the sort order of corresponding [[SortAggregateExec]].
* 1. The plan is a pair of partial and final [[HashAggregateExec]] or [[ObjectHashAggregateExec]],
* and the child of partial aggregate satisfies the sort order of corresponding
* [[SortAggregateExec]].
* or
* 2. The plan is a [[HashAggregateExec]], and the child satisfies the sort order of
* corresponding [[SortAggregateExec]].
* 2. The plan is a [[HashAggregateExec]] or [[ObjectHashAggregateExec]], and the child satisfies
* the sort order of corresponding [[SortAggregateExec]].
*
* Examples:
* 1. aggregate after join:
Expand All @@ -47,9 +48,9 @@ import org.apache.spark.sql.internal.SQLConf
* | => |
* Sort(t1.i) Sort(t1.i)
*
* [[HashAggregateExec]] can be replaced when its child satisfies the sort order of
* corresponding [[SortAggregateExec]]. [[SortAggregateExec]] is faster in the sense that
* it does not have hashing overhead of [[HashAggregateExec]].
* Hash-based aggregate can be replaced when its child satisfies the sort order of
* corresponding sort aggregate. Sort aggregate is faster in the sense that
* it does not have hashing overhead of hash aggregate.
*/
object ReplaceHashWithSortAgg extends Rule[SparkPlan] {
def apply(plan: SparkPlan): SparkPlan = {
Expand All @@ -61,14 +62,15 @@ object ReplaceHashWithSortAgg extends Rule[SparkPlan] {
}

/**
* Replace [[HashAggregateExec]] with [[SortAggregateExec]].
* Replace [[HashAggregateExec]] and [[ObjectHashAggregateExec]] with [[SortAggregateExec]].
*/
private def replaceHashAgg(plan: SparkPlan): SparkPlan = {
plan.transformDown {
case hashAgg: HashAggregateExec if hashAgg.groupingExpressions.nonEmpty =>
case hashAgg: BaseAggregateExec if isHashBasedAggWithKeys(hashAgg) =>
val sortAgg = hashAgg.toSortAggregate
hashAgg.child match {
case partialAgg: HashAggregateExec if isPartialAgg(partialAgg, hashAgg) =>
case partialAgg: BaseAggregateExec
if isHashBasedAggWithKeys(partialAgg) && isPartialAgg(partialAgg, hashAgg) =>
if (SortOrder.orderingSatisfies(
partialAgg.child.outputOrdering, sortAgg.requiredChildOrdering.head)) {
sortAgg.copy(
Expand All @@ -92,7 +94,7 @@ object ReplaceHashWithSortAgg extends Rule[SparkPlan] {
/**
* Check if `partialAgg` to be partial aggregate of `finalAgg`.
*/
private def isPartialAgg(partialAgg: HashAggregateExec, finalAgg: HashAggregateExec): Boolean = {
private def isPartialAgg(partialAgg: BaseAggregateExec, finalAgg: BaseAggregateExec): Boolean = {
if (partialAgg.aggregateExpressions.forall(_.mode == Partial) &&
finalAgg.aggregateExpressions.forall(_.mode == Final)) {
(finalAgg.logicalLink, partialAgg.logicalLink) match {
Expand All @@ -103,4 +105,16 @@ object ReplaceHashWithSortAgg extends Rule[SparkPlan] {
false
}
}

/**
* Check if `agg` is [[HashAggregateExec]] or [[ObjectHashAggregateExec]],
* and has grouping keys.
*/
private def isHashBasedAggWithKeys(agg: BaseAggregateExec): Boolean = {
val isHashBasedAgg = agg match {
case _: HashAggregateExec | _: ObjectHashAggregateExec => true
case _ => false
}
isHashBasedAgg && agg.groupingExpressions.nonEmpty
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ trait BaseAggregateExec extends UnaryExecNode with AliasAwareOutputPartitioning
def groupingExpressions: Seq[NamedExpression]
def aggregateExpressions: Seq[AggregateExpression]
def aggregateAttributes: Seq[Attribute]
def initialInputBufferOffset: Int
def resultExpressions: Seq[NamedExpression]

override def verboseStringWithOperatorId(): String = {
Expand Down Expand Up @@ -95,4 +96,13 @@ trait BaseAggregateExec extends UnaryExecNode with AliasAwareOutputPartitioning
case None => UnspecifiedDistribution :: Nil
}
}

/**
* The corresponding [[SortAggregateExec]] to get same result as this node.
*/
def toSortAggregate: SortAggregateExec = {
SortAggregateExec(
requiredChildDistributionExpressions, groupingExpressions, aggregateExpressions,
aggregateAttributes, initialInputBufferOffset, resultExpressions, child)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1153,15 +1153,6 @@ case class HashAggregateExec(
}
}

/**
* The corresponding [[SortAggregateExec]] to get same result as this node.
*/
def toSortAggregate: SortAggregateExec = {
SortAggregateExec(
requiredChildDistributionExpressions, groupingExpressions, aggregateExpressions,
aggregateAttributes, initialInputBufferOffset, resultExpressions, child)
}

override protected def withNewChildInternal(newChild: SparkPlan): HashAggregateExec =
copy(child = newChild)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution

import org.apache.spark.sql.{DataFrame, QueryTest}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite}
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession

Expand All @@ -30,7 +30,9 @@ abstract class ReplaceHashWithSortAggSuiteBase

private def checkNumAggs(df: DataFrame, hashAggCount: Int, sortAggCount: Int): Unit = {
val plan = df.queryExecution.executedPlan
assert(collectWithSubqueries(plan) { case s: HashAggregateExec => s }.length == hashAggCount)
assert(collectWithSubqueries(plan) {
case s @ (_: HashAggregateExec | _: ObjectHashAggregateExec) => s
}.length == hashAggCount)
assert(collectWithSubqueries(plan) { case s: SortAggregateExec => s }.length == sortAggCount)
}

Expand All @@ -55,71 +57,79 @@ abstract class ReplaceHashWithSortAggSuiteBase
test("replace partial hash aggregate with sort aggregate") {
withTempView("t") {
spark.range(100).selectExpr("id as key").repartition(10).createOrReplaceTempView("t")
val query =
"""
|SELECT key, FIRST(key)
|FROM
|(
| SELECT key
| FROM t
| WHERE key > 10
| SORT BY key
|)
|GROUP BY key
""".stripMargin
checkAggs(query, 1, 1, 2, 0)
Seq("FIRST", "COLLECT_LIST").foreach { aggExpr =>
val query =
s"""
|SELECT key, $aggExpr(key)
|FROM
|(
| SELECT key
| FROM t
| WHERE key > 10
| SORT BY key
|)
|GROUP BY key
""".stripMargin
checkAggs(query, 1, 1, 2, 0)
}
}
}

test("replace partial and final hash aggregate together with sort aggregate") {
withTempView("t1", "t2") {
spark.range(100).selectExpr("id as key").createOrReplaceTempView("t1")
spark.range(50).selectExpr("id as key").createOrReplaceTempView("t2")
val query =
"""
|SELECT key, COUNT(key)
|FROM
|(
| SELECT /*+ SHUFFLE_MERGE(t1) */ t1.key AS key
| FROM t1
| JOIN t2
| ON t1.key = t2.key
|)
|GROUP BY key
""".stripMargin
checkAggs(query, 0, 1, 2, 0)
Seq("COUNT", "COLLECT_LIST").foreach { aggExpr =>
val query =
s"""
|SELECT key, $aggExpr(key)
|FROM
|(
| SELECT /*+ SHUFFLE_MERGE(t1) */ t1.key AS key
| FROM t1
| JOIN t2
| ON t1.key = t2.key
|)
|GROUP BY key
""".stripMargin
checkAggs(query, 0, 1, 2, 0)
}
}
}

test("do not replace hash aggregate if child does not have sort order") {
withTempView("t1", "t2") {
spark.range(100).selectExpr("id as key").createOrReplaceTempView("t1")
spark.range(50).selectExpr("id as key").createOrReplaceTempView("t2")
val query =
"""
|SELECT key, COUNT(key)
|FROM
|(
| SELECT /*+ BROADCAST(t1) */ t1.key AS key
| FROM t1
| JOIN t2
| ON t1.key = t2.key
|)
|GROUP BY key
""".stripMargin
checkAggs(query, 2, 0, 2, 0)
Seq("COUNT", "COLLECT_LIST").foreach { aggExpr =>
val query =
s"""
|SELECT key, $aggExpr(key)
|FROM
|(
| SELECT /*+ BROADCAST(t1) */ t1.key AS key
| FROM t1
| JOIN t2
| ON t1.key = t2.key
|)
|GROUP BY key
""".stripMargin
checkAggs(query, 2, 0, 2, 0)
}
}
}

test("do not replace hash aggregate if there is no group-by column") {
withTempView("t1") {
spark.range(100).selectExpr("id as key").createOrReplaceTempView("t1")
val query =
"""
|SELECT COUNT(key)
|FROM t1
""".stripMargin
checkAggs(query, 2, 0, 2, 0)
Seq("COUNT", "COLLECT_LIST").foreach { aggExpr =>
val query =
s"""
|SELECT $aggExpr(key)
|FROM t1
""".stripMargin
checkAggs(query, 2, 0, 2, 0)
}
}
}
}
Expand Down