Skip to content

Commit c67a774

Browse files
mihailotim-dbcloud-fan
authored andcommitted
[SPARK-52895][SQL] Don't add duplicate elements in resolveExprsWithAggregate
### What changes were proposed in this pull request? Don't add duplicate elements in `resolveExprsWithAggregate`. ### Why are the changes needed? This is needed in order to resolve plan mismatches between fixed-point and single-pass analyzer. At the moment fixed-point duplicates columns if there are duplicate columns missing in HAVING/ORDER BY. However, if there are LCAs, fixed-point will deduplicate these columns because LCA resolution uses a set (and LCA resolution runs after ORDER BY/HAVING resolution in fixed-point). In single-pass LCA resolution is done first and only after comes ORDER BY/HAVING resolution which adds duplicates. This PR makes behavior consistent across all cases by never adding duplicates. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added new test cases to golden files. ### Was this patch authored or co-authored using generative AI tooling? No Closes #51567 from mihailotim-db/mihailotim-db/deduplicate_agg_exprs. Authored-by: Mihailo Timotic <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 45374d8 commit c67a774

File tree

13 files changed

+157
-21
lines changed

13 files changed

+157
-21
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.catalyst.analysis
1919

2020
import java.util
21-
import java.util.Locale
21+
import java.util.{LinkedHashMap, Locale}
2222

2323
import scala.collection.mutable
2424
import scala.collection.mutable.ArrayBuffer
@@ -2919,21 +2919,21 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
29192919
def resolveExprsWithAggregate(
29202920
exprs: Seq[Expression],
29212921
agg: Aggregate): (Seq[NamedExpression], Seq[Expression]) = {
2922-
val extraAggExprs = ArrayBuffer.empty[NamedExpression]
2922+
val extraAggExprs = new LinkedHashMap[Expression, NamedExpression]
29232923
val transformed = exprs.map { e =>
29242924
if (!e.resolved) {
29252925
e
29262926
} else {
29272927
buildAggExprList(e, agg, extraAggExprs)
29282928
}
29292929
}
2930-
(extraAggExprs.toSeq, transformed)
2930+
(extraAggExprs.values().asScala.toSeq, transformed)
29312931
}
29322932

29332933
private def buildAggExprList(
29342934
expr: Expression,
29352935
agg: Aggregate,
2936-
aggExprList: ArrayBuffer[NamedExpression]): Expression = {
2936+
aggExprMap: LinkedHashMap[Expression, NamedExpression]): Expression = {
29372937
// Avoid adding an extra aggregate expression if it's already present in
29382938
// `agg.aggregateExpressions`. Trim inner aliases from aggregate expressions because of
29392939
// expressions like `spark_grouping_id` that can have inner aliases.
@@ -2949,20 +2949,22 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
29492949
expr match {
29502950
case ae: AggregateExpression =>
29512951
val cleaned = trimTempResolvedColumn(ae)
2952-
val alias =
2953-
Alias(cleaned, toPrettySQL(e = cleaned, shouldTrimTempResolvedColumn = true))()
2954-
aggExprList += alias
2955-
alias.toAttribute
2952+
val resultAlias = aggExprMap.computeIfAbsent(
2953+
cleaned.canonicalized,
2954+
_ => Alias(cleaned, toPrettySQL(e = cleaned, shouldTrimTempResolvedColumn = true))()
2955+
)
2956+
resultAlias.toAttribute
29562957
case grouping: Expression if agg.groupingExpressions.exists(grouping.semanticEquals) =>
29572958
trimTempResolvedColumn(grouping) match {
29582959
case ne: NamedExpression =>
2959-
aggExprList += ne
2960-
ne.toAttribute
2960+
val resultAttribute = aggExprMap.computeIfAbsent(ne.canonicalized, _ => ne)
2961+
resultAttribute.toAttribute
29612962
case other =>
2962-
val alias =
2963-
Alias(other, toPrettySQL(e = other, shouldTrimTempResolvedColumn = true))()
2964-
aggExprList += alias
2965-
alias.toAttribute
2963+
val resultAlias = aggExprMap.computeIfAbsent(
2964+
other.canonicalized,
2965+
_ => Alias(other, toPrettySQL(e = other, shouldTrimTempResolvedColumn = true))()
2966+
)
2967+
resultAlias.toAttribute
29662968
}
29672969
case t: TempResolvedColumn =>
29682970
if (t.child.isInstanceOf[Attribute]) {
@@ -2977,15 +2979,15 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
29772979
val childWithTempCol = t.child.transformUp {
29782980
case a: Attribute => TempResolvedColumn(a, Seq(a.name))
29792981
}
2980-
val newChild = buildAggExprList(childWithTempCol, agg, aggExprList)
2982+
val newChild = buildAggExprList(childWithTempCol, agg, aggExprMap)
29812983
if (newChild.containsPattern(TEMP_RESOLVED_COLUMN)) {
29822984
withOrigin(t.origin)(t.copy(hasTried = true))
29832985
} else {
29842986
newChild
29852987
}
29862988
}
29872989
case other =>
2988-
other.withNewChildren(other.children.map(buildAggExprList(_, agg, aggExprList)))
2990+
other.withNewChildren(other.children.map(buildAggExprList(_, agg, aggExprMap)))
29892991
}
29902992
}
29912993
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/HavingResolver.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,10 @@ class HavingResolver(resolver: Resolver, expressionResolver: ExpressionResolver)
7474
val (resolvedConditionWithAliasReplacement, filteredMissingExpressions) =
7575
tryReplaceSortOrderOrHavingConditionWithAlias(resolvedCondition, scopes, missingExpressions)
7676

