-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-37455][SQL] Replace hash with sort aggregate if child is already sorted #34702
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
6448864
a683137
e8609fd
cff1424
8ce7d27
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,139 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.execution | ||
|
|
||
| import org.apache.spark.sql.catalyst.expressions.SortOrder | ||
| import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, Final, Partial} | ||
| import org.apache.spark.sql.catalyst.rules.Rule | ||
| import org.apache.spark.sql.execution.aggregate.HashAggregateExec | ||
| import org.apache.spark.sql.internal.SQLConf | ||
|
|
||
| /** | ||
| * Replace [[HashAggregateExec]] with [[SortAggregateExec]] in the spark plan if: | ||
| * | ||
| * 1. The plan is a pair of partial and final [[HashAggregateExec]], and the child of partial | ||
| * aggregate satisfies the sort order of corresponding [[SortAggregateExec]]. | ||
| * or | ||
| * 2. The plan is a [[HashAggregateExec]], and the child satisfies the sort order of | ||
| * corresponding [[SortAggregateExec]]. | ||
| * | ||
| * Examples: | ||
| * 1. aggregate after join: | ||
| * | ||
| * HashAggregate(t1.i, SUM, final) | ||
| * | SortAggregate(t1.i, SUM, complete) | ||
| * HashAggregate(t1.i, SUM, partial) => | | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems like an orthogonal optimization: we can merge adjacent partial and final aggregates (no shuffle between them) into one complete aggregate.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I think we can add a rule later to optimize it. I vaguely remember someone proposed this in OSS before but seems impact is not high. |
||
| * | SortMergeJoin(t1.i = t2.j) | ||
| * SortMergeJoin(t1.i = t2.j) | ||
| * | ||
| * 2. aggregate after sort: | ||
| * | ||
| * HashAggregate(t1.i, SUM, partial) SortAggregate(t1.i, SUM, partial) | ||
| * | => | | ||
| * Sort(t1.i) Sort(t1.i) | ||
| * | ||
| * [[HashAggregateExec]] can be replaced when its child satisfies the sort order of | ||
| * corresponding [[SortAggregateExec]]. [[SortAggregateExec]] is faster in the sense that | ||
| * it does not have hashing overhead of [[HashAggregateExec]]. | ||
| */ | ||
| object ReplaceHashWithSortAgg extends Rule[SparkPlan] { | ||
| def apply(plan: SparkPlan): SparkPlan = { | ||
| if (!conf.getConf(SQLConf.REPLACE_HASH_WITH_SORT_AGG_ENABLED)) { | ||
| plan | ||
| } else { | ||
| replaceHashAgg(plan) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Replace [[HashAggregateExec]] with [[SortAggregateExec]]. | ||
| */ | ||
| private def replaceHashAgg(plan: SparkPlan): SparkPlan = { | ||
| plan.transformDown { | ||
| case hashAgg: HashAggregateExec if hashAgg.groupingExpressions.nonEmpty => | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BTW, shall we handle
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @cloud-fan - yeah I agree. Don't see a problem why we cannot do it. Created https://issues.apache.org/jira/browse/SPARK-37557 for followup. Will do it shortly, thanks. |
||
| val sortAgg = hashAgg.toSortAggregate | ||
| hashAgg.child match { | ||
| case partialAgg: HashAggregateExec if isPartialAgg(partialAgg, hashAgg) => | ||
| if (SortOrder.orderingSatisfies( | ||
| partialAgg.child.outputOrdering, sortAgg.requiredChildOrdering.head)) { | ||
| sortAgg.copy( | ||
| aggregateExpressions = sortAgg.aggregateExpressions.map(_.copy(mode = Complete)), | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is it always right? I think we also need to check the output partitioning to see if we can eliminate the partial agg. An example is
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @cloud-fan - I don't think we need to check output partitioning, as we are matching a pair of final and partial hash agg, without shuffle in between: So
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah ok, if there is a shuffle in the middle, we can't optimize? This looks quite limited, as having a shuffle in the middle is very common.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
We can, and the rule here also does pattern matching for single Spark native shuffle does not guarantee any sort orders, for Cosco (a remote shuffle service we are running in-house), we support sorted shuffle, so final aggregate can also be possible to replace. |
||
| child = partialAgg.child) | ||
| } else { | ||
| hashAgg | ||
| } | ||
| case other => | ||
| if (SortOrder.orderingSatisfies( | ||
| other.outputOrdering, sortAgg.requiredChildOrdering.head)) { | ||
| sortAgg | ||
| } else { | ||
| hashAgg | ||
| } | ||
| } | ||
| case other => other | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Check if `partialAgg` to be partial aggregate of `finalAgg`. | ||
| */ | ||
| private def isPartialAgg(partialAgg: HashAggregateExec, finalAgg: HashAggregateExec): Boolean = { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks like reverse enginering the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @tanelk - yeah I agree this is mostly reverse engineering and we can do a better job here. I tried link partial and final agg in I found a more elegant way to do it, by checking the linked logical plan of both aggs to be same. Updated.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @cloud-fan for review, thanks. |
||
| val partialGroupExprs = partialAgg.groupingExpressions | ||
| val finalGroupExprs = finalAgg.groupingExpressions | ||
| val partialAggExprs = partialAgg.aggregateExpressions | ||
| val finalAggExprs = finalAgg.aggregateExpressions | ||
| val partialAggAttrs = partialAggExprs.flatMap(_.aggregateFunction.aggBufferAttributes) | ||
| val finalAggAttrs = finalAggExprs.map(_.resultAttribute) | ||
| val partialResultExprs = partialGroupExprs ++ | ||
| partialAggExprs.flatMap(_.aggregateFunction.inputAggBufferAttributes) | ||
|
|
||
| val groupExprsEqual = partialGroupExprs.length == finalGroupExprs.length && | ||
| partialGroupExprs.zip(finalGroupExprs).forall { | ||
| case (e1, e2) => e1.semanticEquals(e2) | ||
| } | ||
| val aggExprsEqual = partialAggExprs.length == finalAggExprs.length && | ||
| partialAggExprs.forall(_.mode == Partial) && finalAggExprs.forall(_.mode == Final) && | ||
| partialAggExprs.zip(finalAggExprs).forall { | ||
| case (e1, e2) => e1.aggregateFunction.semanticEquals(e2.aggregateFunction) | ||
| } | ||
| val isPartialAggAttrsValid = partialAggAttrs.length == partialAgg.aggregateAttributes.length && | ||
| partialAggAttrs.zip(partialAgg.aggregateAttributes).forall { | ||
| case (a1, a2) => a1.semanticEquals(a2) | ||
| } | ||
| val isFinalAggAttrsValid = finalAggAttrs.length == finalAgg.aggregateAttributes.length && | ||
| finalAggAttrs.zip(finalAgg.aggregateAttributes).forall { | ||
| case (a1, a2) => a1.semanticEquals(a2) | ||
| } | ||
| val isPartialResultExprsValid = | ||
| partialResultExprs.length == partialAgg.resultExpressions.length && | ||
| partialResultExprs.zip(partialAgg.resultExpressions).forall { | ||
| case (a1, a2) => a1.semanticEquals(a2) | ||
| } | ||
| val isRequiredDistributionValid = | ||
| partialAgg.requiredChildDistributionExpressions.isEmpty && | ||
| finalAgg.requiredChildDistributionExpressions.exists { exprs => | ||
| exprs.length == finalGroupExprs.length && | ||
| exprs.zip(finalGroupExprs).forall { | ||
| case (e1, e2) => e1.semanticEquals(e2) | ||
| } | ||
| } | ||
|
|
||
| groupExprsEqual && aggExprsEqual && isPartialAggAttrsValid && isFinalAggAttrsValid && | ||
| isPartialResultExprsValid && isRequiredDistributionValid | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it because the planner is top-down so we don't know the child ordering during planning? Then we have to add a new rule to change the agg algorithm in a post-hoc way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it is. If we change our planning to bottom-up and propagate each node output ordering info during planning, then we can run this rule during planning. For now, we have to add it after
EnsureRequirements.