Skip to content
This repository was archived by the owner on Nov 15, 2024. It is now read-only.

Commit 296223a

Browse files
DonnyZoneMatthewRBruce
authored andcommitted
[SPARK-21980][SQL] References in grouping functions should be indexed with semanticEquals
## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-21980 This PR fixes the issue in ResolveGroupingAnalytics rule, which indexes the column references in grouping functions without considering case sensitive configurations. The problem can be reproduced by: `val df = spark.createDataFrame(Seq((1, 1), (2, 1), (2, 2))).toDF("a", "b") df.cube("a").agg(grouping("A")).show()` ## How was this patch tested? unit tests Author: donnyzone <[email protected]> Closes apache#19202 from DonnyZone/ResolveGroupingAnalytics. (cherry picked from commit 21c4450) Signed-off-by: gatorsmile <[email protected]>
1 parent ff7910d commit 296223a

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ class Analyzer(
315315
s"grouping columns (${groupByExprs.mkString(",")})")
316316
}
317317
case e @ Grouping(col: Expression) =>
318-
val idx = groupByExprs.indexOf(col)
318+
val idx = groupByExprs.indexWhere(_.semanticEquals(col))
319319
if (idx >= 0) {
320320
Alias(Cast(BitwiseAnd(ShiftRight(gid, Literal(groupByExprs.length - 1 - idx)),
321321
Literal(1)), ByteType), toPrettySQL(e))()

sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,22 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
186186
)
187187
}
188188

189+
test("SPARK-21980: References in grouping functions should be indexed with semanticEquals") {
190+
checkAnswer(
191+
courseSales.cube("course", "year")
192+
.agg(grouping("CouRse"), grouping("year")),
193+
Row("Java", 2012, 0, 0) ::
194+
Row("Java", 2013, 0, 0) ::
195+
Row("Java", null, 0, 1) ::
196+
Row("dotNET", 2012, 0, 0) ::
197+
Row("dotNET", 2013, 0, 0) ::
198+
Row("dotNET", null, 0, 1) ::
199+
Row(null, 2012, 1, 0) ::
200+
Row(null, 2013, 1, 0) ::
201+
Row(null, null, 1, 1) :: Nil
202+
)
203+
}
204+
189205
test("rollup overlapping columns") {
190206
checkAnswer(
191207
testData2.rollup($"a" + $"b" as "foo", $"b" as "bar").agg(sum($"a" - $"b") as "foo"),

0 commit comments

Comments
 (0)