Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1504,6 +1504,13 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val REPLACE_HASH_WITH_SORT_AGG_ENABLED = buildConf("spark.sql.execution.replaceHashWithSortAgg")
.internal()
.doc("Whether to replace hash aggregate node with sort aggregate based on children's ordering")
.version("3.3.0")
.booleanConf
.createWithDefault(true)

val STATE_STORE_PROVIDER_CLASS =
buildConf("spark.sql.streaming.stateStore.providerClass")
.internal()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,9 @@ object QueryExecution {
PlanSubqueries(sparkSession),
RemoveRedundantProjects,
EnsureRequirements(),
// `ReplaceHashWithSortAgg` needs to be added after `EnsureRequirements` to guarantee the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it because the planner is top-down so we don't know the child ordering during planning? Then we have to add a new rule to change the agg algorithm in a post-hoc way.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it is. If we change our planning to bottom-up and propagate each node output ordering info during planning, then we can run this rule during planning. For now, we have to add it after EnsureRequirements.

// sort order of each node is checked to be valid.
ReplaceHashWithSortAgg,
// `RemoveRedundantSorts` needs to be added after `EnsureRequirements` to guarantee the same
// number of partitions when instantiating PartitioningCollection.
RemoveRedundantSorts,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution

import org.apache.spark.sql.catalyst.expressions.SortOrder
import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, Final, Partial}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.internal.SQLConf

