Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class Analyzer(
ResolveGroupingAnalytics ::
ResolvePivot ::
ResolveOrdinalInOrderByAndGroupBy ::
ResolveSortReferences ::
ResolveMissingReferences ::
ResolveGenerate ::
ResolveFunctions ::
ResolveAliases ::
Expand Down Expand Up @@ -228,21 +228,56 @@ class Analyzer(
Seq.tabulate(1 << c.groupByExprs.length)(i => i)
}

private def hasGroupingId(expr: Seq[Expression]): Boolean = {
expr.exists(_.collectFirst {
case u: UnresolvedAttribute if resolver(u.name, VirtualColumn.groupingIdName) => u
}.isDefined)
private def hasGroupingAttribute(expr: Expression): Boolean = {
expr.collectFirst {
case u: UnresolvedAttribute if resolver(u.name, VirtualColumn.hiveGroupingIdName) => u
}.isDefined
}

def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
private def hasGroupingFunction(e: Expression): Boolean = {
e.collectFirst {
case g: Grouping => g
case g: GroupingID => g
}.isDefined
}

private def replaceGroupingFunc(
expr: Expression,
groupByExprs: Seq[Expression],
gid: Expression): Expression = {
expr transform {
case e: GroupingID =>
if (e.groupByExprs.isEmpty || e.groupByExprs == groupByExprs) {
gid
} else {
throw new AnalysisException(
s"Columns of grouping_id (${e.groupByExprs.mkString(",")}) does not match " +
s"grouping columns (${groupByExprs.mkString(",")})")
}
case Grouping(col: Expression) =>
val idx = groupByExprs.indexOf(col)
if (idx >= 0) {
Cast(BitwiseAnd(ShiftRight(gid, Literal(groupByExprs.length - 1 - idx)),
Literal(1)), ByteType)
} else {
throw new AnalysisException(s"Column of grouping ($col) can't be found " +
s"in grouping columns ${groupByExprs.mkString(",")}")
}
}
}

// This require transformUp to replace grouping()/grouping_id() in resolved Filter/Sort
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case a if !a.childrenResolved => a // be sure all of the children are resolved.
case p if p.expressions.exists(hasGroupingAttribute) =>
failAnalysis(
s"${VirtualColumn.hiveGroupingIdName} is deprecated; use grouping_id() instead")

case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) =>
GroupingSets(bitmasks(c), groupByExprs, child, aggregateExpressions)
case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child) =>
GroupingSets(bitmasks(r), groupByExprs, child, aggregateExpressions)
case g: GroupingSets if g.expressions.exists(!_.resolved) && hasGroupingId(g.expressions) =>
failAnalysis(
s"${VirtualColumn.groupingIdName} is deprecated; use grouping_id() instead")

