-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-40382][SQL] Group distinct aggregate expressions by semantically equivalent children in RewriteDistinctAggregates
#37825
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from 1 commit
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
a5a6fc0
First attempt
bersprockets 0a109d9
Fix for foldables
bersprockets 38f1f6a
Update
bersprockets 3fa3588
Update test
bersprockets 4a40f91
Simplify
bersprockets 165f558
Rename
bersprockets 27dcffe
Update comments
bersprockets 484ca8e
Replace Symbol usage with $"" in new unit tests
bersprockets 208fe82
Update tests
bersprockets 882cdaa
Update
bersprockets f53136d
Use ExpressionSet as key for various distinct aggregate child maps
bersprockets 9938252
Handle case of one distinct grouping with superficially different fun…
bersprockets f7d29df
Update
bersprockets File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Simplify
- Loading branch information
commit 4a40f910f4fd44a526d5959fc255e5702aa29151
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -213,32 +213,17 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { | |
| case a: Aggregate if mayNeedtoRewrite(a) => rewrite(a) | ||
| } | ||
|
|
||
| def rewrite(a: Aggregate): Aggregate = { | ||
| def rewrite(aRaw: Aggregate): Aggregate = { | ||
| // make children of distinct aggregations the same if they are different | ||
| // only because of superficial differences. | ||
| val a = getSanitizedAggregate(aRaw) | ||
|
|
||
| val aggExpressions = collectAggregateExprs(a) | ||
| val distinctAggs = aggExpressions.filter(_.isDistinct) | ||
|
|
||
| val funcChildren = distinctAggs.flatMap { e => | ||
| e.aggregateFunction.children.filter(!_.foldable) | ||
| } | ||
|
|
||
| // For each function child, find the first instance that is semantically equivalent. | ||
| // E.g., assume funcChildren is the following three expressions: | ||
| // [('a + 1), (1 + 'a), 'b] | ||
| // then we want the map to be: | ||
| // Map(('a + 1) -> ('a + 1), (1 + 'a) -> ('a + 1), 'b -> 'b) | ||
| // That is, both ('a + 1) and (1 + 'a) map to ('a + 1). | ||
| // This is an n^2 operation, where n is the number of distinct aggregate children, but it | ||
| // happens only once every time this rule is called. | ||
| val funcChildrenLookup = funcChildren.map { e => | ||
| (e, funcChildren.find(fc => e.semanticEquals(fc)).getOrElse(e)) | ||
| }.toMap | ||
|
|
||
| // Extract distinct aggregate expressions. | ||
| val distinctAggGroups = aggExpressions.filter(_.isDistinct).groupBy { e => | ||
| val unfoldableChildren = e.aggregateFunction.children.filter(!_.foldable).map { fc => | ||
| funcChildrenLookup.getOrElse(fc, fc) | ||
| }.toSet | ||
| val unfoldableChildren = e.aggregateFunction.children.filter(!_.foldable).toSet | ||
| if (unfoldableChildren.nonEmpty) { | ||
| // Only expand the unfoldable children | ||
| unfoldableChildren | ||
|
|
@@ -253,42 +238,8 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { | |
| } | ||
| } | ||
|
|
||
| // get the count of aggregation groups that takes into account | ||
| // even superficial differences in the function children | ||
| val distictAggGroupsCount = aggExpressions.filter(_.isDistinct).map { e => | ||
| val unfoldableChildren = e.aggregateFunction.children.filter(!_.foldable).toSet | ||
| if (unfoldableChildren.nonEmpty) { | ||
| unfoldableChildren | ||
| } else { | ||
| e.aggregateFunction.children.take(1).toSet | ||
| } | ||
| }.toSet.size | ||
|
|
||
| def patchAggregateFunctionChildren( | ||
| af: AggregateFunction)( | ||
| attrs: Expression => Option[Expression]): AggregateFunction = { | ||
| val newChildren = af.children.map(c => attrs(c).getOrElse(c)) | ||
| af.withNewChildren(newChildren).asInstanceOf[AggregateFunction] | ||
| } | ||
|
|
||
| // Aggregation strategy can handle queries with a single distinct group without filter clause. | ||
| if (distinctAggGroups.size == 1 && distictAggGroupsCount > 1 | ||
| && !distinctAggs.exists(_.filter.isDefined)) { | ||
| // we have multiple groups only because of | ||
| // superficial differences. Make them the same so that SparkStrategies | ||
| // doesn't complain during sanity check. That is, if we have an aggList of: | ||
| // [count(distinct b + 1), sum(distinct 1 + b), sum(c)] | ||
| // Change it to: | ||
| // [count(distinct b + 1), sum(distinct b + 1), sum(c)] | ||
| // therefore we have distinct aggregations over only one expression | ||
| val patchedAggExpressions = a.aggregateExpressions.map { e => | ||
| e.transformDown { | ||
| case e: Expression => | ||
| funcChildrenLookup.getOrElse(e, e) | ||
| }.asInstanceOf[NamedExpression] | ||
| } | ||
| a.copy(aggregateExpressions = patchedAggExpressions) | ||
| } else if (distinctAggGroups.size > 1 || distinctAggs.exists(_.filter.isDefined)) { | ||
| if (distinctAggGroups.size > 1 || distinctAggs.exists(_.filter.isDefined)) { | ||
| // Create the attributes for the grouping id and the group by clause. | ||
| val gid = AttributeReference("gid", IntegerType, nullable = false)() | ||
| val groupByMap = a.groupingExpressions.collect { | ||
|
|
@@ -337,12 +288,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { | |
| af | ||
| } else { | ||
| patchAggregateFunctionChildren(af) { x => | ||
| // x might not exactly match any key of distinctAggChildAttrLookup | ||
| // only because `distinctAggChildAttrLookup`'s keys have been de-duped | ||
| // based on semantic equivalence. So we need to translate x to the | ||
| // semantic equivalent that we are actually using. | ||
| val x2 = funcChildrenLookup.getOrElse(x, x) | ||
| distinctAggChildAttrLookup.get(x2) | ||
| distinctAggChildAttrLookup.get(x) | ||
| } | ||
| } | ||
| val newCondition = if (condition.isDefined) { | ||
|
|
@@ -463,6 +409,42 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { | |
| }} | ||
| } | ||
|
|
||
| private def patchAggregateFunctionChildren( | ||
| af: AggregateFunction)( | ||
| attrs: Expression => Option[Expression]): AggregateFunction = { | ||
| val newChildren = af.children.map(c => attrs(c).getOrElse(c)) | ||
| af.withNewChildren(newChildren).asInstanceOf[AggregateFunction] | ||
| } | ||
|
|
||
| private def getSanitizedAggregate(a: Aggregate): Aggregate = { | ||
| val aggExpressions = collectAggregateExprs(a) | ||
| val distinctAggs = aggExpressions.filter(_.isDistinct) | ||
bersprockets marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| val funcChildren = distinctAggs.flatMap { e => | ||
| e.aggregateFunction.children.filter(!_.foldable) | ||
| } | ||
|
|
||
| // For each function child, find the first instance that is semantically equivalent. | ||
| // E.g., assume funcChildren is the following three expressions: | ||
| // [('a + 1), (1 + 'a), 'b] | ||
| // then we want the map to be: | ||
| // Map(('a + 1) -> ('a + 1), (1 + 'a) -> ('a + 1), 'b -> 'b) | ||
| // That is, both ('a + 1) and (1 + 'a) map to ('a + 1). | ||
| // This is an n^2 operation, where n is the number of distinct aggregate children, but it | ||
| // happens only once every time this rule is called. | ||
| val funcChildrenLookup = funcChildren.map { e => | ||
| (e, funcChildren.find(fc => e.semanticEquals(fc)).getOrElse(e)) | ||
|
||
| }.toMap | ||
|
|
||
| val patchedAggExpressions = a.aggregateExpressions.map { e => | ||
| e.transformDown { | ||
| case e: Expression => | ||
| funcChildrenLookup.getOrElse(e, e) | ||
| }.asInstanceOf[NamedExpression] | ||
| } | ||
| a.copy(aggregateExpressions = patchedAggExpressions) | ||
| } | ||
|
|
||
| private def nullify(e: Expression) = Literal.create(null, e.dataType) | ||
|
|
||
| private def expressionAttributePair(e: Expression) = | ||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.