-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-11275] [SQL] Incorrect results when using rollup/cube #9815
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -211,45 +211,35 @@ class Analyzer( | |
| GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations) | ||
| case x: GroupingSets => | ||
| val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)() | ||
| // We will insert another Projection if the GROUP BY keys contains the | ||
| // non-attribute expressions. And the top operators can references those | ||
| // expressions by its alias. | ||
| // e.g. SELECT key%5 as c1 FROM src GROUP BY key%5 ==> | ||
| // SELECT a as c1 FROM (SELECT key%5 AS a FROM src) GROUP BY a | ||
|
|
||
| // find all of the non-attribute expressions in the GROUP BY keys | ||
| val nonAttributeGroupByExpressions = new ArrayBuffer[Alias]() | ||
|
|
||
| // The pair of (the original GROUP BY key, associated attribute) | ||
| val groupByExprPairs = x.groupByExprs.map(_ match { | ||
| case e: NamedExpression => (e, e.toAttribute) | ||
| case other => { | ||
| val alias = Alias(other, other.toString)() | ||
| nonAttributeGroupByExpressions += alias // add the non-attributes expression alias | ||
| (other, alias.toAttribute) | ||
| } | ||
| }) | ||
|
|
||
| // substitute the non-attribute expressions for aggregations. | ||
| val aggregation = x.aggregations.map(expr => expr.transformDown { | ||
| case e => groupByExprPairs.find(_._1.semanticEquals(e)).map(_._2).getOrElse(e) | ||
| }.asInstanceOf[NamedExpression]) | ||
|
|
||
| // substitute the group by expressions. | ||
| val newGroupByExprs = groupByExprPairs.map(_._2) | ||
| // Expand works by setting grouping expressions to null as determined by the bitmasks. To | ||
| // prevent these null values from being used in an aggregate instead of the original value | ||
| // we need to create new aliases for all group by expressions that will only be used for | ||
| // the intended purpose. | ||
| val groupByAliases: Seq[Alias] = x.groupByExprs.map { | ||
| case e: NamedExpression => Alias(e, e.name)() | ||
| case other => Alias(other, other.toString)() | ||
| } | ||
|
|
||
| val child = if (nonAttributeGroupByExpressions.length > 0) { | ||
| // insert additional projection if contains the | ||
| // non-attribute expressions in the GROUP BY keys | ||
| Project(x.child.output ++ nonAttributeGroupByExpressions, x.child) | ||
| } else { | ||
| x.child | ||
| val aggregations: Seq[NamedExpression] = x.aggregations.map { | ||
| // If an expression is an aggregate (contains a AggregateExpression) then we dont change | ||
| // it so that the aggregation is computed on the unmodified value of its argument | ||
| // expressions. | ||
| case expr if expr.find(_.isInstanceOf[AggregateExpression]).nonEmpty => expr | ||
| // If not then its a grouping expression and we need to use the modified (with nulls from | ||
| // Expand) value of the expression. | ||
| case expr => expr.transformDown { | ||
| case e => groupByAliases.find(_.child.semanticEquals(e)).map(_.toAttribute).getOrElse(e) | ||
| }.asInstanceOf[NamedExpression] | ||
| } | ||
|
|
||
| val child = Project(x.child.output ++ groupByAliases, x.child) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So, we will rely on our optimizer to remove this Project if it is not necessary, right? |
||
| val groupByAttributes = groupByAliases.map(_.toAttribute) | ||
|
|
||
| Aggregate( | ||
| newGroupByExprs :+ VirtualColumn.groupingIdAttribute, | ||
| aggregation, | ||
| Expand(x.bitmasks, newGroupByExprs, gid, child)) | ||
| groupByAttributes :+ VirtualColumn.groupingIdAttribute, | ||
| aggregations, | ||
| Expand(x.bitmasks, groupByAttributes, gid, child)) | ||
| } | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -323,6 +323,10 @@ trait GroupingAnalytics extends UnaryNode { | |
|
|
||
| override def output: Seq[Attribute] = aggregations.map(_.toAttribute) | ||
|
|
||
| // Needs to be unresolved before its translated to Aggregate + Expand because output attributes | ||
| // will change in analysis. | ||
| override lazy val resolved: Boolean = false | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this the only change we need to fix this problem?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No. This was a (minor) problem before that was not caught by any of the test cases, it's now more necessary since we duplicate all the grouping columns in the analyzer rule. |
||
|
|
||
| def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -60,6 +60,68 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { | |
| ) | ||
| } | ||
|
|
||
| test("rollup") { | ||
| checkAnswer( | ||
| courseSales.rollup("course", "year").sum("earnings"), | ||
| Row("Java", 2012, 20000.0) :: | ||
| Row("Java", 2013, 30000.0) :: | ||
| Row("Java", null, 50000.0) :: | ||
| Row("dotNET", 2012, 15000.0) :: | ||
| Row("dotNET", 2013, 48000.0) :: | ||
| Row("dotNET", null, 63000.0) :: | ||
| Row(null, null, 113000.0) :: Nil | ||
| ) | ||
| } | ||
|
|
||
| test("cube") { | ||
| checkAnswer( | ||
| courseSales.cube("course", "year").sum("earnings"), | ||
| Row("Java", 2012, 20000.0) :: | ||
| Row("Java", 2013, 30000.0) :: | ||
| Row("Java", null, 50000.0) :: | ||
| Row("dotNET", 2012, 15000.0) :: | ||
| Row("dotNET", 2013, 48000.0) :: | ||
| Row("dotNET", null, 63000.0) :: | ||
| Row(null, 2012, 35000.0) :: | ||
| Row(null, 2013, 78000.0) :: | ||
| Row(null, null, 113000.0) :: Nil | ||
| ) | ||
| } | ||
|
|
||
| test("rollup overlapping columns") { | ||
| checkAnswer( | ||
| testData2.rollup($"a" + $"b" as "foo", $"b" as "bar").agg(sum($"a" - $"b") as "foo"), | ||
| Row(2, 1, 0) :: Row(3, 2, -1) :: Row(3, 1, 1) :: Row(4, 2, 0) :: Row(4, 1, 2) :: Row(5, 2, 1) | ||
| :: Row(2, null, 0) :: Row(3, null, 0) :: Row(4, null, 2) :: Row(5, null, 1) | ||
| :: Row(null, null, 3) :: Nil | ||
| ) | ||
|
|
||
| checkAnswer( | ||
| testData2.rollup("a", "b").agg(sum("b")), | ||
| Row(1, 1, 1) :: Row(1, 2, 2) :: Row(2, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 1) :: Row(3, 2, 2) | ||
| :: Row(1, null, 3) :: Row(2, null, 3) :: Row(3, null, 3) | ||
| :: Row(null, null, 9) :: Nil | ||
| ) | ||
| } | ||
|
|
||
| test("cube overlapping columns") { | ||
| checkAnswer( | ||
| testData2.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")), | ||
| Row(2, 1, 0) :: Row(3, 2, -1) :: Row(3, 1, 1) :: Row(4, 2, 0) :: Row(4, 1, 2) :: Row(5, 2, 1) | ||
| :: Row(2, null, 0) :: Row(3, null, 0) :: Row(4, null, 2) :: Row(5, null, 1) | ||
| :: Row(null, 1, 3) :: Row(null, 2, 0) | ||
| :: Row(null, null, 3) :: Nil | ||
| ) | ||
|
|
||
| checkAnswer( | ||
| testData2.cube("a", "b").agg(sum("b")), | ||
| Row(1, 1, 1) :: Row(1, 2, 2) :: Row(2, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 1) :: Row(3, 2, 2) | ||
| :: Row(1, null, 3) :: Row(2, null, 3) :: Row(3, null, 3) | ||
| :: Row(null, 1, 3) :: Row(null, 2, 6) | ||
| :: Row(null, null, 9) :: Nil | ||
| ) | ||
| } | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. these
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. correct |
||
| test("spark.sql.retainGroupColumns config") { | ||
| checkAnswer( | ||
| testData2.groupBy("a").agg(sum($"b")), | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I see. We need to distinguish the real grouping expression values and those nulls added by expand. Otherwise, we generate wrong results.