Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -208,42 +208,92 @@ class Analyzer(
// We will insert another Projection if the GROUP BY keys contains the
// non-attribute expressions. And the top operators can references those
// expressions by its alias.
// e.g. SELECT key%5 as c1 FROM src GROUP BY key%5 ==>
// SELECT a as c1 FROM (SELECT key%5 AS a FROM src) GROUP BY a
// e.g. SELECT key%5 as c1 FROM src GROUP BY key%5 with rollup ==>
// SELECT a as c1 FROM (SELECT key%5 AS a FROM src) GROUP BY a with rollup

// find all of the non-attribute expressions in the GROUP BY keys
val nonAttributeGroupByExpressions = new ArrayBuffer[Alias]()

// The pair of (the original GROUP BY key, associated attribute)
val groupByExprPairs = x.groupByExprs.map(_ match {
val groupByExprPairs = x.groupByExprs.map {
case e: NamedExpression => (e, e.toAttribute)
case other => {
val alias = Alias(other, other.toString)()
nonAttributeGroupByExpressions += alias // add the non-attributes expression alias
(other, alias.toAttribute)
}
})
}

// substitute the non-attribute expressions for aggregations.
// substitute the non-attribute expressions in aggregations
// by the generated aliases. Here, it does not include the ones that
// function as the input parameters in the other expressions.
val aggregation = x.aggregations.map(expr => expr.transformDown {
case e => groupByExprPairs.find(_._1.semanticEquals(e)).map(_._2).getOrElse(e)
case alias @ Alias(e: Expression, _) =>
groupByExprPairs.find(_._1.semanticEquals(e)).map(_._2).getOrElse(alias)
}.asInstanceOf[NamedExpression])

// substitute the group by expressions.
val newGroupByExprs = groupByExprPairs.map(_._2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Take this for example:

select sum(a-b) as ab from mytable group by b with rollup;

I think we probably need to add extra column(s) for the output of Expand, (e.g. (a-b) in this case). Will that be more simple for this fixing?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @chenghao-intel ,

Could you explain it a little bit more?

So far, this query is correctly processed and returned a correct result. Since b is part of an aggregated function, the fix added extra columns for b. Below is the generated plan:

== Analyzed Logical Plan ==
ab: bigint
Aggregate [b#3,grouping__id#12], [sum(cast((a#2 - b#3#13) as bigint)) AS ab#4L]
 Expand [0,1], [b#3], grouping__id#12
  Project [a#2,b#3,b#3 AS b#3#13]
   Subquery mytable
    Project [_1#0 AS a#2,_2#1 AS b#3]
     LocalRelation [_1#0,_2#1], [[1,2],[2,4],[2,9]]

== Optimized Logical Plan ==
Aggregate [b#3,grouping__id#12], [sum(cast((a#2 - b#3#13) as bigint)) AS ab#4L]
 Expand [0,1], [b#3], grouping__id#12
  LocalRelation [a#2,b#3,b#3#13], [[1,2,2],[2,4,4],[2,9,9]]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, sorry, actually there are so big difference as I mentioned.

But I got error when do the query like below, can you please take look at it?

select sum(a+b) as ab from mytable group by a+b, b with rollup;
15/11/04 17:46:36 ERROR thriftserver.SparkSQLDriver: Failed in [select sum(a+b) as ab from mytable group by a+b, b with rollup]
org.apache.spark.sql.catalyst.analysis.UnresolvedException: Invalid call to dataType on unresolved object, tree: '(cast(a#109 as double) + b#110)
    at org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute.dataType(unresolved.scala:59)
    at org.apache.spark.sql.catalyst.plans.logical.Expand$$anonfun$expand$1$$anonfun$5$$anonfun$apply$3.applyOrElse(basicOperators.scala:291)
    at org.apache.spark.sql.catalyst.plans.logical.Expand$$anonfun$expand$1$$anonfun$5$$anonfun$apply$3.applyOrElse(basicOperators.scala:287)
    at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:227)
    at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:227)
    at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:51)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:226)
    at org.apache.spark.sql.catalyst.plans.logical.Expand$$anonfun$expand$1$$anonfun$5.apply(basicOperators.scala:287)
    at org.apache.spark.sql.catalyst.plans.logical.Expand$$anonfun$expand$1$$anonfun$5.apply(basicOperators.scala:287)
    at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:244)
    at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:244)
    at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
    at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:47)
    at scala.collection.TraversableLike$class.map(TraversableLike.scala:244)
    at scala.collection.AbstractTraversable.map(Traversable.scala:105)
    at org.apache.spark.sql.catalyst.plans.logical.Expand$$anonfun$expand$1.apply(basicOperators.scala:287)
    at org.apache.spark.sql.catalyst.plans.logical.Expand$$anonfun$expand$1.apply(basicOperators.scala:283)
    at scala.collection.immutable.List.foreach(List.scala:318)
    at org.apache.spark.sql.catalyst.plans.logical.Expand.expand(basicOperators.scala:283)
    at org.apache.spark.sql.catalyst.plans.logical.Expand.<init>(basicOperators.scala:254)
    at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveGroupingAnalytics$$anonfun$apply$6.applyOrElse(Analyzer.scala:293)
    at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveGroupingAnalytics$$anonfun$apply$6.applyOrElse(Analyzer.scala:200)
    at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan$$anonfun$resolveOperators$1.apply(LogicalPlan.scala:57)
    at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan$$anonfun$resolveOperators$1.apply(LogicalPlan.scala:57)
    at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:51)
    at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.resolveOperators(LogicalPlan.scala:56)
    at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveGroupingAnalytics$.apply(Analyzer.scala:200)
    at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveGroupingAnalytics$.apply(Analyzer.scala:173)
    at org.apache.spark.sql.catalyst.rules.RuleExecutor$$anonfun$execute$1$$anonfun$apply$1.apply(RuleExecutor.scala:83)
    at org.apache.spark.sql.catalyst.rules.RuleExecutor$$anonfun$execute$1$$anonfun$apply$1.apply(RuleExecutor.scala:80)
    at scala.collection.LinearSeqOptimized$class.foldLeft(LinearSeqOptimized.scala:111)
    at scala.collection.immutable.List.foldLeft(List.scala:84)
    at org.apache.spark.sql.catalyst.rules.RuleExecutor$$anonfun$execute$1.apply(RuleExecutor.scala:80)
    at org.apache.spark.sql.catalyst.rules.RuleExecutor$$anonfun$execute$1.apply(RuleExecutor.scala:72)
    at scala.collection.immutable.List.foreach(List.scala:318)
    at org.apache.spark.sql.catalyst.rules.RuleExecutor.execute(RuleExecutor.scala:72)
    at org.apache.spark.sql.execution.QueryExecution.analyzed$lzycompute(QueryExecution.scala:38)
    at org.apache.spark.sql.execution.QueryExecution.analyzed(QueryExecution.scala:38)
    at org.apache.spark.sql.execution.QueryExecution.assertAnalyzed(QueryExecution.scala:36)
    at org.apache.spark.sql.DataFrame.<init>(DataFrame.scala:132)
    at org.apache.spark.sql.DataFrame$.apply(DataFrame.scala:51)
    at org.apache.spark.sql.SQLContext.sql(SQLContext.scala:784)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chenghao-intel , a good catch! Thank you!

I fixed this issue if you integrate the latest change. Also added two more test cases to cover it.


val child = if (nonAttributeGroupByExpressions.length > 0) {
val child = if (nonAttributeGroupByExpressions.nonEmpty) {
// insert additional projection if contains the
// non-attribute expressions in the GROUP BY keys
Project(x.child.output ++ nonAttributeGroupByExpressions, x.child)
} else {
x.child
}

// When expanding the input rows during evaluation, the column values will be set to
// null for the grouping sets that contains null (e.g., (null, null)). If these values
// are also used in aggregation, the aggregated values are always null.
// Thus, we will insert another Projection if the GROUP BY keys are contained in the
// aggregation. And the top operators can references those keys by its alias.
// e.g. SELECT a, b, sum(a) FROM src GROUP BY a, b with rollup ==>
// SELECT a, b, sum(a1) FROM (SELECT a, b, a AS a1 FROM src) GROUP BY a, b with rollup

// collect all the distinct attributes that are in both aggregation functions and
// group by clauses
val attrInAggregatedFuncAndGroupBy = aggregation.collect {
case aggFunc: Alias => aggFunc.collect {
case a : Attribute if newGroupByExprs.contains(a) => a}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Were doing a contains here on a sequence, this could maybe get a bit slow with a large number of aggregates / grouping expressions

}.flatten.distinct

val alias4AttrInAggregatedFuncAndGroupBy = new ArrayBuffer[Alias]()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems we already know the initial size of this so lets at try and use that to reduce allocations.


// generate alias for each attribute in attrInAggregatedFuncAndGroupBy
val attrInAggregatedFuncPairs = attrInAggregatedFuncAndGroupBy.map(a => {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is just being used for lookup, could we - do we expect this to always be pretty small?
It also feels a little weird constructing a list of aliases as a side effect inside of the map which is also creating the seq of (name, attribute) but I'm frequently a bit too concerned with functional style bits so maybe this is ok.

val alias = Alias(a, a.toString)()
alias4AttrInAggregatedFuncAndGroupBy += alias
(a, alias.toAttribute)
})

val nonAttributeGroupByExpressionsToAttribute =
nonAttributeGroupByExpressions.map(a => a.toAttribute)

val newAggregation = aggregation.map {
case a : Alias => a.transform {
// Must avoid replace the alias replaced by the first step;
// otherwise, the following case will fail:
// select a + b, b, sum(a - b) from test group by a + b, b with cube
case e => attrInAggregatedFuncPairs.find(
_._1==e && !nonAttributeGroupByExpressionsToAttribute.contains(e)).
map(_._2).getOrElse(e)
}.asInstanceOf[NamedExpression]
case other => other
}

val newChild = if (alias4AttrInAggregatedFuncAndGroupBy.nonEmpty) {
// When applying this rule, two Projections could be generated.
// Here, we expect the optimizer can collapse them.
Project(child.output ++ alias4AttrInAggregatedFuncAndGroupBy, child)
} else {
child
}

Aggregate(
newGroupByExprs :+ VirtualColumn.groupingIdAttribute,
aggregation,
Expand(x.bitmasks, newGroupByExprs, gid, child))
newAggregation,
Expand(x.bitmasks, newGroupByExprs, gid, newChild))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.hive

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The unit tests for the DataFrame methods should be in the base sql package not hive so that they are run when not compiled with hive support.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, Andrew!

Using SQL statements might generate a different logical plan. Thus, we need to do the unit test for both SQL and Dataframe in this package org.apache.spark.sql.hive.

Based on my understanding, DataFrame methods are shared by both hiveContext and sqlContext. Thus, we do not need to do it again in the base sql package. Right?


import org.apache.spark.sql.{DataFrame, QueryTest}
import org.apache.spark.sql.{Row, DataFrame, QueryTest}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.scalatest.BeforeAndAfterAll
Expand All @@ -32,35 +32,168 @@ class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with
private var testData: DataFrame = _

override def beforeAll() {
testData = Seq((1, 2), (2, 4)).toDF("a", "b")
testData = Seq((1, 2), (2, 4), (2, 9)).toDF("a", "b")
hiveContext.registerDataFrameAsTable(testData, "mytable")
}

override def afterAll(): Unit = {
hiveContext.dropTempTable("mytable")
}

test("rollup") {
test("rollup: aggregation input parameters overlap with the non-attribute expressions in group by") {
val sqlRollUp = sql(
"""
SELECT a + b, b, sum(a + b) as ab, (a + b) as c FROM mytable GROUP BY a + b, b WITH ROLLUP
""".stripMargin)

val res = sqlRollUp.collect()

val expected =
Row ( null, null, 20, null) ::
Row ( 3, null, 3, 3 ) ::
Row ( 6, null, 6, 6 ) ::
Row (11, null, 11, 11 ) ::
Row ( 3, 2, 3, 3 ) ::
Row ( 6, 4, 6, 6 ) ::
Row (11, 9, 11, 11 ) :: Nil

checkAnswer(sqlRollUp, expected)

checkAnswer(
testData.rollup($"a" + $"b", $"b").agg(sum($"a" + $"b"), $"a" + $"b"),
expected
)
}

test("rollup: group by function") {
val sqlRollUp = sql(
"""
|SELECT a + b, b, sum(a - b) as ab
|FROM mytable
|GROUP BY a + b, b WITH ROLLUP
""".stripMargin)

val expected =
Row (null, null, -10) ::
Row (3, null, -1) ::
Row (6, null, -2) ::
Row (11, null, -7) ::
Row (3, 2, -1) ::
Row (6, 4, -2) ::
Row (11, 9, -7) :: Nil

checkAnswer(sqlRollUp, expected)

checkAnswer(
testData.rollup($"a" + $"b", $"b").agg(sum($"a" - $"b")),
sql("select a + b, b, sum(a - b) from mytable group by a + b, b with rollup").collect()
expected
)
}

test("rollup: aggregation function parameters overlap with the group by columns") {
val sqlRollUp = sql(
"""
|SELECT a, b, sum(b), max(b), min(b+b)
|FROM mytable
|GROUP BY a, b WITH ROLLUP
""".stripMargin)

val expected =
Row (null, null, 15, 9, 4) ::
Row (1, null, 2, 2, 4) ::
Row (2, null, 13, 9, 8) ::
Row (1, 2, 2, 2, 4) ::
Row (2, 4, 4, 4, 8) ::
Row (2, 9, 9, 9, 18) :: Nil

checkAnswer(sqlRollUp, expected)

checkAnswer(
testData.rollup("a", "b").agg(sum("b"), max("b"), min($"b" + $"b")),
expected
)
}

test("cube: aggregation input parameters overlap with the non-attribute expressions in group by") {
val sqlCube = sql(
"""
SELECT a + b, b, sum(a + b) as ab, (a + b) as c FROM mytable GROUP BY a + b, b WITH CUBE
""".stripMargin)

val res = sqlCube.collect()

val expected =
Row ( null, 2, 3, null) ::
Row ( null, 4, 6, null) ::
Row ( null, 9, 11, null) ::
Row ( null, null, 20, null) ::
Row ( 3, null, 3, 3 ) ::
Row ( 6, null, 6, 6 ) ::
Row (11, null, 11, 11 ) ::
Row ( 3, 2, 3, 3 ) ::
Row ( 6, 4, 6, 6 ) ::
Row (11, 9, 11, 11 ) :: Nil

checkAnswer(sqlCube, expected)

checkAnswer(
testData.rollup("a", "b").agg(sum("b")),
sql("select a, b, sum(b) from mytable group by a, b with rollup").collect()
testData.cube($"a" + $"b", $"b").agg(sum($"a" + $"b"), $"a" + $"b"),
expected
)
}

test("cube") {
test("cube: group by function") {
val sqlCube = sql(
"""
|SELECT a + b, b, sum(a - b) as ab
|FROM mytable
|GROUP BY a + b, b WITH CUBE
""".stripMargin)

val expected =
Row (null, 2, -1) ::
Row (null, 4, -2) ::
Row (null, 9, -7) ::
Row (null, null, -10) ::
Row (3, null, -1) ::
Row (6, null, -2) ::
Row (11, null, -7) ::
Row (3, 2, -1) ::
Row (6, 4, -2) ::
Row (11, 9, -7) :: Nil

checkAnswer(sqlCube, expected)

checkAnswer(
testData.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")),
sql("select a + b, b, sum(a - b) from mytable group by a + b, b with cube").collect()
expected
)
}

test("cube: aggregation function parameters overlap with the group by columns") {
val sqlCube = sql(
"""
|SELECT a, b, sum(b), max(b), min(b+b)
|FROM mytable
|GROUP BY a, b WITH CUBE
""".stripMargin)

val expected =
Row (null, 2, 2, 2, 4) ::
Row (null, 4, 4, 4, 8) ::
Row (null, 9, 9, 9, 18) ::
Row (null, null, 15, 9, 4) ::
Row (1, null, 2, 2, 4) ::
Row (2, null, 13, 9, 8) ::
Row (1, 2, 2, 2, 4) ::
Row (2, 4, 4, 4, 8) ::
Row (2, 9, 9, 9, 18) :: Nil

checkAnswer(sqlCube, expected)

checkAnswer(
testData.cube("a", "b").agg(sum("b")),
sql("select a, b, sum(b) from mytable group by a, b with cube").collect()
testData.cube("a", "b").agg(sum("b"), max("b"), min($"b" + $"b")),
expected
)
}
}