// Ensure all the expressions have been resolved.
case x: GroupingSets if x.expressions.forall(_.resolved) =>
val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)()
Expand Down Expand Up @@ -270,31 +305,14 @@ class Analyzer(
def isPartOfAggregation(e: Expression): Boolean = {
aggsBuffer.exists(a => a.find(_ eq e).isDefined)
}
expr.transformDown {
replaceGroupingFunc(expr, x.groupByExprs, gid).transformDown {
// AggregateExpression should be computed on the unmodified value of its argument
// expressions, so we should not replace any references to grouping expression
// inside it.
case e: AggregateExpression =>
aggsBuffer += e
e
case e if isPartOfAggregation(e) => e
case e: GroupingID =>
if (e.groupByExprs.isEmpty || e.groupByExprs == x.groupByExprs) {
gid
} else {
throw new AnalysisException(
s"Columns of grouping_id (${e.groupByExprs.mkString(",")}) does not match " +
s"grouping columns (${x.groupByExprs.mkString(",")})")
}
case Grouping(col: Expression) =>
val idx = x.groupByExprs.indexOf(col)
if (idx >= 0) {
Cast(BitwiseAnd(ShiftRight(gid, Literal(x.groupByExprs.length - 1 - idx)),
Literal(1)), ByteType)
} else {
throw new AnalysisException(s"Column of grouping ($col) can't be found " +
s"in grouping columns ${x.groupByExprs.mkString(",")}")
}
case e =>
val index = groupByAliases.indexWhere(_.child.semanticEquals(e))
if (index == -1) {
Expand All @@ -306,9 +324,37 @@ class Analyzer(
}

Aggregate(
groupByAttributes :+ VirtualColumn.groupingIdAttribute,
groupByAttributes :+ gid,
aggregations,
Expand(x.bitmasks, groupByAliases, groupByAttributes, gid, x.child))

case f @ Filter(cond, child) if hasGroupingFunction(cond) =>
val groupingExprs = findGroupingExprs(child)
// The unresolved grouping id will be resolved by ResolveMissingReferences
val newCond = replaceGroupingFunc(cond, groupingExprs, VirtualColumn.groupingIdAttribute)
f.copy(condition = newCond)

case s @ Sort(order, _, child) if order.exists(hasGroupingFunction) =>
val groupingExprs = findGroupingExprs(child)
val gid = VirtualColumn.groupingIdAttribute
// The unresolved grouping id will be resolved by ResolveMissingReferences
val newOrder = order.map(replaceGroupingFunc(_, groupingExprs, gid).asInstanceOf[SortOrder])
s.copy(order = newOrder)
}

private def findGroupingExprs(plan: LogicalPlan): Seq[Expression] = {
plan.collectFirst {
case a: Aggregate =>
// this Aggregate should have grouping id as the last grouping key.
val gid = a.groupingExpressions.last
if (!gid.isInstanceOf[AttributeReference]
|| gid.asInstanceOf[AttributeReference].name != VirtualColumn.groupingIdName) {
failAnalysis(s"grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")
}
a.groupingExpressions.take(a.groupingExpressions.length - 1)
}.getOrElse {
failAnalysis(s"grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")
}
}
}

Expand Down Expand Up @@ -663,13 +709,15 @@ class Analyzer(
* clause. This rule detects such queries and adds the required attributes to the original
* projection, so that they will be available during sorting. Another projection is added to
* remove these attributes after sorting.
*
* The HAVING clause could also used a grouping columns that is not presented in the SELECT.
*/
object ResolveSortReferences extends Rule[LogicalPlan] {
object ResolveMissingReferences extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
// Skip sort with aggregate. This will be handled in ResolveAggregateFunctions
case sa @ Sort(_, _, child: Aggregate) => sa

case s @ Sort(order, _, child) if !s.resolved && child.resolved =>
case s @ Sort(order, _, child) if child.resolved =>
try {
val newOrder = order.map(resolveExpressionRecursively(_, child).asInstanceOf[SortOrder])
val requiredAttrs = AttributeSet(newOrder).filter(_.resolved)
Expand All @@ -689,6 +737,26 @@ class Analyzer(
// in Sort
case ae: AnalysisException => s
}

case f @ Filter(cond, child) if child.resolved =>
try {
val newCond = resolveExpressionRecursively(cond, child)
val requiredAttrs = newCond.references.filter(_.resolved)
val missingAttrs = requiredAttrs -- child.outputSet
if (missingAttrs.nonEmpty) {
// Add missing attributes and then project them away.
Project(child.output,
Filter(newCond, addMissingAttr(child, missingAttrs)))
} else if (newCond != cond) {
f.copy(condition = newCond)
} else {
f
}
} catch {
// Attempting to resolve it might fail. When this happens, return the original plan.
// Users will see an AnalysisException for resolution failure of missing attributes
case ae: AnalysisException => f
}
}

/**
Expand Down Expand Up @@ -843,27 +911,33 @@ class Analyzer(
if aggregate.resolved =>

// Try resolving the condition of the filter as though it is in the aggregate clause
val aggregatedCondition =
Aggregate(
grouping,
Alias(havingCondition, "havingCondition")(isGenerated = true) :: Nil,
child)
val resolvedOperator = execute(aggregatedCondition)
def resolvedAggregateFilter =
resolvedOperator
.asInstanceOf[Aggregate]
.aggregateExpressions.head

// If resolution was successful and we see the filter has an aggregate in it, add it to
// the original aggregate operator.
if (resolvedOperator.resolved && containsAggregate(resolvedAggregateFilter)) {
val aggExprsWithHaving = resolvedAggregateFilter +: originalAggExprs

Project(aggregate.output,
Filter(resolvedAggregateFilter.toAttribute,
aggregate.copy(aggregateExpressions = aggExprsWithHaving)))
} else {
filter
try {
val aggregatedCondition =
Aggregate(
grouping,
Alias(havingCondition, "havingCondition")(isGenerated = true) :: Nil,
child)
val resolvedOperator = execute(aggregatedCondition)
def resolvedAggregateFilter =
resolvedOperator
.asInstanceOf[Aggregate]
.aggregateExpressions.head

// If resolution was successful and we see the filter has an aggregate in it, add it to
// the original aggregate operator.
if (resolvedOperator.resolved && containsAggregate(resolvedAggregateFilter)) {
val aggExprsWithHaving = resolvedAggregateFilter +: originalAggExprs

Project(aggregate.output,
Filter(resolvedAggregateFilter.toAttribute,
aggregate.copy(aggregateExpressions = aggExprsWithHaving)))
} else {
filter
}
} catch {
// Attempting to resolve in the aggregate can result in ambiguity. When this happens,
// just return the original plan.
case ae: AnalysisException => filter
}

case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved =>
Expand Down Expand Up @@ -927,11 +1001,8 @@ class Analyzer(
}
}

private def isAggregateExpression(e: Expression): Boolean = {
e.isInstanceOf[AggregateExpression] || e.isInstanceOf[Grouping] || e.isInstanceOf[GroupingID]
}
def containsAggregate(condition: Expression): Boolean = {
condition.find(isAggregateExpression).isDefined
condition.find(_.isInstanceOf[AggregateExpression]).isDefined
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,8 @@ case class PrettyAttribute(
}

object VirtualColumn {
val groupingIdName: String = "grouping__id"
// The attribute name used by Hive, which has different result than Spark, deprecated.
val hiveGroupingIdName: String = "grouping__id"
val groupingIdName: String = "spark_grouping_id"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain what's going on here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"grouping__id" came from Hive, but unfortunately the implementation is wrong, see https://issues.apache.org/jira/browse/HIVE-12833. So we deprecated to favor the standard function grouping_id() as public API. "spark_grouping_id" is the virtual column only used internally.

val groupingIdAttribute: UnresolvedAttribute = UnresolvedAttribute(groupingIdName)
}
82 changes: 82 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2230,6 +2230,88 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
assert(error.getMessage contains "grouping__id is deprecated; use grouping_id() instead")
}

test("grouping and grouping_id in having") {
checkAnswer(
sql("select course, year from courseSales group by cube(course, year)" +
" having grouping(year) = 1 and grouping_id(course, year) > 0"),
Row("Java", null) ::
Row("dotNET", null) ::
Row(null, null) :: Nil
)

var error = intercept[AnalysisException] {
sql("select course, year from courseSales group by course, year" +
" having grouping(course) > 0")
}
assert(error.getMessage contains
"grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")
error = intercept[AnalysisException] {
sql("select course, year from courseSales group by course, year" +
" having grouping_id(course, year) > 0")
}
assert(error.getMessage contains
"grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")
error = intercept[AnalysisException] {
sql("select course, year from courseSales group by cube(course, year)" +
" having grouping__id > 0")
}
assert(error.getMessage contains "grouping__id is deprecated; use grouping_id() instead")
}

test("grouping and grouping_id in sort") {
checkAnswer(
sql("select course, year, grouping(course), grouping(year) from courseSales" +
" group by cube(course, year) order by grouping_id(course, year), course, year"),
Row("Java", 2012, 0, 0) ::
Row("Java", 2013, 0, 0) ::
Row("dotNET", 2012, 0, 0) ::
Row("dotNET", 2013, 0, 0) ::
Row("Java", null, 0, 1) ::
Row("dotNET", null, 0, 1) ::
Row(null, 2012, 1, 0) ::
Row(null, 2013, 1, 0) ::
Row(null, null, 1, 1) :: Nil
)

checkAnswer(
sql("select course, year, grouping_id(course, year) from courseSales" +
" group by cube(course, year) order by grouping(course), grouping(year), course, year"),
Row("Java", 2012, 0) ::
Row("Java", 2013, 0) ::
Row("dotNET", 2012, 0) ::
Row("dotNET", 2013, 0) ::
Row("Java", null, 1) ::
Row("dotNET", null, 1) ::
Row(null, 2012, 2) ::
Row(null, 2013, 2) ::
Row(null, null, 3) :: Nil
)

var error = intercept[AnalysisException] {
sql("select course, year from courseSales group by course, year" +
" order by grouping(course)")
}
assert(error.getMessage contains
"grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")
error = intercept[AnalysisException] {
sql("select course, year from courseSales group by course, year" +
" order by grouping_id(course, year)")
}
assert(error.getMessage contains
"grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")
error = intercept[AnalysisException] {
sql("select course, year from courseSales group by cube(course, year)" +
" order by grouping__id")
}
assert(error.getMessage contains "grouping__id is deprecated; use grouping_id() instead")
}

test("filter on a grouping column that is not presented in SELECT") {
checkAnswer(
sql("select count(1) from (select 1 as a) t group by a having a > 0"),
Row(1) :: Nil)
}

test("SPARK-13056: Null in map value causes NPE") {
val df = Seq(1 -> Map("abc" -> "somestring", "cba" -> null)).toDF("key", "value")
withTempTable("maptest") {
Expand Down