Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1003,18 +1003,30 @@ class Analyzer(
*/
object ResolveAggAliasInGroupBy extends Rule[LogicalPlan] {

// This is a strict check though, we put this to apply the rule only in alias expressions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... only if the expression is not resolvable by child

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

private def notResolvableByChild(attrName: String, child: LogicalPlan): Boolean =
!child.output.exists(a => resolver(a.name, attrName))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: style

    private def notResolvableByChild(attrName: String, child: LogicalPlan): Boolean = {
      !child.output.exists(a => resolver(a.name, attrName))
    }

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Fixed.


override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case agg @ Aggregate(groups, aggs, child)
if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) &&
groups.exists(_.isInstanceOf[UnresolvedAttribute]) =>
// This is a strict check though, we put this to apply the rule only in alias expressions
def notResolvableByChild(attrName: String): Boolean =
!child.output.exists(a => resolver(a.name, attrName))
agg.copy(groupingExpressions = groups.map {
case u: UnresolvedAttribute if notResolvableByChild(u.name) =>
groups.exists(!_.resolved) =>
agg.copy(groupingExpressions = groups.map { _.transform {
case u: UnresolvedAttribute if notResolvableByChild(u.name, child) =>
aggs.find(ne => resolver(ne.name, u.name)).getOrElse(u)
}
})

case gs @ GroupingSets(selectedGroups, groups, child, aggs)
if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) &&
(selectedGroups :+ groups).exists(_.exists(_.isInstanceOf[UnresolvedAttribute])) =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

groups should cover selectedGroups. So we may not need to add selectedGroups here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aha, I see. It looks reasonable. I'll update. Thanks!

Copy link
Member

@viirya viirya May 12, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we sure that grouping expressions are all pure attributes? If not, this check might fail.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh. nvm. It is of course.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, thanks!

def mayResolveAttrByAggregateExprs(exprs: Seq[Expression]): Seq[Expression] = exprs.map {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should do exprs.map { _.transform { ... like above.

case u: UnresolvedAttribute if notResolvableByChild(u.name, child) =>
aggs.find(ne => resolver(ne.name, u.name)).getOrElse(u)
case e => e
})
}
gs.copy(selectedGroupByExprs = selectedGroups.map(mayResolveAttrByAggregateExprs),
groupByExprs = mayResolveAttrByAggregateExprs(groups))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,7 @@ case class Expand(
* We will transform GROUPING SETS into logical plan Aggregate(.., Expand) in Analyzer
*
* @param selectedGroupByExprs A sequence of selected GroupBy expressions, all exprs should
* exists in groupByExprs.
* exist in groupByExprs.
* @param groupByExprs The Group By expressions candidates.
* @param child Child operator
* @param aggregations The Aggregation expressions, those non selected group by expressions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,9 @@ SELECT course, year, GROUPING_ID(course, year) FROM courseSales GROUP BY CUBE(co
ORDER BY GROUPING(course), GROUPING(year), course, year;
SELECT course, year FROM courseSales GROUP BY course, year ORDER BY GROUPING(course);
SELECT course, year FROM courseSales GROUP BY course, year ORDER BY GROUPING_ID(course);
SELECT course, year FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id;
SELECT course, year FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id;

-- Aliases in SELECT could be used in ROLLUP/CUBE/GROUPING SETS
SELECT a + b AS k1, b AS k2, SUM(a - b) FROM testData GROUP BY CUBE(k1, k2);
SELECT a + b AS k, b, SUM(a - b) FROM testData GROUP BY ROLLUP(k, b);
SELECT a + b, b AS k, SUM(a - b) FROM testData GROUP BY a + b, k GROUPING SETS(k)
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 26
-- Number of queries: 29


-- !query 0
Expand Down Expand Up @@ -328,3 +328,50 @@ struct<>
-- !query 25 output
org.apache.spark.sql.AnalysisException
grouping__id is deprecated; use grouping_id() instead;


-- !query 26
SELECT a + b AS k1, b AS k2, SUM(a - b) FROM testData GROUP BY CUBE(k1, k2)
-- !query 26 schema
struct<k1:int,k2:int,sum((a - b)):bigint>
-- !query 26 output
2 1 0
2 NULL 0
3 1 1
3 2 -1
3 NULL 0
4 1 2
4 2 0
4 NULL 2
5 2 1
5 NULL 1
NULL 1 3
NULL 2 0
NULL NULL 3


-- !query 27
SELECT a + b AS k, b, SUM(a - b) FROM testData GROUP BY ROLLUP(k, b)
-- !query 27 schema
struct<k:int,b:int,sum((a - b)):bigint>
-- !query 27 output
2 1 0
2 NULL 0
3 1 1
3 2 -1
3 NULL 0
4 1 2
4 2 0
4 NULL 2
5 2 1
5 NULL 1
NULL NULL 3


-- !query 28
SELECT a + b, b AS k, SUM(a - b) FROM testData GROUP BY a + b, k GROUPING SETS(k)
-- !query 28 schema
struct<(a + b):int,k:int,sum((a - b)):bigint>
-- !query 28 output
NULL 1 3
NULL 2 0