Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -336,9 +336,15 @@ class Analyzer(
}
j.copy(right = newRight)

// When resolve `SortOrder`s in Sort based on child, don't report errors as
// we still have chance to resolve it based on grandchild
case s @ Sort(ordering, global, child) if child.resolved && !s.resolved =>
val newOrdering = resolveSortOrders(ordering, child, throws = false)
Sort(newOrdering, global, child)

case q: LogicalPlan =>
logTrace(s"Attempting to resolve ${q.simpleString}")
q transformExpressionsUp {
q transformExpressionsUp {
case u @ UnresolvedAttribute(nameParts) if nameParts.length == 1 &&
resolver(nameParts(0), VirtualColumn.groupingIdName) &&
q.isInstanceOf[GroupingAnalytics] =>
Expand Down Expand Up @@ -373,6 +379,26 @@ class Analyzer(
exprs.exists(_.collect { case _: Star => true }.nonEmpty)
}

private def resolveSortOrders(ordering: Seq[SortOrder], plan: LogicalPlan, throws: Boolean) = {
ordering.map { order =>
// Resolve SortOrder in one round.
// If throws == false or the desired attribute doesn't exist
// (like try to resolve `a.b` but `a` doesn't exist), fail and return the origin one.
// Else, throw exception.
try {
val newOrder = order transformUp {
case u @ UnresolvedAttribute(nameParts) =>
plan.resolve(nameParts, resolver).getOrElse(u)
case UnresolvedExtractValue(child, fieldName) if child.resolved =>
ExtractValue(child, fieldName, resolver)
}
newOrder.asInstanceOf[SortOrder]
} catch {
case a: AnalysisException if !throws => order
}
}
}

/**
* In many dialects of SQL it is valid to sort by attributes that are not present in the SELECT
* clause. This rule detects such queries and adds the required attributes to the original
Expand All @@ -383,13 +409,13 @@ class Analyzer(
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case s @ Sort(ordering, global, p @ Project(projectList, child))
if !s.resolved && p.resolved =>
val (resolvedOrdering, missing) = resolveAndFindMissing(ordering, p, child)
val (newOrdering, missing) = resolveAndFindMissing(ordering, p, child)

// If this rule was not a no-op, return the transformed plan, otherwise return the original.
if (missing.nonEmpty) {
// Add missing attributes and then project them away after the sort.
Project(p.output,
Sort(resolvedOrdering, global,
Sort(newOrdering, global,
Project(projectList ++ missing, child)))
} else {
logDebug(s"Failed to find $missing in ${p.output.mkString(", ")}")
Expand All @@ -404,19 +430,19 @@ class Analyzer(
)

// Find sort attributes that are projected away so we can temporarily add them back in.
val (resolvedOrdering, unresolved) = resolveAndFindMissing(ordering, a, groupingRelation)
val (newOrdering, missingAttr) = resolveAndFindMissing(ordering, a, groupingRelation)

// Find aggregate expressions and evaluate them early, since they can't be evaluated in a
// Sort.
val (withAggsRemoved, aliasedAggregateList) = resolvedOrdering.map {
val (withAggsRemoved, aliasedAggregateList) = newOrdering.map {
case aggOrdering if aggOrdering.collect { case a: AggregateExpression => a }.nonEmpty =>
val aliased = Alias(aggOrdering.child, "_aggOrdering")()
(aggOrdering.copy(child = aliased.toAttribute), aliased :: Nil)
(aggOrdering.copy(child = aliased.toAttribute), Some(aliased))

case other => (other, Nil)
case other => (other, None)
}.unzip

val missing = unresolved ++ aliasedAggregateList.flatten
val missing = missingAttr ++ aliasedAggregateList.flatten

if (missing.nonEmpty) {
// Add missing grouping exprs and then project them away after the sort.
Expand All @@ -429,40 +455,25 @@ class Analyzer(
}

/**
* Given a child and a grandchild that are present beneath a sort operator, returns
* a resolved sort ordering and a list of attributes that are missing from the child
* but are present in the grandchild.
* Given a child and a grandchild that are present beneath a sort operator, try to resolve
* the sort ordering and returns it with a list of attributes that are missing from the
* child but are present in the grandchild.
*/
def resolveAndFindMissing(
ordering: Seq[SortOrder],
child: LogicalPlan,
grandchild: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = {
// Find any attributes that remain unresolved in the sort.
val unresolved: Seq[Seq[String]] =
ordering.flatMap(_.collect { case UnresolvedAttribute(nameParts) => nameParts })

// Create a map from name, to resolved attributes, when the desired name can be found
// prior to the projection.
val resolved: Map[Seq[String], NamedExpression] =
unresolved.flatMap(u => grandchild.resolve(u, resolver).map(a => u -> a)).toMap

val newOrdering = resolveSortOrders(ordering, grandchild, throws = true)
// Construct a set that contains all of the attributes that we need to evaluate the
// ordering.
val requiredAttributes = AttributeSet(resolved.values)

val requiredAttributes = AttributeSet(newOrdering.filter(_.resolved))
// Figure out which ones are missing from the projection, so that we can add them and
// remove them after the sort.
val missingInProject = requiredAttributes -- child.output

// Now that we have all the attributes we need, reconstruct a resolved ordering.
// It is important to do it here, instead of waiting for the standard resolved as adding
// attributes to the project below can actually introduce ambiquity that was not present
// before.
val resolvedOrdering = ordering.map(_ transform {
case u @ UnresolvedAttribute(name) => resolved.getOrElse(name, u)
}).asInstanceOf[Seq[SortOrder]]

(resolvedOrdering, missingInProject.toSeq)
// It is important to return the new SortOrders here, instead of waiting for the standard
// resolving process as adding attributes to the project below can actually introduce
// ambiguity that was not present before.
(newOrdering, missingInProject.toSeq)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,6 @@ trait CheckAnalysis {
case operator: LogicalPlan =>
operator transformExpressionsUp {
case a: Attribute if !a.resolved =>
if (operator.childrenResolved) {
a match {
case UnresolvedAttribute(nameParts) =>
// Throw errors for specific problems with get field.
operator.resolveChildren(nameParts, resolver, throwErrors = true)
}
}

val from = operator.inputSet.map(_.name).mkString(", ")
a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,19 +50,19 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
* [[org.apache.spark.sql.catalyst.analysis.UnresolvedRelation UnresolvedRelation]]
* should return `false`).
*/
lazy val resolved: Boolean = !expressions.exists(!_.resolved) && childrenResolved
lazy val resolved: Boolean = expressions.forall(_.resolved) && childrenResolved

override protected def statePrefix = if (!resolved) "'" else super.statePrefix

/**
* Returns true if all its children of this query plan have been resolved.
*/
def childrenResolved: Boolean = !children.exists(!_.resolved)
def childrenResolved: Boolean = children.forall(_.resolved)

/**
* Returns true when the given logical plan will return the same results as this logical plan.
*
* Since its likely undecideable to generally determine if two given plans will produce the same
* Since its likely undecidable to generally determine if two given plans will produce the same
* results, it is okay for this function to return false, even if the results are actually
* the same. Such behavior will not affect correctness, only the application of performance
* enhancements like caching. However, it is not acceptable to return true if the results could
Expand Down Expand Up @@ -111,9 +111,8 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
*/
def resolveChildren(
nameParts: Seq[String],
resolver: Resolver,
throwErrors: Boolean = false): Option[NamedExpression] =
resolve(nameParts, children.flatMap(_.output), resolver, throwErrors)
resolver: Resolver): Option[NamedExpression] =
resolve(nameParts, children.flatMap(_.output), resolver)

/**
* Optionally resolves the given strings to a [[NamedExpression]] based on the output of this
Expand All @@ -122,9 +121,8 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
*/
def resolve(
nameParts: Seq[String],
resolver: Resolver,
throwErrors: Boolean = false): Option[NamedExpression] =
resolve(nameParts, output, resolver, throwErrors)
resolver: Resolver): Option[NamedExpression] =
resolve(nameParts, output, resolver)

/**
* Given an attribute name, split it to name parts by dot, but
Expand All @@ -134,7 +132,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
def resolveQuoted(
name: String,
resolver: Resolver): Option[NamedExpression] = {
resolve(parseAttributeName(name), resolver, true)
resolve(parseAttributeName(name), output, resolver)
}

/**
Expand Down Expand Up @@ -219,8 +217,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
protected def resolve(
nameParts: Seq[String],
input: Seq[Attribute],
resolver: Resolver,
throwErrors: Boolean): Option[NamedExpression] = {
resolver: Resolver): Option[NamedExpression] = {

// A sequence of possible candidate matches.
// Each candidate is a tuple. The first element is a resolved attribute, followed by a list
Expand Down Expand Up @@ -254,19 +251,15 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {

// One match, but we also need to extract the requested nested field.
case Seq((a, nestedFields)) =>
try {
// The foldLeft adds GetFields for every remaining parts of the identifier,
// and aliases it with the last part of the identifier.
// For example, consider "a.b.c", where "a" is resolved to an existing attribute.
// Then this will add GetField("c", GetField("b", a)), and alias
// the final expression as "c".
val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) =>
ExtractValue(expr, Literal(fieldName), resolver))
val aliasName = nestedFields.last
Some(Alias(fieldExprs, aliasName)())
} catch {
case a: AnalysisException if !throwErrors => None
}
// The foldLeft adds ExtractValues for every remaining parts of the identifier,
// and aliases it with the last part of the identifier.
// For example, consider "a.b.c", where "a" is resolved to an existing attribute.
// Then this will add ExtractValue("c", ExtractValue("b", a)), and alias
// the final expression as "c".
val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) =>
ExtractValue(expr, Literal(fieldName), resolver))
val aliasName = nestedFields.last
Some(Alias(fieldExprs, aliasName)())

// No matches.
case Seq() =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
* @param rule the function use to transform this nodes children
*/
def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = {
val afterRuleOnChildren = transformChildrenUp(rule);
val afterRuleOnChildren = transformChildrenUp(rule)
if (this fastEquals afterRuleOnChildren) {
CurrentOrigin.withOrigin(origin) {
rule.applyOrElse(this, identity[BaseType])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1430,4 +1430,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
checkAnswer(sql("select i <=> b from t"), sql("select r2 from t"))
}
}

test("SPARK-7067: order by queries for complex ExtractValue chain") {
withTempTable("t") {
sqlContext.read.json(sqlContext.sparkContext.makeRDD(
"""{"a": {"b": [{"c": 1}]}, "b": [{"d": 1}]}""" :: Nil)).registerTempTable("t")
checkAnswer(sql("SELECT a.b FROM t ORDER BY b[0].d"), Row(Seq(Row(1))))
}
}
}