-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-11275][SQL] Rollup and Cube Generates the Incorrect Results when Aggregation Functions Use Group By Columns #9419
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -208,42 +208,92 @@ class Analyzer( | |
| // We will insert another Projection if the GROUP BY keys contains the | ||
| // non-attribute expressions. And the top operators can references those | ||
| // expressions by its alias. | ||
| // e.g. SELECT key%5 as c1 FROM src GROUP BY key%5 ==> | ||
| // SELECT a as c1 FROM (SELECT key%5 AS a FROM src) GROUP BY a | ||
| // e.g. SELECT key%5 as c1 FROM src GROUP BY key%5 with rollup ==> | ||
| // SELECT a as c1 FROM (SELECT key%5 AS a FROM src) GROUP BY a with rollup | ||
|
|
||
| // find all of the non-attribute expressions in the GROUP BY keys | ||
| val nonAttributeGroupByExpressions = new ArrayBuffer[Alias]() | ||
|
|
||
| // The pair of (the original GROUP BY key, associated attribute) | ||
| val groupByExprPairs = x.groupByExprs.map(_ match { | ||
| val groupByExprPairs = x.groupByExprs.map { | ||
| case e: NamedExpression => (e, e.toAttribute) | ||
| case other => { | ||
| val alias = Alias(other, other.toString)() | ||
| nonAttributeGroupByExpressions += alias // add the non-attributes expression alias | ||
| (other, alias.toAttribute) | ||
| } | ||
| }) | ||
| } | ||
|
|
||
| // substitute the non-attribute expressions for aggregations. | ||
| // substitute the non-attribute expressions in aggregations | ||
| // by the generated aliases. Here, it does not include the ones that | ||
| // function as the input parameters in the other expressions. | ||
| val aggregation = x.aggregations.map(expr => expr.transformDown { | ||
| case e => groupByExprPairs.find(_._1.semanticEquals(e)).map(_._2).getOrElse(e) | ||
| case alias @ Alias(e: Expression, _) => | ||
| groupByExprPairs.find(_._1.semanticEquals(e)).map(_._2).getOrElse(alias) | ||
| }.asInstanceOf[NamedExpression]) | ||
|
|
||
| // substitute the group by expressions. | ||
| val newGroupByExprs = groupByExprPairs.map(_._2) | ||
|
|
||
| val child = if (nonAttributeGroupByExpressions.length > 0) { | ||
| val child = if (nonAttributeGroupByExpressions.nonEmpty) { | ||
| // insert additional projection if contains the | ||
| // non-attribute expressions in the GROUP BY keys | ||
| Project(x.child.output ++ nonAttributeGroupByExpressions, x.child) | ||
| } else { | ||
| x.child | ||
| } | ||
|
|
||
| // When expanding the input rows during evaluation, the column values will be set to | ||
| // null for the grouping sets that contains null (e.g., (null, null)). If these values | ||
| // are also used in aggregation, the aggregated values are always null. | ||
| // Thus, we will insert another Projection if the GROUP BY keys are contained in the | ||
| // aggregation. And the top operators can references those keys by its alias. | ||
| // e.g. SELECT a, b, sum(a) FROM src GROUP BY a, b with rollup ==> | ||
| // SELECT a, b, sum(a1) FROM (SELECT a, b, a AS a1 FROM src) GROUP BY a, b with rollup | ||
|
|
||
| // collect all the distinct attributes that are in both aggregation functions and | ||
| // group by clauses | ||
| val attrInAggregatedFuncAndGroupBy = aggregation.collect { | ||
| case aggFunc: Alias => aggFunc.collect { | ||
| case a : Attribute if newGroupByExprs.contains(a) => a} | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Were doing a contains here on a sequence, this could maybe get a bit slow with a large number of aggregates / grouping expressions |
||
| }.flatten.distinct | ||
|
|
||
| val alias4AttrInAggregatedFuncAndGroupBy = new ArrayBuffer[Alias]() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems we already know the initial size of this so lets at try and use that to reduce allocations. |
||
|
|
||
| // generate alias for each attribute in attrInAggregatedFuncAndGroupBy | ||
| val attrInAggregatedFuncPairs = attrInAggregatedFuncAndGroupBy.map(a => { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is just being used for lookup, could we - do we expect this to always be pretty small? |
||
| val alias = Alias(a, a.toString)() | ||
| alias4AttrInAggregatedFuncAndGroupBy += alias | ||
| (a, alias.toAttribute) | ||
| }) | ||
|
|
||
| val nonAttributeGroupByExpressionsToAttribute = | ||
| nonAttributeGroupByExpressions.map(a => a.toAttribute) | ||
|
|
||
| val newAggregation = aggregation.map { | ||
| case a : Alias => a.transform { | ||
| // Must avoid replace the alias replaced by the first step; | ||
| // otherwise, the following case will fail: | ||
| // select a + b, b, sum(a - b) from test group by a + b, b with cube | ||
| case e => attrInAggregatedFuncPairs.find( | ||
| _._1==e && !nonAttributeGroupByExpressionsToAttribute.contains(e)). | ||
| map(_._2).getOrElse(e) | ||
| }.asInstanceOf[NamedExpression] | ||
| case other => other | ||
| } | ||
|
|
||
| val newChild = if (alias4AttrInAggregatedFuncAndGroupBy.nonEmpty) { | ||
| // When applying this rule, two Projections could be generated. | ||
| // Here, we expect the optimizer can collapse them. | ||
| Project(child.output ++ alias4AttrInAggregatedFuncAndGroupBy, child) | ||
| } else { | ||
| child | ||
| } | ||
|
|
||
| Aggregate( | ||
| newGroupByExprs :+ VirtualColumn.groupingIdAttribute, | ||
| aggregation, | ||
| Expand(x.bitmasks, newGroupByExprs, gid, child)) | ||
| newAggregation, | ||
| Expand(x.bitmasks, newGroupByExprs, gid, newChild)) | ||
| } | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,7 +17,7 @@ | |
|
|
||
| package org.apache.spark.sql.hive | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The unit tests for the DataFrame methods should be in the base sql package not hive so that they are run when not compiled with hive support.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you, Andrew! Using SQL statements might generate a different logical plan. Thus, we need to do the unit test for both SQL and Dataframe in this package org.apache.spark.sql.hive. Based on my understanding, DataFrame methods are shared by both hiveContext and sqlContext. Thus, we do not need to do it again in the base sql package. Right? |
||
|
|
||
| import org.apache.spark.sql.{DataFrame, QueryTest} | ||
| import org.apache.spark.sql.{Row, DataFrame, QueryTest} | ||
| import org.apache.spark.sql.functions._ | ||
| import org.apache.spark.sql.hive.test.TestHiveSingleton | ||
| import org.scalatest.BeforeAndAfterAll | ||
|
|
@@ -32,35 +32,168 @@ class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with | |
| private var testData: DataFrame = _ | ||
|
|
||
| override def beforeAll() { | ||
| testData = Seq((1, 2), (2, 4)).toDF("a", "b") | ||
| testData = Seq((1, 2), (2, 4), (2, 9)).toDF("a", "b") | ||
| hiveContext.registerDataFrameAsTable(testData, "mytable") | ||
| } | ||
|
|
||
| override def afterAll(): Unit = { | ||
| hiveContext.dropTempTable("mytable") | ||
| } | ||
|
|
||
| test("rollup") { | ||
| test("rollup: aggregation input parameters overlap with the non-attribute expressions in group by") { | ||
| val sqlRollUp = sql( | ||
| """ | ||
| SELECT a + b, b, sum(a + b) as ab, (a + b) as c FROM mytable GROUP BY a + b, b WITH ROLLUP | ||
| """.stripMargin) | ||
|
|
||
| val res = sqlRollUp.collect() | ||
|
|
||
| val expected = | ||
| Row ( null, null, 20, null) :: | ||
| Row ( 3, null, 3, 3 ) :: | ||
| Row ( 6, null, 6, 6 ) :: | ||
| Row (11, null, 11, 11 ) :: | ||
| Row ( 3, 2, 3, 3 ) :: | ||
| Row ( 6, 4, 6, 6 ) :: | ||
| Row (11, 9, 11, 11 ) :: Nil | ||
|
|
||
| checkAnswer(sqlRollUp, expected) | ||
|
|
||
| checkAnswer( | ||
| testData.rollup($"a" + $"b", $"b").agg(sum($"a" + $"b"), $"a" + $"b"), | ||
| expected | ||
| ) | ||
| } | ||
|
|
||
| test("rollup: group by function") { | ||
| val sqlRollUp = sql( | ||
| """ | ||
| |SELECT a + b, b, sum(a - b) as ab | ||
| |FROM mytable | ||
| |GROUP BY a + b, b WITH ROLLUP | ||
| """.stripMargin) | ||
|
|
||
| val expected = | ||
| Row (null, null, -10) :: | ||
| Row (3, null, -1) :: | ||
| Row (6, null, -2) :: | ||
| Row (11, null, -7) :: | ||
| Row (3, 2, -1) :: | ||
| Row (6, 4, -2) :: | ||
| Row (11, 9, -7) :: Nil | ||
|
|
||
| checkAnswer(sqlRollUp, expected) | ||
|
|
||
| checkAnswer( | ||
| testData.rollup($"a" + $"b", $"b").agg(sum($"a" - $"b")), | ||
| sql("select a + b, b, sum(a - b) from mytable group by a + b, b with rollup").collect() | ||
| expected | ||
| ) | ||
| } | ||
|
|
||
| test("rollup: aggregation function parameters overlap with the group by columns") { | ||
| val sqlRollUp = sql( | ||
| """ | ||
| |SELECT a, b, sum(b), max(b), min(b+b) | ||
| |FROM mytable | ||
| |GROUP BY a, b WITH ROLLUP | ||
| """.stripMargin) | ||
|
|
||
| val expected = | ||
| Row (null, null, 15, 9, 4) :: | ||
| Row (1, null, 2, 2, 4) :: | ||
| Row (2, null, 13, 9, 8) :: | ||
| Row (1, 2, 2, 2, 4) :: | ||
| Row (2, 4, 4, 4, 8) :: | ||
| Row (2, 9, 9, 9, 18) :: Nil | ||
|
|
||
| checkAnswer(sqlRollUp, expected) | ||
|
|
||
| checkAnswer( | ||
| testData.rollup("a", "b").agg(sum("b"), max("b"), min($"b" + $"b")), | ||
| expected | ||
| ) | ||
| } | ||
|
|
||
| test("cube: aggregation input parameters overlap with the non-attribute expressions in group by") { | ||
| val sqlCube = sql( | ||
| """ | ||
| SELECT a + b, b, sum(a + b) as ab, (a + b) as c FROM mytable GROUP BY a + b, b WITH CUBE | ||
| """.stripMargin) | ||
|
|
||
| val res = sqlCube.collect() | ||
|
|
||
| val expected = | ||
| Row ( null, 2, 3, null) :: | ||
| Row ( null, 4, 6, null) :: | ||
| Row ( null, 9, 11, null) :: | ||
| Row ( null, null, 20, null) :: | ||
| Row ( 3, null, 3, 3 ) :: | ||
| Row ( 6, null, 6, 6 ) :: | ||
| Row (11, null, 11, 11 ) :: | ||
| Row ( 3, 2, 3, 3 ) :: | ||
| Row ( 6, 4, 6, 6 ) :: | ||
| Row (11, 9, 11, 11 ) :: Nil | ||
|
|
||
| checkAnswer(sqlCube, expected) | ||
|
|
||
| checkAnswer( | ||
| testData.rollup("a", "b").agg(sum("b")), | ||
| sql("select a, b, sum(b) from mytable group by a, b with rollup").collect() | ||
| testData.cube($"a" + $"b", $"b").agg(sum($"a" + $"b"), $"a" + $"b"), | ||
| expected | ||
| ) | ||
| } | ||
|
|
||
| test("cube") { | ||
| test("cube: group by function") { | ||
| val sqlCube = sql( | ||
| """ | ||
| |SELECT a + b, b, sum(a - b) as ab | ||
| |FROM mytable | ||
| |GROUP BY a + b, b WITH CUBE | ||
| """.stripMargin) | ||
|
|
||
| val expected = | ||
| Row (null, 2, -1) :: | ||
| Row (null, 4, -2) :: | ||
| Row (null, 9, -7) :: | ||
| Row (null, null, -10) :: | ||
| Row (3, null, -1) :: | ||
| Row (6, null, -2) :: | ||
| Row (11, null, -7) :: | ||
| Row (3, 2, -1) :: | ||
| Row (6, 4, -2) :: | ||
| Row (11, 9, -7) :: Nil | ||
|
|
||
| checkAnswer(sqlCube, expected) | ||
|
|
||
| checkAnswer( | ||
| testData.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")), | ||
| sql("select a + b, b, sum(a - b) from mytable group by a + b, b with cube").collect() | ||
| expected | ||
| ) | ||
| } | ||
|
|
||
| test("cube: aggregation function parameters overlap with the group by columns") { | ||
| val sqlCube = sql( | ||
| """ | ||
| |SELECT a, b, sum(b), max(b), min(b+b) | ||
| |FROM mytable | ||
| |GROUP BY a, b WITH CUBE | ||
| """.stripMargin) | ||
|
|
||
| val expected = | ||
| Row (null, 2, 2, 2, 4) :: | ||
| Row (null, 4, 4, 4, 8) :: | ||
| Row (null, 9, 9, 9, 18) :: | ||
| Row (null, null, 15, 9, 4) :: | ||
| Row (1, null, 2, 2, 4) :: | ||
| Row (2, null, 13, 9, 8) :: | ||
| Row (1, 2, 2, 2, 4) :: | ||
| Row (2, 4, 4, 4, 8) :: | ||
| Row (2, 9, 9, 9, 18) :: Nil | ||
|
|
||
| checkAnswer(sqlCube, expected) | ||
|
|
||
| checkAnswer( | ||
| testData.cube("a", "b").agg(sum("b")), | ||
| sql("select a, b, sum(b) from mytable group by a, b with cube").collect() | ||
| testData.cube("a", "b").agg(sum("b"), max("b"), min($"b" + $"b")), | ||
| expected | ||
| ) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Take this for example:
I think we probably need to add extra column(s) for the output of
Expand, (e.g.(a-b)in this case). Will that be more simple for this fixing?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, @chenghao-intel ,
Could you explain it a little bit more?
So far, this query is correctly processed and returned a correct result. Since b is part of an aggregated function, the fix added extra columns for b. Below is the generated plan:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, sorry, actually there are so big difference as I mentioned.
But I got error when do the query like below, can you please take look at it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@chenghao-intel , a good catch! Thank you!
I fixed this issue if you integrate the latest change. Also added two more test cases to cover it.