Skip to content
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1129,7 +1129,8 @@ class Analyzer(
case sa @ Sort(_, _, AnalysisBarrier(child: Aggregate)) => sa
case sa @ Sort(_, _, child: Aggregate) => sa

case s @ Sort(order, _, child) if !s.resolved && child.resolved =>
case s @ Sort(order, _, child)
if (!s.resolved || s.missingInput.nonEmpty) && child.resolved =>
val (newOrder, newChild) = resolveExprsAndAddMissingAttrs(order, child)
val ordering = newOrder.map(_.asInstanceOf[SortOrder])
if (child.output == newChild.output) {
Expand All @@ -1140,7 +1141,7 @@ class Analyzer(
Project(child.output, newSort)
}

case f @ Filter(cond, child) if !f.resolved && child.resolved =>
case f @ Filter(cond, child) if (!f.resolved || f.missingInput.nonEmpty) && child.resolved =>
val (newCond, newChild) = resolveExprsAndAddMissingAttrs(Seq(cond), child)
if (child.output == newChild.output) {
f.copy(condition = newCond.head)
Expand All @@ -1151,10 +1152,17 @@ class Analyzer(
}
}

/**
* This method tries to resolve expressions and find missing attributes recursively. Specially,
* when the expressions used in `Sort` or `Filter` contain unresolved attributes or resolved
* attributes which are missed from child output. This method tries to find the missing
* attributes out and add into the projection.
*/
private def resolveExprsAndAddMissingAttrs(
exprs: Seq[Expression], plan: LogicalPlan): (Seq[Expression], LogicalPlan) = {
if (exprs.forall(_.resolved)) {
// All given expressions are resolved, no need to continue anymore.
// Missing attributes can be unresolved attributes or resolved attributes which are not in
// the output attributes of the plan.
if (exprs.forall(e => e.resolved && e.references.subsetOf(plan.outputSet))) {
(exprs, plan)
} else {
plan match {
Expand All @@ -1165,15 +1173,19 @@ class Analyzer(
(newExprs, AnalysisBarrier(newChild))

case p: Project =>
// Resolving expressions against current plan.
val maybeResolvedExprs = exprs.map(resolveExpression(_, p))
// Recursively resolving expressions on the child of current plan.
val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, p.child)
val missingAttrs = AttributeSet(newExprs) -- AttributeSet(maybeResolvedExprs)
// If some attributes used by expressions are resolvable only on the rewritten child
// plan, we need to add them into original projection.
val missingAttrs = (AttributeSet(newExprs) -- p.outputSet).intersect(newChild.outputSet)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a second time, but we need to fix in Aggregate case? The logic seems completely different. Or can we remove Aggregate case if ResolveAggregateFunctions can handle this? I don't think we have any reason to keep a wrong logic.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I think it's better to have a re-producible test case before changing Aggregate case. I'm trying to create a test case for it. Then it can be more confident to change Aggregate case.

Actually I found another place we need to fix. Seems we don't have enough test coverage for similar features.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic gets convoluted here and we need to add comments. Basically we need to explain when we should expand the project list.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if we do not do the .intersect(newChild.outputSet)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this intersect, some tests fail, e.g., group-analytics.sql in SQLQueryTestSuite. Some attributes are resolved on parent plans, not on child plans. We can't add them as missing attributes here.

(newExprs, Project(p.projectList ++ missingAttrs, newChild))

case a @ Aggregate(groupExprs, aggExprs, child) =>
val maybeResolvedExprs = exprs.map(resolveExpression(_, a))
val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, child)
val missingAttrs = AttributeSet(newExprs) -- AttributeSet(maybeResolvedExprs)
val missingAttrs = (AttributeSet(newExprs) -- a.outputSet).intersect(newChild.outputSet)
if (missingAttrs.forall(attr => groupExprs.exists(_.semanticEquals(attr)))) {
// All the missing attributes are grouping expressions, valid case.
(newExprs, a.copy(aggregateExpressions = aggExprs ++ missingAttrs, child = newChild))
Expand Down Expand Up @@ -1493,7 +1505,11 @@ class Analyzer(

// Try resolving the ordering as though it is in the aggregate clause.
try {
val unresolvedSortOrders = sortOrder.filter(s => !s.resolved || containsAggregate(s))
// If a sort order is unresolved, containing references not in aggregate, or containing
// `AggregateExpression`, we need to push down it to the underlying aggregate operator.
val unresolvedSortOrders = sortOrder.filter { s =>
!s.resolved || !s.references.subsetOf(aggregate.outputSet) || containsAggregate(s)
}
val aliasedOrdering =
unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")())
val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering)
Expand Down
21 changes: 21 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2387,4 +2387,25 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val mapWithBinaryKey = map(lit(Array[Byte](1.toByte)), lit(1))
checkAnswer(spark.range(1).select(mapWithBinaryKey.getItem(Array[Byte](1.toByte))), Row(1))
}

test("SPARK-24781: Using a reference from Dataset in Filter/Sort might not work") {
val df = Seq(("test1", 0), ("test2", 1)).toDF("name", "id")
val filter1 = df.select(df("name")).filter(df("id") === 0)
val filter2 = df.select(col("name")).filter(col("id") === 0)
checkAnswer(filter1, filter2.collect())

val sort1 = df.select(df("name")).orderBy(df("id"))
val sort2 = df.select(col("name")).orderBy(col("id"))
checkAnswer(sort1, sort2.collect())

withSQLConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS.key -> "false") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test case should be split to two.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will update it in next commit.

val aggPlusSort1 = df.groupBy(df("name")).agg(count(df("name"))).orderBy(df("name"))
val aggPlusSort2 = df.groupBy(col("name")).agg(count(col("name"))).orderBy(col("name"))
checkAnswer(aggPlusSort1, aggPlusSort2.collect())

val aggPlusFilter1 = df.groupBy(df("name")).agg(count(df("name"))).filter(df("name") === 0)
val aggPlusFilter2 = df.groupBy(col("name")).agg(count(col("name"))).filter(col("name") === 0)
checkAnswer(aggPlusFilter1, aggPlusFilter2.collect())
}
}
}