diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAgg.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAgg.scala index 63ad2d0cafb7..4495bc9b6a58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAgg.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAgg.scala @@ -20,17 +20,18 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.expressions.SortOrder import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, Final, Partial} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.aggregate.HashAggregateExec +import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec} import org.apache.spark.sql.internal.SQLConf /** - * Replace [[HashAggregateExec]] with [[SortAggregateExec]] in the spark plan if: + * Replace hash-based aggregate with sort aggregate in the spark plan if: * - * 1. The plan is a pair of partial and final [[HashAggregateExec]], and the child of partial - * aggregate satisfies the sort order of corresponding [[SortAggregateExec]]. + * 1. The plan is a pair of partial and final [[HashAggregateExec]] or [[ObjectHashAggregateExec]], + * and the child of partial aggregate satisfies the sort order of corresponding + * [[SortAggregateExec]]. * or - * 2. The plan is a [[HashAggregateExec]], and the child satisfies the sort order of - * corresponding [[SortAggregateExec]]. + * 2. The plan is a [[HashAggregateExec]] or [[ObjectHashAggregateExec]], and the child satisfies + * the sort order of corresponding [[SortAggregateExec]]. * * Examples: * 1. aggregate after join: @@ -47,9 +48,9 @@ import org.apache.spark.sql.internal.SQLConf * | => | * Sort(t1.i) Sort(t1.i) * - * [[HashAggregateExec]] can be replaced when its child satisfies the sort order of - * corresponding [[SortAggregateExec]]. [[SortAggregateExec]] is faster in the sense that - * it does not have hashing overhead of [[HashAggregateExec]]. + * Hash-based aggregate can be replaced when its child satisfies the sort order of + * corresponding sort aggregate. Sort aggregate is faster in the sense that + * it does not have hashing overhead of hash aggregate. */ object ReplaceHashWithSortAgg extends Rule[SparkPlan] { def apply(plan: SparkPlan): SparkPlan = { @@ -61,14 +62,15 @@ object ReplaceHashWithSortAgg extends Rule[SparkPlan] { } /** - * Replace [[HashAggregateExec]] with [[SortAggregateExec]]. + * Replace [[HashAggregateExec]] and [[ObjectHashAggregateExec]] with [[SortAggregateExec]]. */ private def replaceHashAgg(plan: SparkPlan): SparkPlan = { plan.transformDown { - case hashAgg: HashAggregateExec if hashAgg.groupingExpressions.nonEmpty => + case hashAgg: BaseAggregateExec if isHashBasedAggWithKeys(hashAgg) => val sortAgg = hashAgg.toSortAggregate hashAgg.child match { - case partialAgg: HashAggregateExec if isPartialAgg(partialAgg, hashAgg) => + case partialAgg: BaseAggregateExec + if isHashBasedAggWithKeys(partialAgg) && isPartialAgg(partialAgg, hashAgg) => if (SortOrder.orderingSatisfies( partialAgg.child.outputOrdering, sortAgg.requiredChildOrdering.head)) { sortAgg.copy( @@ -92,7 +94,7 @@ object ReplaceHashWithSortAgg extends Rule[SparkPlan] { /** * Check if `partialAgg` to be partial aggregate of `finalAgg`. */ - private def isPartialAgg(partialAgg: HashAggregateExec, finalAgg: HashAggregateExec): Boolean = { + private def isPartialAgg(partialAgg: BaseAggregateExec, finalAgg: BaseAggregateExec): Boolean = { if (partialAgg.aggregateExpressions.forall(_.mode == Partial) && finalAgg.aggregateExpressions.forall(_.mode == Final)) { (finalAgg.logicalLink, partialAgg.logicalLink) match { @@ -103,4 +105,16 @@ object ReplaceHashWithSortAgg extends Rule[SparkPlan] { false } } + + /** + * Check if `agg` is [[HashAggregateExec]] or [[ObjectHashAggregateExec]], + * and has grouping keys. + */ + private def isHashBasedAggWithKeys(agg: BaseAggregateExec): Boolean = { + val isHashBasedAgg = agg match { + case _: HashAggregateExec | _: ObjectHashAggregateExec => true + case _ => false + } + isHashBasedAgg && agg.groupingExpressions.nonEmpty + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala index c676609bc37e..b709c8092e46 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala @@ -30,6 +30,7 @@ trait BaseAggregateExec extends UnaryExecNode with AliasAwareOutputPartitioning def groupingExpressions: Seq[NamedExpression] def aggregateExpressions: Seq[AggregateExpression] def aggregateAttributes: Seq[Attribute] + def initialInputBufferOffset: Int def resultExpressions: Seq[NamedExpression] override def verboseStringWithOperatorId(): String = { @@ -95,4 +96,13 @@ trait BaseAggregateExec extends UnaryExecNode with AliasAwareOutputPartitioning case None => UnspecifiedDistribution :: Nil } } + + /** + * The corresponding [[SortAggregateExec]] to get same result as this node. + */ + def toSortAggregate: SortAggregateExec = { + SortAggregateExec( + requiredChildDistributionExpressions, groupingExpressions, aggregateExpressions, + aggregateAttributes, initialInputBufferOffset, resultExpressions, child) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 85e81cb12dca..854515402860 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -1153,15 +1153,6 @@ case class HashAggregateExec( } } - /** - * The corresponding [[SortAggregateExec]] to get same result as this node. - */ - def toSortAggregate: SortAggregateExec = { - SortAggregateExec( - requiredChildDistributionExpressions, groupingExpressions, aggregateExpressions, - aggregateAttributes, initialInputBufferOffset, resultExpressions, child) - } - override protected def withNewChildInternal(newChild: SparkPlan): HashAggregateExec = copy(child = newChild) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAggSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAggSuite.scala index 78765fdf4f75..47679ed7865d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAggSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAggSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.{DataFrame, QueryTest} import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite} -import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec} +import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -30,7 +30,9 @@ abstract class ReplaceHashWithSortAggSuiteBase private def checkNumAggs(df: DataFrame, hashAggCount: Int, sortAggCount: Int): Unit = { val plan = df.queryExecution.executedPlan - assert(collectWithSubqueries(plan) { case s: HashAggregateExec => s }.length == hashAggCount) + assert(collectWithSubqueries(plan) { + case s @ (_: HashAggregateExec | _: ObjectHashAggregateExec) => s + }.length == hashAggCount) assert(collectWithSubqueries(plan) { case s: SortAggregateExec => s }.length == sortAggCount) } @@ -55,19 +57,21 @@ abstract class ReplaceHashWithSortAggSuiteBase test("replace partial hash aggregate with sort aggregate") { withTempView("t") { spark.range(100).selectExpr("id as key").repartition(10).createOrReplaceTempView("t") - val query = - """ - |SELECT key, FIRST(key) - |FROM - |( - | SELECT key - | FROM t - | WHERE key > 10 - | SORT BY key - |) - |GROUP BY key - """.stripMargin - checkAggs(query, 1, 1, 2, 0) + Seq("FIRST", "COLLECT_LIST").foreach { aggExpr => + val query = + s""" + |SELECT key, $aggExpr(key) + |FROM + |( + | SELECT key + | FROM t + | WHERE key > 10 + | SORT BY key + |) + |GROUP BY key + """.stripMargin + checkAggs(query, 1, 1, 2, 0) + } } } @@ -75,19 +79,21 @@ abstract class ReplaceHashWithSortAggSuiteBase withTempView("t1", "t2") { spark.range(100).selectExpr("id as key").createOrReplaceTempView("t1") spark.range(50).selectExpr("id as key").createOrReplaceTempView("t2") - val query = - """ - |SELECT key, COUNT(key) - |FROM - |( - | SELECT /*+ SHUFFLE_MERGE(t1) */ t1.key AS key - | FROM t1 - | JOIN t2 - | ON t1.key = t2.key - |) - |GROUP BY key - """.stripMargin - checkAggs(query, 0, 1, 2, 0) + Seq("COUNT", "COLLECT_LIST").foreach { aggExpr => + val query = + s""" + |SELECT key, $aggExpr(key) + |FROM + |( + | SELECT /*+ SHUFFLE_MERGE(t1) */ t1.key AS key + | FROM t1 + | JOIN t2 + | ON t1.key = t2.key + |) + |GROUP BY key + """.stripMargin + checkAggs(query, 0, 1, 2, 0) + } } } @@ -95,31 +101,35 @@ abstract class ReplaceHashWithSortAggSuiteBase withTempView("t1", "t2") { spark.range(100).selectExpr("id as key").createOrReplaceTempView("t1") spark.range(50).selectExpr("id as key").createOrReplaceTempView("t2") - val query = - """ - |SELECT key, COUNT(key) - |FROM - |( - | SELECT /*+ BROADCAST(t1) */ t1.key AS key - | FROM t1 - | JOIN t2 - | ON t1.key = t2.key - |) - |GROUP BY key - """.stripMargin - checkAggs(query, 2, 0, 2, 0) + Seq("COUNT", "COLLECT_LIST").foreach { aggExpr => + val query = + s""" + |SELECT key, $aggExpr(key) + |FROM + |( + | SELECT /*+ BROADCAST(t1) */ t1.key AS key + | FROM t1 + | JOIN t2 + | ON t1.key = t2.key + |) + |GROUP BY key + """.stripMargin + checkAggs(query, 2, 0, 2, 0) + } } } test("do not replace hash aggregate if there is no group-by column") { withTempView("t1") { spark.range(100).selectExpr("id as key").createOrReplaceTempView("t1") - val query = - """ - |SELECT COUNT(key) - |FROM t1 - """.stripMargin - checkAggs(query, 2, 0, 2, 0) + Seq("COUNT", "COLLECT_LIST").foreach { aggExpr => + val query = + s""" + |SELECT $aggExpr(key) + |FROM t1 + """.stripMargin + checkAggs(query, 2, 0, 2, 0) + } } } }