Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
  • Loading branch information
Xiao Li authored and Xiao Li committed Nov 2, 2015
commit b10418e161d5809f3b1de92cf4a33b2f362cd2b4
Original file line number Diff line number Diff line change
Expand Up @@ -208,21 +208,21 @@ class Analyzer(
// We will insert another Projection if the GROUP BY keys contains the
// non-attribute expressions. And the top operators can references those
// expressions by its alias.
// e.g. SELECT key%5 as c1 FROM src GROUP BY key%5 ==>
// SELECT a as c1 FROM (SELECT key%5 AS a FROM src) GROUP BY a
// e.g. SELECT key%5 as c1 FROM src GROUP BY key%5 with rollup ==>
// SELECT a as c1 FROM (SELECT key%5 AS a FROM src) GROUP BY a with rollup

// find all of the non-attribute expressions in the GROUP BY keys
val nonAttributeGroupByExpressions = new ArrayBuffer[Alias]()

// The pair of (the original GROUP BY key, associated attribute)
val groupByExprPairs = x.groupByExprs.map(_ match {
val groupByExprPairs = x.groupByExprs.map {
case e: NamedExpression => (e, e.toAttribute)
case other => {
val alias = Alias(other, other.toString)()
nonAttributeGroupByExpressions += alias // add the non-attributes expression alias
(other, alias.toAttribute)
}
})
}

// substitute the non-attribute expressions for aggregations.
val aggregation = x.aggregations.map(expr => expr.transformDown {
Expand All @@ -232,18 +232,55 @@ class Analyzer(
// substitute the group by expressions.
val newGroupByExprs = groupByExprPairs.map(_._2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Take this for example:

select sum(a-b) as ab from mytable group by b with rollup;

I think we probably need to add extra column(s) for the output of Expand, (e.g. (a-b) in this case). Will that be more simple for this fixing?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @chenghao-intel ,

Could you explain it a little bit more?

So far, this query is correctly processed and returned a correct result. Since b is part of an aggregated function, the fix added extra columns for b. Below is the generated plan:

== Analyzed Logical Plan ==
ab: bigint
Aggregate [b#3,grouping__id#12], [sum(cast((a#2 - b#3#13) as bigint)) AS ab#4L]
 Expand [0,1], [b#3], grouping__id#12
  Project [a#2,b#3,b#3 AS b#3#13]
   Subquery mytable
    Project [_1#0 AS a#2,_2#1 AS b#3]
     LocalRelation [_1#0,_2#1], [[1,2],[2,4],[2,9]]

== Optimized Logical Plan ==
Aggregate [b#3,grouping__id#12], [sum(cast((a#2 - b#3#13) as bigint)) AS ab#4L]
 Expand [0,1], [b#3], grouping__id#12
  LocalRelation [a#2,b#3,b#3#13], [[1,2,2],[2,4,4],[2,9,9]]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, sorry, actually there are so big difference as I mentioned.

But I got error when do the query like below, can you please take look at it?

select sum(a+b) as ab from mytable group by a+b, b with rollup;
15/11/04 17:46:36 ERROR thriftserver.SparkSQLDriver: Failed in [select sum(a+b) as ab from mytable group by a+b, b with rollup]
org.apache.spark.sql.catalyst.analysis.UnresolvedException: Invalid call to dataType on unresolved object, tree: '(cast(a#109 as double) + b#110)
    at org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute.dataType(unresolved.scala:59)
    at org.apache.spark.sql.catalyst.plans.logical.Expand$$anonfun$expand$1$$anonfun$5$$anonfun$apply$3.applyOrElse(basicOperators.scala:291)
    at org.apache.spark.sql.catalyst.plans.logical.Expand$$anonfun$expand$1$$anonfun$5$$anonfun$apply$3.applyOrElse(basicOperators.scala:287)
    at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:227)
    at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:227)
    at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:51)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:226)
    at org.apache.spark.sql.catalyst.plans.logical.Expand$$anonfun$expand$1$$anonfun$5.apply(basicOperators.scala:287)
    at org.apache.spark.sql.catalyst.plans.logical.Expand$$anonfun$expand$1$$anonfun$5.apply(basicOperators.scala:287)
    at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:244)
    at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:244)
    at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
    at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:47)
    at scala.collection.TraversableLike$class.map(TraversableLike.scala:244)
    at scala.collection.AbstractTraversable.map(Traversable.scala:105)
    at org.apache.spark.sql.catalyst.plans.logical.Expand$$anonfun$expand$1.apply(basicOperators.scala:287)
    at org.apache.spark.sql.catalyst.plans.logical.Expand$$anonfun$expand$1.apply(basicOperators.scala:283)
    at scala.collection.immutable.List.foreach(List.scala:318)
    at org.apache.spark.sql.catalyst.plans.logical.Expand.expand(basicOperators.scala:283)
    at org.apache.spark.sql.catalyst.plans.logical.Expand.<init>(basicOperators.scala:254)
    at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveGroupingAnalytics$$anonfun$apply$6.applyOrElse(Analyzer.scala:293)
    at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveGroupingAnalytics$$anonfun$apply$6.applyOrElse(Analyzer.scala:200)
    at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan$$anonfun$resolveOperators$1.apply(LogicalPlan.scala:57)
    at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan$$anonfun$resolveOperators$1.apply(LogicalPlan.scala:57)
    at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:51)
    at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.resolveOperators(LogicalPlan.scala:56)
    at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveGroupingAnalytics$.apply(Analyzer.scala:200)
    at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveGroupingAnalytics$.apply(Analyzer.scala:173)
    at org.apache.spark.sql.catalyst.rules.RuleExecutor$$anonfun$execute$1$$anonfun$apply$1.apply(RuleExecutor.scala:83)
    at org.apache.spark.sql.catalyst.rules.RuleExecutor$$anonfun$execute$1$$anonfun$apply$1.apply(RuleExecutor.scala:80)
    at scala.collection.LinearSeqOptimized$class.foldLeft(LinearSeqOptimized.scala:111)
    at scala.collection.immutable.List.foldLeft(List.scala:84)
    at org.apache.spark.sql.catalyst.rules.RuleExecutor$$anonfun$execute$1.apply(RuleExecutor.scala:80)
    at org.apache.spark.sql.catalyst.rules.RuleExecutor$$anonfun$execute$1.apply(RuleExecutor.scala:72)
    at scala.collection.immutable.List.foreach(List.scala:318)
    at org.apache.spark.sql.catalyst.rules.RuleExecutor.execute(RuleExecutor.scala:72)
    at org.apache.spark.sql.execution.QueryExecution.analyzed$lzycompute(QueryExecution.scala:38)
    at org.apache.spark.sql.execution.QueryExecution.analyzed(QueryExecution.scala:38)
    at org.apache.spark.sql.execution.QueryExecution.assertAnalyzed(QueryExecution.scala:36)
    at org.apache.spark.sql.DataFrame.<init>(DataFrame.scala:132)
    at org.apache.spark.sql.DataFrame$.apply(DataFrame.scala:51)
    at org.apache.spark.sql.SQLContext.sql(SQLContext.scala:784)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chenghao-intel , a good catch! Thank you!

I fixed this issue if you integrate the latest change. Also added two more test cases to cover it.


val child = if (nonAttributeGroupByExpressions.length > 0) {
val child = if (nonAttributeGroupByExpressions.nonEmpty) {
// insert additional projection if contains the
// non-attribute expressions in the GROUP BY keys
Project(x.child.output ++ nonAttributeGroupByExpressions, x.child)
} else {
x.child
}

// We will insert another Projection if the GROUP BY keys are contained in the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe in addition to describing what we do we should explain why we are inserting this projection here as well?

// aggregation. And the top operators can references those keys by its alias.
// e.g. SELECT a, b, sum(a) FROM src GROUP BY a, b with rollup ==>
// SELECT a, b, sum(a1) FROM (SELECT a, b, a AS a1 FROM src) GROUP BY a, b with rollup

// collect all the distinct attributes that are in both aggregation functions and group by clauses
val attrInAggregatedFuncAndGroupBy = aggregation.collect {
case aggFunc: Alias => aggFunc.collect {case a : Attribute if newGroupByExprs.contains(a) => a}
}.flatten.distinct

val alias4AttrInAggregatedFuncAndGroupBy = new ArrayBuffer[Alias]()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems we already know the initial size of this so lets at try and use that to reduce allocations.


// Generate alias for each attribute in attrInAggregatedFuncAndGroupBy
val attrInAggregatedFuncPairs = attrInAggregatedFuncAndGroupBy.map(a => {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is just being used for lookup, could we - do we expect this to always be pretty small?
It also feels a little weird constructing a list of aliases as a side effect inside of the map which is also creating the seq of (name, attribute) but I'm frequently a bit too concerned with functional style bits so maybe this is ok.

val alias = Alias(a, a.toString)()
alias4AttrInAggregatedFuncAndGroupBy += alias
(a, alias.toAttribute)
})

val nonAttributeGroupByExpressionsToAttribute = nonAttributeGroupByExpressions.map(a=>a.toAttribute)

val newAggregation = aggregation.map {
case a : Alias => a.transform {
// must avoid the alias replacement by the first step; otherwise, the following case will fail:
// select a + b, b, sum(a - b) from test group by a + b, b with cube
case e => attrInAggregatedFuncPairs.find(_._1==e && !nonAttributeGroupByExpressionsToAttribute.contains(e)).
map(_._2).getOrElse(e)
}.asInstanceOf[NamedExpression]
case other => other
}

val newChild = if (alias4AttrInAggregatedFuncAndGroupBy.nonEmpty) {
Project(child.output ++ alias4AttrInAggregatedFuncAndGroupBy, child)
} else {
child
}

Aggregate(
newGroupByExprs :+ VirtualColumn.groupingIdAttribute,
aggregation,
Expand(x.bitmasks, newGroupByExprs, gid, child))
newAggregation,
Expand(x.bitmasks, newGroupByExprs, gid, newChild))
}
}

Expand Down