77+
val deduplicatedMissingExpressions = deduplicateMissingExpressions(filteredMissingExpressions)
78+
7779
val resolvedChildWithMissingAttributes =
78-
insertMissingExpressions(resolvedChild, filteredMissingExpressions)
80+
insertMissingExpressions(resolvedChild, deduplicatedMissingExpressions)
7981

8082
val isChildChangedByMissingExpressions = !resolvedChildWithMissingAttributes.eq(resolvedChild)
8183

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolvesNameByHiddenOutput.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,20 @@ trait ResolvesNameByHiddenOutput extends SQLConfHelper {
222222
case other => other
223223
}
224224

225+
/**
226+
* Deduplicates missing expressions by [[ExprId]].
227+
*/
228+
def deduplicateMissingExpressions(
229+
missingExpressions: Seq[NamedExpression]): Seq[NamedExpression] = {
230+
val duplicateMissingExpressions = new HashSet[ExprId]
231+
missingExpressions.collect {
232+
case expression: NamedExpression
233+
if !duplicateMissingExpressions.contains(expression.exprId) =>
234+
duplicateMissingExpressions.add(expression.exprId)
235+
expression
236+
}
237+
}
238+
225239
private def expandOperatorsOutputList(
226240
operator: UnaryNode,
227241
operatorOutput: Seq[NamedExpression],

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/SortResolver.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,10 @@ class SortResolver(operatorResolver: Resolver, expressionResolver: ExpressionRes
134134
val (resolvedOrderExpressionsWithAliasesReplaced, filteredMissingExpressions) =
135135
tryReplaceSortOrderWithAlias(resolvedOrderExpressions, missingExpressions)
136136

137+
val deduplicatedMissingExpressions = deduplicateMissingExpressions(filteredMissingExpressions)
138+
137139
val resolvedChildWithMissingAttributes =
138-
insertMissingExpressions(resolvedChild, filteredMissingExpressions)
140+
insertMissingExpressions(resolvedChild, deduplicatedMissingExpressions)
139141

140142
val isChildChangedByMissingExpressions = !resolvedChildWithMissingAttributes.eq(resolvedChild)
141143

sql/core/src/test/resources/sql-tests/analyzer-results/group-analytics.sql.out

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ HAVING GROUPING(year) = 1 AND GROUPING_ID(course, year) > 0 ORDER BY course, yea
383383
Sort [course#x ASC NULLS FIRST, year#x ASC NULLS FIRST], true
384384
+- Project [course#x, year#x]
385385
+- Filter ((cast(cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) as int) = 1) AND (spark_grouping_id#xL > cast(0 as bigint)))
386-
+- Aggregate [course#x, year#x, spark_grouping_id#xL], [course#x, year#x, spark_grouping_id#xL, spark_grouping_id#xL]
386+
+- Aggregate [course#x, year#x, spark_grouping_id#xL], [course#x, year#x, spark_grouping_id#xL]
387387
+- Expand [[course#x, year#x, earnings#x, course#x, year#x, 0], [course#x, year#x, earnings#x, course#x, null, 1], [course#x, year#x, earnings#x, null, year#x, 2], [course#x, year#x, earnings#x, null, null, 3]], [course#x, year#x, earnings#x, course#x, year#x, spark_grouping_id#xL]
388388
+- Project [course#x, year#x, earnings#x, course#x AS course#x, year#x AS year#x]
389389
+- SubqueryAlias coursesales

sql/core/src/test/resources/sql-tests/analyzer-results/having.sql.out

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,3 +486,23 @@ Filter cast(scalar-subquery#x [alias#x] as boolean)
486486
: +- OneRowRelation
487487
+- Aggregate [col1#x], [col1#x AS alias#x]
488488
+- LocalRelation [col1#x]
489+
490+
491+
-- !query
492+
SELECT col1 FROM VALUES(1,2) GROUP BY col1, col2 HAVING col2 = col2
493+
-- !query analysis
494+
Project [col1#x]
495+
+- Filter (col2#x = col2#x)
496+
+- Aggregate [col1#x, col2#x], [col1#x, col2#x]
497+
+- LocalRelation [col1#x, col2#x]
498+
499+
500+
-- !query
501+
SELECT col1 AS a, a AS b FROM VALUES(1,2) GROUP BY col1, col2 HAVING col2 = col2
502+
-- !query analysis
503+
Project [a#x, b#x]
504+
+- Filter (col2#x = col2#x)
505+
+- Project [a#x, a#x AS b#x, col2#x]
506+
+- Project [col1#x, col2#x, col1#x AS a#x]
507+
+- Aggregate [col1#x, col2#x], [col1#x, col2#x]
508+
+- LocalRelation [col1#x, col2#x]

sql/core/src/test/resources/sql-tests/analyzer-results/order-by.sql.out

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,26 @@ Sort [(sum(b) + 1)#xL ASC NULLS FIRST], true
454454
+- LocalRelation [a#x, b#x]
455455

456456

457+
-- !query
458+
SELECT col1 FROM VALUES(1,2) GROUP BY col1, col2 ORDER BY col2, col2
459+
-- !query analysis
460+
Project [col1#x]
461+
+- Sort [col2#x ASC NULLS FIRST, col2#x ASC NULLS FIRST], true
462+
+- Aggregate [col1#x, col2#x], [col1#x, col2#x]
463+
+- LocalRelation [col1#x, col2#x]
464+
465+
466+
-- !query
467+
SELECT col1 AS a, a AS b FROM VALUES(1,2) GROUP BY col1, col2 ORDER BY col2, col2
468+
-- !query analysis
469+
Project [a#x, b#x]
470+
+- Sort [col2#x ASC NULLS FIRST, col2#x ASC NULLS FIRST], true
471+
+- Project [a#x, a#x AS b#x, col2#x]
472+
+- Project [col1#x, col2#x, col1#x AS a#x]
473+
+- Aggregate [col1#x, col2#x], [col1#x, col2#x]
474+
+- LocalRelation [col1#x, col2#x]
475+
476+
457477
-- !query
458478
DROP VIEW IF EXISTS testData
459479
-- !query analysis

sql/core/src/test/resources/sql-tests/analyzer-results/udf/udf-group-analytics.sql.out

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ HAVING GROUPING(year) = 1 AND GROUPING_ID(course, year) > 0 ORDER BY course, udf
256256
Sort [course#x ASC NULLS FIRST, cast(udf(cast(year#x as string)) as int) ASC NULLS FIRST], true
257257
+- Project [course#x, year#x]
258258
+- Filter ((cast(cast((shiftright(spark_grouping_id#xL, 0) & 1) as tinyint) as int) = 1) AND (spark_grouping_id#xL > cast(0 as bigint)))
259-
+- Aggregate [course#x, year#x, spark_grouping_id#xL], [course#x, year#x, spark_grouping_id#xL, spark_grouping_id#xL]
259+
+- Aggregate [course#x, year#x, spark_grouping_id#xL], [course#x, year#x, spark_grouping_id#xL]
260260
+- Expand [[course#x, year#x, earnings#x, course#x, year#x, 0], [course#x, year#x, earnings#x, course#x, null, 1], [course#x, year#x, earnings#x, null, year#x, 2], [course#x, year#x, earnings#x, null, null, 3]], [course#x, year#x, earnings#x, course#x, year#x, spark_grouping_id#xL]
261261
+- Project [course#x, year#x, earnings#x, course#x AS course#x, year#x AS year#x]
262262
+- SubqueryAlias coursesales

sql/core/src/test/resources/sql-tests/inputs/having.sql

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,8 @@ GROUP BY col1
9191
HAVING (
9292
SELECT col1[0] = 1
9393
);
94+
95+
-- Missing attribute (col2) in HAVING is added only once
96+
97+
SELECT col1 FROM VALUES(1,2) GROUP BY col1, col2 HAVING col2 = col2;
98+
SELECT col1 AS a, a AS b FROM VALUES(1,2) GROUP BY col1, col2 HAVING col2 = col2;

sql/core/src/test/resources/sql-tests/inputs/order-by.sql

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,5 +54,10 @@ SELECT MAX(a) + SUM(b) FROM testData ORDER BY SUM(b) + MAX(a);
5454
SELECT SUM(a) + 1 + MIN(a) FROM testData ORDER BY 1 + 1 + 1 + MIN(a) + 1 + SUM(a);
5555
SELECT SUM(b) + 1 FROM testData HAVING SUM(b) + 1 > 0 ORDER BY SUM(b) + 1;
5656

57+
-- Missing attribute (col2) in ORDER BY is added only once
58+
59+
SELECT col1 FROM VALUES(1,2) GROUP BY col1, col2 ORDER BY col2, col2;
60+
SELECT col1 AS a, a AS b FROM VALUES(1,2) GROUP BY col1, col2 ORDER BY col2, col2;
61+
5762
-- Clean up
5863
DROP VIEW IF EXISTS testData;

0 commit comments

Comments
 (0)