Skip to content

Commit 7f05b1f

Browse files
cloud-fanmarmbrus
authored andcommitted
[SPARK-7067] [SQL] fix bug when use complex nested fields in ORDER BY
This PR is a improvement for #5189. The resolution rule for ORDER BY is: first resolve based on what comes from the select clause and then fall back on its child only when this fails. There are 2 steps. First, try to resolve `Sort` in `ResolveReferences` based on select clause, and ignore exceptions. Second, try to resolve `Sort` in `ResolveSortReferences` and add missing projection. However, the way we resolve `SortOrder` is wrong. We just resolve `UnresolvedAttribute` and use the result to indicate if we can resolve `SortOrder`. But `UnresolvedAttribute` is only part of `GetField` chain(broken by `GetItem`), so we need to go through the whole chain to indicate if we can resolve `SortOrder`. With this change, we can also avoid re-throw GetField exception in `CheckAnalysis` which is little ugly. Author: Wenchen Fan <[email protected]> Closes #5659 from cloud-fan/order-by and squashes the following commits: cfa79f8 [Wenchen Fan] update test 3245d28 [Wenchen Fan] minor improve 465ee07 [Wenchen Fan] address comment 1fc41a2 [Wenchen Fan] fix SPARK-7067
1 parent a411a40 commit 7f05b1f

File tree

5 files changed

+70
-66
lines changed

5 files changed

+70
-66
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 43 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -336,9 +336,15 @@ class Analyzer(
336336
}
337337
j.copy(right = newRight)
338338