/**
* Replace [[HashAggregateExec]] with [[SortAggregateExec]] in the spark plan if:
*
* 1. The plan is a pair of partial and final [[HashAggregateExec]], and the child of partial
* aggregate satisfies the sort order of corresponding [[SortAggregateExec]].
* or
* 2. The plan is a [[HashAggregateExec]], and the child satisfies the sort order of
* corresponding [[SortAggregateExec]].
*
* Examples:
* 1. aggregate after join:
*
* HashAggregate(t1.i, SUM, final)
* | SortAggregate(t1.i, SUM, complete)
* HashAggregate(t1.i, SUM, partial) => |
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like an orthogonal optimization: we can merge adjacent partial and final aggregates (no shuffle between them) into one complete aggregate.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think we can add a rule later to optimize it. I vaguely remember someone proposed this in OSS before but seems impact is not high.

* | SortMergeJoin(t1.i = t2.j)
* SortMergeJoin(t1.i = t2.j)
*
* 2. aggregate after sort:
*
* HashAggregate(t1.i, SUM, partial) SortAggregate(t1.i, SUM, partial)
* | => |
* Sort(t1.i) Sort(t1.i)
*
* [[HashAggregateExec]] can be replaced when its child satisfies the sort order of
* corresponding [[SortAggregateExec]]. [[SortAggregateExec]] is faster in the sense that
* it does not have hashing overhead of [[HashAggregateExec]].
*/
object ReplaceHashWithSortAgg extends Rule[SparkPlan] {
def apply(plan: SparkPlan): SparkPlan = {
if (!conf.getConf(SQLConf.REPLACE_HASH_WITH_SORT_AGG_ENABLED)) {
plan
} else {
replaceHashAgg(plan)
}
}

/**
* Replace [[HashAggregateExec]] with [[SortAggregateExec]].
*/
private def replaceHashAgg(plan: SparkPlan): SparkPlan = {
plan.transformDown {
case hashAgg: HashAggregateExec if hashAgg.groupingExpressions.nonEmpty =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, shall we handle ObjectHashAggregateExec as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan - yeah I agree. Don't see a problem why we cannot do it. Created https://issues.apache.org/jira/browse/SPARK-37557 for followup. Will do it shortly, thanks.

val sortAgg = hashAgg.toSortAggregate
hashAgg.child match {
case partialAgg: HashAggregateExec if isPartialAgg(partialAgg, hashAgg) =>
if (SortOrder.orderingSatisfies(
partialAgg.child.outputOrdering, sortAgg.requiredChildOrdering.head)) {
sortAgg.copy(
aggregateExpressions = sortAgg.aggregateExpressions.map(_.copy(mode = Complete)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it always right? I think we also need to check the output partitioning to see if we can eliminate the partial agg.

An example is df.sortWithinPartitions. It does not cluster the data, just sort it within each partition.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan - I don't think we need to check output partitioning, as we are matching a pair of final and partial hash agg, without shuffle in between:

  HashAggregate(final)
          |                                       SortAggregate(complete)
HashAggregate(partial)             =>                    |
          |                                            child
        child 

So child must already have proper output partitioning for SortAggregate, o.w. it cannot satisfy original HashAggregate(final)'s required distribution.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ok, if there is a shuffle in the middle, we can't optimize? This looks quite limited, as having a shuffle in the middle is very common.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if there is a shuffle in the middle, we can't optimize?

We can, and the rule here also does pattern matching for single HashAggregate below. I added a unit test case in ReplaceHashWithSortAggSuite.scala to demonstrate replacing partial aggregate - "replace partial hash aggregate with sort aggregate". But I think it would be rare to be able to replace final aggregate (though this rule also covers it), as final aggregate is almostly always immediately after a shuffle, so there's no sort ordering before final aggregate.

Spark native shuffle does not guarantee any sort orders, for Cosco (a remote shuffle service we are running in-house), we support sorted shuffle, so final aggregate can also be possible to replace.

child = partialAgg.child)
} else {
hashAgg
}
case other =>
if (SortOrder.orderingSatisfies(
other.outputOrdering, sortAgg.requiredChildOrdering.head)) {
sortAgg
} else {
hashAgg
}
}
case other => other
}
}

/**
* Check if `partialAgg` to be partial aggregate of `finalAgg`.
*/
private def isPartialAgg(partialAgg: HashAggregateExec, finalAgg: HashAggregateExec): Boolean = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like reverse enginering the AggUtils. Could we just link the partial and final agg when they are constructed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tanelk - yeah I agree this is mostly reverse engineering and we can do a better job here. I tried link partial and final agg in AggUtils and check linked physical plan to be same or not. This does not quite work due to we are doing top-down planning, and the linked partial agg not being same as planned partial agg (having PlanLater operator in linked partial agg).

I found a more elegant way to do it, by checking the linked logical plan of both aggs to be same. Updated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @cloud-fan for review, thanks.

val partialGroupExprs = partialAgg.groupingExpressions
val finalGroupExprs = finalAgg.groupingExpressions
val partialAggExprs = partialAgg.aggregateExpressions
val finalAggExprs = finalAgg.aggregateExpressions
val partialAggAttrs = partialAggExprs.flatMap(_.aggregateFunction.aggBufferAttributes)
val finalAggAttrs = finalAggExprs.map(_.resultAttribute)
val partialResultExprs = partialGroupExprs ++
partialAggExprs.flatMap(_.aggregateFunction.inputAggBufferAttributes)

val groupExprsEqual = partialGroupExprs.length == finalGroupExprs.length &&
partialGroupExprs.zip(finalGroupExprs).forall {
case (e1, e2) => e1.semanticEquals(e2)
}
val aggExprsEqual = partialAggExprs.length == finalAggExprs.length &&
partialAggExprs.forall(_.mode == Partial) && finalAggExprs.forall(_.mode == Final) &&
partialAggExprs.zip(finalAggExprs).forall {
case (e1, e2) => e1.aggregateFunction.semanticEquals(e2.aggregateFunction)
}
val isPartialAggAttrsValid = partialAggAttrs.length == partialAgg.aggregateAttributes.length &&
partialAggAttrs.zip(partialAgg.aggregateAttributes).forall {
case (a1, a2) => a1.semanticEquals(a2)
}
val isFinalAggAttrsValid = finalAggAttrs.length == finalAgg.aggregateAttributes.length &&
finalAggAttrs.zip(finalAgg.aggregateAttributes).forall {
case (a1, a2) => a1.semanticEquals(a2)
}
val isPartialResultExprsValid =
partialResultExprs.length == partialAgg.resultExpressions.length &&
partialResultExprs.zip(partialAgg.resultExpressions).forall {
case (a1, a2) => a1.semanticEquals(a2)
}
val isRequiredDistributionValid =
partialAgg.requiredChildDistributionExpressions.isEmpty &&
finalAgg.requiredChildDistributionExpressions.exists { exprs =>
exprs.length == finalGroupExprs.length &&
exprs.zip(finalGroupExprs).forall {
case (e1, e2) => e1.semanticEquals(e2)
}
}

groupExprsEqual && aggExprsEqual && isPartialAggAttrsValid && isFinalAggAttrsValid &&
isPartialResultExprsValid && isRequiredDistributionValid
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ case class AdaptiveSparkPlanExec(
Seq(
RemoveRedundantProjects,
ensureRequirements,
ReplaceHashWithSortAgg,
RemoveRedundantSorts,
DisableUnnecessaryBucketedScan,
OptimizeSkewedJoin(ensureRequirements)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1153,6 +1153,15 @@ case class HashAggregateExec(
}
}

/**
* The corresponding [[SortAggregateExec]] to get same result as this node.
*/
def toSortAggregate: SortAggregateExec = {
SortAggregateExec(
requiredChildDistributionExpressions, groupingExpressions, aggregateExpressions,
aggregateAttributes, initialInputBufferOffset, resultExpressions, child)
}

override protected def withNewChildInternal(newChild: SparkPlan): HashAggregateExec =
copy(child = newChild)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
+- Exchange (44)
+- * HashAggregate (43)
+- * HashAggregate (42)
+- * HashAggregate (41)
+- SortAggregate (41)
+- * Project (40)
+- * BroadcastHashJoin Inner BuildRight (39)
:- * Project (33)
Expand Down Expand Up @@ -221,21 +221,21 @@ Join condition: None
Output [3]: [cs_order_number#5, cs_ext_ship_cost#6, cs_net_profit#7]
Input [5]: [cs_ship_date_sk#1, cs_order_number#5, cs_ext_ship_cost#6, cs_net_profit#7, d_date_sk#23]

(41) HashAggregate [codegen id : 11]
(41) SortAggregate
Input [3]: [cs_order_number#5, cs_ext_ship_cost#6, cs_net_profit#7]
Keys [1]: [cs_order_number#5]
Functions [2]: [partial_sum(UnscaledValue(cs_ext_ship_cost#6)), partial_sum(UnscaledValue(cs_net_profit#7))]
Aggregate Attributes [2]: [sum(UnscaledValue(cs_ext_ship_cost#6))#26, sum(UnscaledValue(cs_net_profit#7))#27]
Results [3]: [cs_order_number#5, sum#28, sum#29]

(42) HashAggregate [codegen id : 11]
(42) HashAggregate [codegen id : 12]
Input [3]: [cs_order_number#5, sum#28, sum#29]
Keys [1]: [cs_order_number#5]
Functions [2]: [merge_sum(UnscaledValue(cs_ext_ship_cost#6)), merge_sum(UnscaledValue(cs_net_profit#7))]
Aggregate Attributes [2]: [sum(UnscaledValue(cs_ext_ship_cost#6))#26, sum(UnscaledValue(cs_net_profit#7))#27]
Results [3]: [cs_order_number#5, sum#28, sum#29]

(43) HashAggregate [codegen id : 11]
(43) HashAggregate [codegen id : 12]
Input [3]: [cs_order_number#5, sum#28, sum#29]
Keys: []
Functions [3]: [merge_sum(UnscaledValue(cs_ext_ship_cost#6)), merge_sum(UnscaledValue(cs_net_profit#7)), partial_count(distinct cs_order_number#5)]
Expand All @@ -246,7 +246,7 @@ Results [3]: [sum#28, sum#29, count#31]
Input [3]: [sum#28, sum#29, count#31]
Arguments: SinglePartition, ENSURE_REQUIREMENTS, [id=#32]

(45) HashAggregate [codegen id : 12]
(45) HashAggregate [codegen id : 13]
Input [3]: [sum#28, sum#29, count#31]
Keys: []
Functions [3]: [sum(UnscaledValue(cs_ext_ship_cost#6)), sum(UnscaledValue(cs_net_profit#7)), count(distinct cs_order_number#5)]
Expand Down
Loading