diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index beabacfc88e3..bd073c3ff0ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -208,31 +208,34 @@ class Analyzer( // We will insert another Projection if the GROUP BY keys contains the // non-attribute expressions. And the top operators can references those // expressions by its alias. - // e.g. SELECT key%5 as c1 FROM src GROUP BY key%5 ==> - // SELECT a as c1 FROM (SELECT key%5 AS a FROM src) GROUP BY a + // e.g. SELECT key%5 as c1 FROM src GROUP BY key%5 with rollup ==> + // SELECT a as c1 FROM (SELECT key%5 AS a FROM src) GROUP BY a with rollup // find all of the non-attribute expressions in the GROUP BY keys val nonAttributeGroupByExpressions = new ArrayBuffer[Alias]() // The pair of (the original GROUP BY key, associated attribute) - val groupByExprPairs = x.groupByExprs.map(_ match { + val groupByExprPairs = x.groupByExprs.map { case e: NamedExpression => (e, e.toAttribute) case other => { val alias = Alias(other, other.toString)() nonAttributeGroupByExpressions += alias // add the non-attributes expression alias (other, alias.toAttribute) } - }) + } - // substitute the non-attribute expressions for aggregations. + // substitute the non-attribute expressions in aggregations + // by the generated aliases. Here, it does not include the ones that + // function as the input parameters in the other expressions. val aggregation = x.aggregations.map(expr => expr.transformDown { - case e => groupByExprPairs.find(_._1.semanticEquals(e)).map(_._2).getOrElse(e) + case alias @ Alias(e: Expression, _) => + groupByExprPairs.find(_._1.semanticEquals(e)).map(_._2).getOrElse(alias) }.asInstanceOf[NamedExpression]) // substitute the group by expressions. val newGroupByExprs = groupByExprPairs.map(_._2) - val child = if (nonAttributeGroupByExpressions.length > 0) { + val child = if (nonAttributeGroupByExpressions.nonEmpty) { // insert additional projection if contains the // non-attribute expressions in the GROUP BY keys Project(x.child.output ++ nonAttributeGroupByExpressions, x.child) @@ -240,10 +243,57 @@ class Analyzer( x.child } + // When expanding the input rows during evaluation, the column values will be set to + // null for the grouping sets that contains null (e.g., (null, null)). If these values + // are also used in aggregation, the aggregated values are always null. + // Thus, we will insert another Projection if the GROUP BY keys are contained in the + // aggregation. And the top operators can references those keys by its alias. + // e.g. SELECT a, b, sum(a) FROM src GROUP BY a, b with rollup ==> + // SELECT a, b, sum(a1) FROM (SELECT a, b, a AS a1 FROM src) GROUP BY a, b with rollup + + // collect all the distinct attributes that are in both aggregation functions and + // group by clauses + val attrInAggregatedFuncAndGroupBy = aggregation.collect { + case aggFunc: Alias => aggFunc.collect { + case a : Attribute if newGroupByExprs.contains(a) => a} + }.flatten.distinct + + val alias4AttrInAggregatedFuncAndGroupBy = new ArrayBuffer[Alias]() + + // generate alias for each attribute in attrInAggregatedFuncAndGroupBy + val attrInAggregatedFuncPairs = attrInAggregatedFuncAndGroupBy.map(a => { + val alias = Alias(a, a.toString)() + alias4AttrInAggregatedFuncAndGroupBy += alias + (a, alias.toAttribute) + }) + + val nonAttributeGroupByExpressionsToAttribute = + nonAttributeGroupByExpressions.map(a => a.toAttribute) + + val newAggregation = aggregation.map { + case a : Alias => a.transform { + // Must avoid replace the alias replaced by the first step; + // otherwise, the following case will fail: + // select a + b, b, sum(a - b) from test group by a + b, b with cube + case e => attrInAggregatedFuncPairs.find( + _._1==e && !nonAttributeGroupByExpressionsToAttribute.contains(e)). + map(_._2).getOrElse(e) + }.asInstanceOf[NamedExpression] + case other => other + } + + val newChild = if (alias4AttrInAggregatedFuncAndGroupBy.nonEmpty) { + // When applying this rule, two Projections could be generated. + // Here, we expect the optimizer can collapse them. + Project(child.output ++ alias4AttrInAggregatedFuncAndGroupBy, child) + } else { + child + } + Aggregate( newGroupByExprs :+ VirtualColumn.groupingIdAttribute, - aggregation, - Expand(x.bitmasks, newGroupByExprs, gid, child)) + newAggregation, + Expand(x.bitmasks, newGroupByExprs, gid, newChild)) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala index 2e5cae415e54..e6eab125d5f0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.{DataFrame, QueryTest} +import org.apache.spark.sql.{Row, DataFrame, QueryTest} import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.scalatest.BeforeAndAfterAll @@ -32,7 +32,7 @@ class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with private var testData: DataFrame = _ override def beforeAll() { - testData = Seq((1, 2), (2, 4)).toDF("a", "b") + testData = Seq((1, 2), (2, 4), (2, 9)).toDF("a", "b") hiveContext.registerDataFrameAsTable(testData, "mytable") } @@ -40,27 +40,160 @@ class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with hiveContext.dropTempTable("mytable") } - test("rollup") { + test("rollup: aggregation input parameters overlap with the non-attribute expressions in group by") { + val sqlRollUp = sql( + """ + SELECT a + b, b, sum(a + b) as ab, (a + b) as c FROM mytable GROUP BY a + b, b WITH ROLLUP + """.stripMargin) + + val res = sqlRollUp.collect() + + val expected = + Row ( null, null, 20, null) :: + Row ( 3, null, 3, 3 ) :: + Row ( 6, null, 6, 6 ) :: + Row (11, null, 11, 11 ) :: + Row ( 3, 2, 3, 3 ) :: + Row ( 6, 4, 6, 6 ) :: + Row (11, 9, 11, 11 ) :: Nil + + checkAnswer(sqlRollUp, expected) + + checkAnswer( + testData.rollup($"a" + $"b", $"b").agg(sum($"a" + $"b"), $"a" + $"b"), + expected + ) + } + + test("rollup: group by function") { + val sqlRollUp = sql( + """ + |SELECT a + b, b, sum(a - b) as ab + |FROM mytable + |GROUP BY a + b, b WITH ROLLUP + """.stripMargin) + + val expected = + Row (null, null, -10) :: + Row (3, null, -1) :: + Row (6, null, -2) :: + Row (11, null, -7) :: + Row (3, 2, -1) :: + Row (6, 4, -2) :: + Row (11, 9, -7) :: Nil + + checkAnswer(sqlRollUp, expected) + checkAnswer( testData.rollup($"a" + $"b", $"b").agg(sum($"a" - $"b")), - sql("select a + b, b, sum(a - b) from mytable group by a + b, b with rollup").collect() + expected + ) + } + + test("rollup: aggregation function parameters overlap with the group by columns") { + val sqlRollUp = sql( + """ + |SELECT a, b, sum(b), max(b), min(b+b) + |FROM mytable + |GROUP BY a, b WITH ROLLUP + """.stripMargin) + + val expected = + Row (null, null, 15, 9, 4) :: + Row (1, null, 2, 2, 4) :: + Row (2, null, 13, 9, 8) :: + Row (1, 2, 2, 2, 4) :: + Row (2, 4, 4, 4, 8) :: + Row (2, 9, 9, 9, 18) :: Nil + + checkAnswer(sqlRollUp, expected) + + checkAnswer( + testData.rollup("a", "b").agg(sum("b"), max("b"), min($"b" + $"b")), + expected ) + } + + test("cube: aggregation input parameters overlap with the non-attribute expressions in group by") { + val sqlCube = sql( + """ + SELECT a + b, b, sum(a + b) as ab, (a + b) as c FROM mytable GROUP BY a + b, b WITH CUBE + """.stripMargin) + + val res = sqlCube.collect() + + val expected = + Row ( null, 2, 3, null) :: + Row ( null, 4, 6, null) :: + Row ( null, 9, 11, null) :: + Row ( null, null, 20, null) :: + Row ( 3, null, 3, 3 ) :: + Row ( 6, null, 6, 6 ) :: + Row (11, null, 11, 11 ) :: + Row ( 3, 2, 3, 3 ) :: + Row ( 6, 4, 6, 6 ) :: + Row (11, 9, 11, 11 ) :: Nil + + checkAnswer(sqlCube, expected) checkAnswer( - testData.rollup("a", "b").agg(sum("b")), - sql("select a, b, sum(b) from mytable group by a, b with rollup").collect() + testData.cube($"a" + $"b", $"b").agg(sum($"a" + $"b"), $"a" + $"b"), + expected ) } - test("cube") { + test("cube: group by function") { + val sqlCube = sql( + """ + |SELECT a + b, b, sum(a - b) as ab + |FROM mytable + |GROUP BY a + b, b WITH CUBE + """.stripMargin) + + val expected = + Row (null, 2, -1) :: + Row (null, 4, -2) :: + Row (null, 9, -7) :: + Row (null, null, -10) :: + Row (3, null, -1) :: + Row (6, null, -2) :: + Row (11, null, -7) :: + Row (3, 2, -1) :: + Row (6, 4, -2) :: + Row (11, 9, -7) :: Nil + + checkAnswer(sqlCube, expected) + checkAnswer( testData.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")), - sql("select a + b, b, sum(a - b) from mytable group by a + b, b with cube").collect() + expected ) + } + + test("cube: aggregation function parameters overlap with the group by columns") { + val sqlCube = sql( + """ + |SELECT a, b, sum(b), max(b), min(b+b) + |FROM mytable + |GROUP BY a, b WITH CUBE + """.stripMargin) + + val expected = + Row (null, 2, 2, 2, 4) :: + Row (null, 4, 4, 4, 8) :: + Row (null, 9, 9, 9, 18) :: + Row (null, null, 15, 9, 4) :: + Row (1, null, 2, 2, 4) :: + Row (2, null, 13, 9, 8) :: + Row (1, 2, 2, 2, 4) :: + Row (2, 4, 4, 4, 8) :: + Row (2, 9, 9, 9, 18) :: Nil + + checkAnswer(sqlCube, expected) checkAnswer( - testData.cube("a", "b").agg(sum("b")), - sql("select a, b, sum(b) from mytable group by a, b with cube").collect() + testData.cube("a", "b").agg(sum("b"), max("b"), min($"b" + $"b")), + expected ) } }