339+
// When resolve `SortOrder`s in Sort based on child, don't report errors as
340+
// we still have chance to resolve it based on grandchild
341+
case s @ Sort(ordering, global, child) if child.resolved && !s.resolved =>
342+
val newOrdering = resolveSortOrders(ordering, child, throws = false)
343+
Sort(newOrdering, global, child)
344+
339345
case q: LogicalPlan =>
340346
logTrace(s"Attempting to resolve ${q.simpleString}")
341-
q transformExpressionsUp {
347+
q transformExpressionsUp {
342348
case u @ UnresolvedAttribute(nameParts) if nameParts.length == 1 &&
343349
resolver(nameParts(0), VirtualColumn.groupingIdName) &&
344350
q.isInstanceOf[GroupingAnalytics] =>
@@ -373,6 +379,26 @@ class Analyzer(
373379
exprs.exists(_.collect { case _: Star => true }.nonEmpty)
374380
}
375381

382+
private def resolveSortOrders(ordering: Seq[SortOrder], plan: LogicalPlan, throws: Boolean) = {
383+
ordering.map { order =>
384+
// Resolve SortOrder in one round.
385+
// If throws == false or the desired attribute doesn't exist
386+
// (like try to resolve `a.b` but `a` doesn't exist), fail and return the origin one.
387+
// Else, throw exception.
388+
try {
389+
val newOrder = order transformUp {
390+
case u @ UnresolvedAttribute(nameParts) =>
391+
plan.resolve(nameParts, resolver).getOrElse(u)
392+
case UnresolvedExtractValue(child, fieldName) if child.resolved =>
393+
ExtractValue(child, fieldName, resolver)
394+
}
395+
newOrder.asInstanceOf[SortOrder]
396+
} catch {
397+
case a: AnalysisException if !throws => order
398+
}
399+
}
400+
}
401+
376402
/**
377403
* In many dialects of SQL it is valid to sort by attributes that are not present in the SELECT
378404
* clause. This rule detects such queries and adds the required attributes to the original
@@ -383,13 +409,13 @@ class Analyzer(
383409
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
384410
case s @ Sort(ordering, global, p @ Project(projectList, child))
385411
if !s.resolved && p.resolved =>
386-
val (resolvedOrdering, missing) = resolveAndFindMissing(ordering, p, child)
412+
val (newOrdering, missing) = resolveAndFindMissing(ordering, p, child)
387413

388414
// If this rule was not a no-op, return the transformed plan, otherwise return the original.
389415
if (missing.nonEmpty) {
390416
// Add missing attributes and then project them away after the sort.
391417
Project(p.output,
392-
Sort(resolvedOrdering, global,
418+
Sort(newOrdering, global,
393419
Project(projectList ++ missing, child)))
394420
} else {
395421
logDebug(s"Failed to find $missing in ${p.output.mkString(", ")}")
@@ -404,19 +430,19 @@ class Analyzer(
404430
)
405431

406432
// Find sort attributes that are projected away so we can temporarily add them back in.
407-
val (resolvedOrdering, unresolved) = resolveAndFindMissing(ordering, a, groupingRelation)
433+
val (newOrdering, missingAttr) = resolveAndFindMissing(ordering, a, groupingRelation)
408434

409435
// Find aggregate expressions and evaluate them early, since they can't be evaluated in a
410436
// Sort.
411-
val (withAggsRemoved, aliasedAggregateList) = resolvedOrdering.map {
437+
val (withAggsRemoved, aliasedAggregateList) = newOrdering.map {
412438
case aggOrdering if aggOrdering.collect { case a: AggregateExpression => a }.nonEmpty =>
413439
val aliased = Alias(aggOrdering.child, "_aggOrdering")()
414-
(aggOrdering.copy(child = aliased.toAttribute), aliased :: Nil)
440+
(aggOrdering.copy(child = aliased.toAttribute), Some(aliased))
415441

416-
case other => (other, Nil)
442+
case other => (other, None)
417443
}.unzip
418444

419-
val missing = unresolved ++ aliasedAggregateList.flatten
445+
val missing = missingAttr ++ aliasedAggregateList.flatten
420446

421447
if (missing.nonEmpty) {
422448
// Add missing grouping exprs and then project them away after the sort.
@@ -429,40 +455,25 @@ class Analyzer(
429455
}
430456

431457
/**
432-
* Given a child and a grandchild that are present beneath a sort operator, returns
433-
* a resolved sort ordering and a list of attributes that are missing from the child
434-
* but are present in the grandchild.
458+
* Given a child and a grandchild that are present beneath a sort operator, try to resolve
459+
* the sort ordering and returns it with a list of attributes that are missing from the
460+
* child but are present in the grandchild.
435461
*/
436462
def resolveAndFindMissing(
437463
ordering: Seq[SortOrder],
438464
child: LogicalPlan,
439465
grandchild: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = {
440-
// Find any attributes that remain unresolved in the sort.
441-
val unresolved: Seq[Seq[String]] =
442-
ordering.flatMap(_.collect { case UnresolvedAttribute(nameParts) => nameParts })
443-
444-
// Create a map from name, to resolved attributes, when the desired name can be found
445-
// prior to the projection.
446-
val resolved: Map[Seq[String], NamedExpression] =
447-
unresolved.flatMap(u => grandchild.resolve(u, resolver).map(a => u -> a)).toMap
448-
466+
val newOrdering = resolveSortOrders(ordering, grandchild, throws = true)
449467
// Construct a set that contains all of the attributes that we need to evaluate the
450468
// ordering.
451-
val requiredAttributes = AttributeSet(resolved.values)
452-
469+
val requiredAttributes = AttributeSet(newOrdering.filter(_.resolved))
453470
// Figure out which ones are missing from the projection, so that we can add them and
454471
// remove them after the sort.
455472
val missingInProject = requiredAttributes -- child.output
456-
457-
// Now that we have all the attributes we need, reconstruct a resolved ordering.
458-
// It is important to do it here, instead of waiting for the standard resolved as adding
459-
// attributes to the project below can actually introduce ambiquity that was not present
460-
// before.
461-
val resolvedOrdering = ordering.map(_ transform {
462-
case u @ UnresolvedAttribute(name) => resolved.getOrElse(name, u)
463-
}).asInstanceOf[Seq[SortOrder]]
464-
465-
(resolvedOrdering, missingInProject.toSeq)
473+
// It is important to return the new SortOrders here, instead of waiting for the standard
474+
// resolving process as adding attributes to the project below can actually introduce
475+
// ambiguity that was not present before.
476+
(newOrdering, missingInProject.toSeq)
466477
}
467478
}
468479

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,6 @@ trait CheckAnalysis {
5151
case operator: LogicalPlan =>
5252
operator transformExpressionsUp {
5353
case a: Attribute if !a.resolved =>
54-
if (operator.childrenResolved) {
55-
a match {
56-
case UnresolvedAttribute(nameParts) =>
57-
// Throw errors for specific problems with get field.
58-
operator.resolveChildren(nameParts, resolver, throwErrors = true)
59-
}
60-
}
61-
6254
val from = operator.inputSet.map(_.name).mkString(", ")
6355
a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from")
6456

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -50,19 +50,19 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
5050
* [[org.apache.spark.sql.catalyst.analysis.UnresolvedRelation UnresolvedRelation]]
5151
* should return `false`).
5252
*/
53-
lazy val resolved: Boolean = !expressions.exists(!_.resolved) && childrenResolved
53+
lazy val resolved: Boolean = expressions.forall(_.resolved) && childrenResolved
5454

5555
override protected def statePrefix = if (!resolved) "'" else super.statePrefix
5656

5757
/**
5858
* Returns true if all its children of this query plan have been resolved.
5959
*/
60-
def childrenResolved: Boolean = !children.exists(!_.resolved)
60+
def childrenResolved: Boolean = children.forall(_.resolved)
6161

6262
/**
6363
* Returns true when the given logical plan will return the same results as this logical plan.
6464
*
65-
* Since its likely undecideable to generally determine if two given plans will produce the same
65+
* Since its likely undecidable to generally determine if two given plans will produce the same
6666
* results, it is okay for this function to return false, even if the results are actually
6767
* the same. Such behavior will not affect correctness, only the application of performance
6868
* enhancements like caching. However, it is not acceptable to return true if the results could
@@ -111,9 +111,8 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
111111
*/
112112
def resolveChildren(
113113
nameParts: Seq[String],
114-
resolver: Resolver,
115-
throwErrors: Boolean = false): Option[NamedExpression] =
116-
resolve(nameParts, children.flatMap(_.output), resolver, throwErrors)
114+
resolver: Resolver): Option[NamedExpression] =
115+
resolve(nameParts, children.flatMap(_.output), resolver)
117116

118117
/**
119118
* Optionally resolves the given strings to a [[NamedExpression]] based on the output of this
@@ -122,9 +121,8 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
122121
*/
123122
def resolve(
124123
nameParts: Seq[String],
125-
resolver: Resolver,
126-
throwErrors: Boolean = false): Option[NamedExpression] =
127-
resolve(nameParts, output, resolver, throwErrors)
124+
resolver: Resolver): Option[NamedExpression] =
125+
resolve(nameParts, output, resolver)
128126

129127
/**
130128
* Given an attribute name, split it to name parts by dot, but
@@ -134,7 +132,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
134132
def resolveQuoted(
135133
name: String,
136134
resolver: Resolver): Option[NamedExpression] = {
137-
resolve(parseAttributeName(name), resolver, true)
135+
resolve(parseAttributeName(name), output, resolver)
138136
}
139137

140138
/**
@@ -219,8 +217,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
219217
protected def resolve(
220218
nameParts: Seq[String],
221219
input: Seq[Attribute],
222-
resolver: Resolver,
223-
throwErrors: Boolean): Option[NamedExpression] = {
220+
resolver: Resolver): Option[NamedExpression] = {
224221

225222
// A sequence of possible candidate matches.
226223
// Each candidate is a tuple. The first element is a resolved attribute, followed by a list
@@ -254,19 +251,15 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
254251

255252
// One match, but we also need to extract the requested nested field.
256253
case Seq((a, nestedFields)) =>
257-
try {
258-
// The foldLeft adds GetFields for every remaining parts of the identifier,
259-
// and aliases it with the last part of the identifier.
260-
// For example, consider "a.b.c", where "a" is resolved to an existing attribute.
261-
// Then this will add GetField("c", GetField("b", a)), and alias
262-
// the final expression as "c".
263-
val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) =>
264-
ExtractValue(expr, Literal(fieldName), resolver))
265-
val aliasName = nestedFields.last
266-
Some(Alias(fieldExprs, aliasName)())
267-
} catch {
268-
case a: AnalysisException if !throwErrors => None
269-
}
254+
// The foldLeft adds ExtractValues for every remaining parts of the identifier,
255+
// and aliases it with the last part of the identifier.
256+
// For example, consider "a.b.c", where "a" is resolved to an existing attribute.
257+
// Then this will add ExtractValue("c", ExtractValue("b", a)), and alias
258+
// the final expression as "c".
259+
val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) =>
260+
ExtractValue(expr, Literal(fieldName), resolver))
261+
val aliasName = nestedFields.last
262+
Some(Alias(fieldExprs, aliasName)())
270263

271264
// No matches.
272265
case Seq() =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
285285
* @param rule the function use to transform this nodes children
286286
*/
287287
def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = {
288-
val afterRuleOnChildren = transformChildrenUp(rule);
288+
val afterRuleOnChildren = transformChildrenUp(rule)
289289
if (this fastEquals afterRuleOnChildren) {
290290
CurrentOrigin.withOrigin(origin) {
291291
rule.applyOrElse(this, identity[BaseType])

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1440,4 +1440,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
14401440
checkAnswer(sql("select i <=> b from t"), sql("select r2 from t"))
14411441
}
14421442
}
1443+
1444+
test("SPARK-7067: order by queries for complex ExtractValue chain") {
1445+
withTempTable("t") {
1446+
sqlContext.read.json(sqlContext.sparkContext.makeRDD(
1447+
"""{"a": {"b": [{"c": 1}]}, "b": [{"d": 1}]}""" :: Nil)).registerTempTable("t")
1448+
checkAnswer(sql("SELECT a.b FROM t ORDER BY b[0].d"), Row(Seq(Row(1))))
1449+
}
1450+
}
14431451
}

0 commit comments

Comments
 (0)