Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -211,45 +211,35 @@ class Analyzer(
GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations)
case x: GroupingSets =>
val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)()
// We will insert another Projection if the GROUP BY keys contains the
// non-attribute expressions. And the top operators can references those
// expressions by its alias.
// e.g. SELECT key%5 as c1 FROM src GROUP BY key%5 ==>
// SELECT a as c1 FROM (SELECT key%5 AS a FROM src) GROUP BY a

// find all of the non-attribute expressions in the GROUP BY keys
val nonAttributeGroupByExpressions = new ArrayBuffer[Alias]()

// The pair of (the original GROUP BY key, associated attribute)
val groupByExprPairs = x.groupByExprs.map(_ match {
case e: NamedExpression => (e, e.toAttribute)
case other => {
val alias = Alias(other, other.toString)()
nonAttributeGroupByExpressions += alias // add the non-attributes expression alias
(other, alias.toAttribute)
}
})

// substitute the non-attribute expressions for aggregations.
val aggregation = x.aggregations.map(expr => expr.transformDown {
case e => groupByExprPairs.find(_._1.semanticEquals(e)).map(_._2).getOrElse(e)
}.asInstanceOf[NamedExpression])

// substitute the group by expressions.
val newGroupByExprs = groupByExprPairs.map(_._2)
// Expand works by setting grouping expressions to null as determined by the bitmasks. To
// prevent these null values from being used in an aggregate instead of the original value
// we need to create new aliases for all group by expressions that will only be used for
// the intended purpose.
val groupByAliases: Seq[Alias] = x.groupByExprs.map {
case e: NamedExpression => Alias(e, e.name)()
case other => Alias(other, other.toString)()
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. We need to distinguish the real grouping expression values and those nulls added by expand. Otherwise, we generate wrong results.


val child = if (nonAttributeGroupByExpressions.length > 0) {
// insert additional projection if contains the
// non-attribute expressions in the GROUP BY keys
Project(x.child.output ++ nonAttributeGroupByExpressions, x.child)
} else {
x.child
val aggregations: Seq[NamedExpression] = x.aggregations.map {
// If an expression is an aggregate (contains a AggregateExpression) then we dont change
// it so that the aggregation is computed on the unmodified value of its argument
// expressions.
case expr if expr.find(_.isInstanceOf[AggregateExpression]).nonEmpty => expr
// If not then its a grouping expression and we need to use the modified (with nulls from
// Expand) value of the expression.
case expr => expr.transformDown {
case e => groupByAliases.find(_.child.semanticEquals(e)).map(_.toAttribute).getOrElse(e)
}.asInstanceOf[NamedExpression]
}

val child = Project(x.child.output ++ groupByAliases, x.child)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, we will rely on our optimizer to remove this Project if it is not necessary, right?

val groupByAttributes = groupByAliases.map(_.toAttribute)

Aggregate(
newGroupByExprs :+ VirtualColumn.groupingIdAttribute,
aggregation,
Expand(x.bitmasks, newGroupByExprs, gid, child))
groupByAttributes :+ VirtualColumn.groupingIdAttribute,
aggregations,
Expand(x.bitmasks, groupByAttributes, gid, child))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,10 @@ trait GroupingAnalytics extends UnaryNode {

override def output: Seq[Attribute] = aggregations.map(_.toAttribute)

// Needs to be unresolved before its translated to Aggregate + Expand because output attributes
// will change in analysis.
override lazy val resolved: Boolean = false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the only change we need to fix this problem?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. This was a (minor) problem before that was not caught by any of the test cases, it's now more necessary since we duplicate all the grouping columns in the analyzer rule.


def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,68 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
)
}

test("rollup") {
checkAnswer(
courseSales.rollup("course", "year").sum("earnings"),
Row("Java", 2012, 20000.0) ::
Row("Java", 2013, 30000.0) ::
Row("Java", null, 50000.0) ::
Row("dotNET", 2012, 15000.0) ::
Row("dotNET", 2013, 48000.0) ::
Row("dotNET", null, 63000.0) ::
Row(null, null, 113000.0) :: Nil
)
}

test("cube") {
checkAnswer(
courseSales.cube("course", "year").sum("earnings"),
Row("Java", 2012, 20000.0) ::
Row("Java", 2013, 30000.0) ::
Row("Java", null, 50000.0) ::
Row("dotNET", 2012, 15000.0) ::
Row("dotNET", 2013, 48000.0) ::
Row("dotNET", null, 63000.0) ::
Row(null, 2012, 35000.0) ::
Row(null, 2013, 78000.0) ::
Row(null, null, 113000.0) :: Nil
)
}

test("rollup overlapping columns") {
checkAnswer(
testData2.rollup($"a" + $"b" as "foo", $"b" as "bar").agg(sum($"a" - $"b") as "foo"),
Row(2, 1, 0) :: Row(3, 2, -1) :: Row(3, 1, 1) :: Row(4, 2, 0) :: Row(4, 1, 2) :: Row(5, 2, 1)
:: Row(2, null, 0) :: Row(3, null, 0) :: Row(4, null, 2) :: Row(5, null, 1)
:: Row(null, null, 3) :: Nil
)

checkAnswer(
testData2.rollup("a", "b").agg(sum("b")),
Row(1, 1, 1) :: Row(1, 2, 2) :: Row(2, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 1) :: Row(3, 2, 2)
:: Row(1, null, 3) :: Row(2, null, 3) :: Row(3, null, 3)
:: Row(null, null, 9) :: Nil
)
}

test("cube overlapping columns") {
checkAnswer(
testData2.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")),
Row(2, 1, 0) :: Row(3, 2, -1) :: Row(3, 1, 1) :: Row(4, 2, 0) :: Row(4, 1, 2) :: Row(5, 2, 1)
:: Row(2, null, 0) :: Row(3, null, 0) :: Row(4, null, 2) :: Row(5, null, 1)
:: Row(null, 1, 3) :: Row(null, 2, 0)
:: Row(null, null, 3) :: Nil
)

checkAnswer(
testData2.cube("a", "b").agg(sum("b")),
Row(1, 1, 1) :: Row(1, 2, 2) :: Row(2, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 1) :: Row(3, 2, 2)
:: Row(1, null, 3) :: Row(2, null, 3) :: Row(3, null, 3)
:: Row(null, 1, 3) :: Row(null, 2, 6)
:: Row(null, null, 9) :: Nil
)
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these * overlapping * cases will fail without the fix, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct

test("spark.sql.retainGroupColumns config") {
checkAnswer(
testData2.groupBy("a").agg(sum($"b")),
Expand Down