Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,12 @@ object SimpleAnalyzer extends Analyzer(EmptyCatalog, EmptyFunctionRegistry, true
* [[UnresolvedRelation]]s into fully typed objects using information in a schema [[Catalog]] and
* a [[FunctionRegistry]].
*/
class Analyzer(catalog: Catalog,
registry: FunctionRegistry,
caseSensitive: Boolean,
maxIterations: Int = 100)
extends RuleExecutor[LogicalPlan] with HiveTypeCoercion {
class Analyzer(
catalog: Catalog,
registry: FunctionRegistry,
caseSensitive: Boolean,
maxIterations: Int = 100)
extends RuleExecutor[LogicalPlan] with HiveTypeCoercion with CheckAnalysis {

val resolver = if (caseSensitive) caseSensitiveResolution else caseInsensitiveResolution

Expand Down Expand Up @@ -340,19 +341,16 @@ class Analyzer(catalog: Catalog,
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case s @ Sort(ordering, global, p @ Project(projectList, child))
if !s.resolved && p.resolved =>
val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name })
val resolved = unresolved.flatMap(child.resolve(_, resolver))
val requiredAttributes =
AttributeSet(resolved.flatMap(_.collect { case a: Attribute => a }))
val (resolvedOrdering, missing) = resolveAndFindMissing(ordering, p, child)

val missingInProject = requiredAttributes -- p.output
if (missingInProject.nonEmpty) {
// If this rule was not a no-op, return the transformed plan, otherwise return the original.
if (missing.nonEmpty) {
// Add missing attributes and then project them away after the sort.
Project(projectList.map(_.toAttribute),
Sort(ordering, global,
Project(projectList ++ missingInProject, child)))
Project(p.output,
Sort(resolvedOrdering, global,
Project(projectList ++ missing, child)))
} else {
logDebug(s"Failed to find $missingInProject in ${p.output.mkString(", ")}")
logDebug(s"Failed to find $missing in ${p.output.mkString(", ")}")
s // Nothing we can do here. Return original plan.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even there is no missing in project, we still need to build new Sort with resolved SortOrder.

}
case s @ Sort(ordering, global, a @ Aggregate(grouping, aggs, child))
Expand All @@ -364,18 +362,54 @@ class Analyzer(catalog: Catalog,
grouping.collect { case ne: NamedExpression => ne.toAttribute }
)

logDebug(s"Grouping expressions: $groupingRelation")
val resolved = unresolved.flatMap(groupingRelation.resolve(_, resolver))
val missingInAggs = resolved.filterNot(a.outputSet.contains)
logDebug(s"Resolved: $resolved Missing in aggs: $missingInAggs")
if (missingInAggs.nonEmpty) {
val (resolvedOrdering, missing) = resolveAndFindMissing(ordering, a, groupingRelation)

if (missing.nonEmpty) {
// Add missing grouping exprs and then project them away after the sort.
Project(a.output,
Sort(ordering, global, Aggregate(grouping, aggs ++ missingInAggs, child)))
Sort(resolvedOrdering, global,
Aggregate(grouping, aggs ++ missing, child)))
} else {
s // Nothing we can do here. Return original plan.
}
}

/**
* Given a child and a grandchild that are present beneath a sort operator, returns
* a resolved sort ordering and a list of attributes that are missing from the child
* but are present in the grandchild.
*/
def resolveAndFindMissing(
ordering: Seq[SortOrder],
child: LogicalPlan,
grandchild: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = {
// Find any attributes that remain unresolved in the sort.
val unresolved: Seq[String] =
ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name })

// Create a map from name, to resolved attributes, when the desired name can be found
// prior to the projection.
val resolved: Map[String, NamedExpression] =
unresolved.flatMap(u => grandchild.resolve(u, resolver).map(a => u -> a)).toMap

// Construct a set that contains all of the attributes that we need to evaluate the
// ordering.
val requiredAttributes = AttributeSet(resolved.values)

// Figure out which ones are missing from the projection, so that we can add them and
// remove them after the sort.
val missingInProject = requiredAttributes -- child.output

// Now that we have all the attributes we need, reconstruct a resolved ordering.
// It is important to do it here, instead of waiting for the standard resolved as adding
// attributes to the project below can actually introduce ambiquity that was not present
// before.
val resolvedOrdering = ordering.map(_ transform {
case u @ UnresolvedAttribute(name) => resolved.getOrElse(name, u)
}).asInstanceOf[Seq[SortOrder]]

(resolvedOrdering, missingInProject.toSeq)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,31 @@ import org.apache.spark.sql.types._
/**
* Throws user facing errors when passed invalid queries that fail to analyze.
*/
class CheckAnalysis {
trait CheckAnalysis {
self: Analyzer =>

/**
* Override to provide additional checks for correct analysis.
* These rules will be evaluated after our built-in check rules.
*/
val extendedCheckRules: Seq[LogicalPlan => Unit] = Nil

def failAnalysis(msg: String): Nothing = {
protected def failAnalysis(msg: String): Nothing = {
throw new AnalysisException(msg)
}

def apply(plan: LogicalPlan): Unit = {
def checkAnalysis(plan: LogicalPlan): Unit = {
// We transform up and order the rules so as to catch the first possible failure instead
// of the result of cascading resolution failures.
plan.foreachUp {
case operator: LogicalPlan =>
operator transformExpressionsUp {
case a: Attribute if !a.resolved =>
if (operator.childrenResolved) {
// Throw errors for specific problems with get field.
operator.resolveChildren(a.name, resolver, throwErrors = true)
}

val from = operator.inputSet.map(_.name).mkString(", ")
a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ object AttributeSet {
def apply(a: Attribute): AttributeSet = new AttributeSet(Set(new AttributeEquals(a)))

/** Constructs a new [[AttributeSet]] given a sequence of [[Expression Expressions]]. */
def apply(baseSet: Seq[Expression]): AttributeSet = {
def apply(baseSet: Iterable[Expression]): AttributeSet = {
new AttributeSet(
baseSet
.flatMap(_.references)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.types.{ArrayType, StructType, StructField}


abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
Expand Down Expand Up @@ -109,16 +110,22 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
* nodes of this LogicalPlan. The attribute is expressed as
* as string in the following form: `[scope].AttributeName.[nested].[fields]...`.
*/
def resolveChildren(name: String, resolver: Resolver): Option[NamedExpression] =
resolve(name, children.flatMap(_.output), resolver)
def resolveChildren(
name: String,
resolver: Resolver,
throwErrors: Boolean = false): Option[NamedExpression] =
resolve(name, children.flatMap(_.output), resolver, throwErrors)

/**
* Optionally resolves the given string to a [[NamedExpression]] based on the output of this
* LogicalPlan. The attribute is expressed as string in the following form:
* `[scope].AttributeName.[nested].[fields]...`.
*/
def resolve(name: String, resolver: Resolver): Option[NamedExpression] =
resolve(name, output, resolver)
def resolve(
name: String,
resolver: Resolver,
throwErrors: Boolean = false): Option[NamedExpression] =
resolve(name, output, resolver, throwErrors)

/**
* Resolve the given `name` string against the given attribute, returning either 0 or 1 match.
Expand Down Expand Up @@ -162,7 +169,8 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
protected def resolve(
name: String,
input: Seq[Attribute],
resolver: Resolver): Option[NamedExpression] = {
resolver: Resolver,
throwErrors: Boolean): Option[NamedExpression] = {

val parts = name.split("\\.")

Expand Down Expand Up @@ -196,14 +204,19 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {

// One match, but we also need to extract the requested nested field.
case Seq((a, nestedFields)) =>
// The foldLeft adds UnresolvedGetField for every remaining parts of the name,
// and aliased it with the last part of the name.
// For example, consider name "a.b.c", where "a" is resolved to an existing attribute.
// Then this will add UnresolvedGetField("b") and UnresolvedGetField("c"), and alias
// the final expression as "c".
val fieldExprs = nestedFields.foldLeft(a: Expression)(UnresolvedGetField)
val aliasName = nestedFields.last
Some(Alias(fieldExprs, aliasName)())
try {

// The foldLeft adds UnresolvedGetField for every remaining parts of the name,
// and aliased it with the last part of the name.
// For example, consider name "a.b.c", where "a" is resolved to an existing attribute.
// Then this will add UnresolvedGetField("b") and UnresolvedGetField("c"), and alias
// the final expression as "c".
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update comment? Seems we will not add UnresolvedGetField any more.

val fieldExprs = nestedFields.foldLeft(a: Expression)(resolveGetField(_, _, resolver))
val aliasName = nestedFields.last
Some(Alias(fieldExprs, aliasName)())
} catch {
case a: AnalysisException if !throwErrors => None
}

// No matches.
case Seq() =>
Expand All @@ -212,11 +225,46 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {

// More than one match.
case ambiguousReferences =>
val referenceNames = ambiguousReferences.map(_._1.qualifiedName).mkString(", ")
val referenceNames = ambiguousReferences.map(_._1).mkString(", ")
throw new AnalysisException(
s"Reference '$name' is ambiguous, could be: $referenceNames.")
}
}

/**
* Returns the resolved `GetField`, and report error if no desired field or over one
* desired fields are found.
*
* TODO: this code is duplicated from Analyzer and should be refactored to avoid this.
*/
protected def resolveGetField(
expr: Expression,
fieldName: String,
resolver: Resolver): Expression = {
def findField(fields: Array[StructField]): Int = {
val checkField = (f: StructField) => resolver(f.name, fieldName)
val ordinal = fields.indexWhere(checkField)
if (ordinal == -1) {
throw new AnalysisException(
s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}")
} else if (fields.indexWhere(checkField, ordinal + 1) != -1) {
throw new AnalysisException(
s"Ambiguous reference to fields ${fields.filter(checkField).mkString(", ")}")
} else {
ordinal
}
}
expr.dataType match {
case StructType(fields) =>
val ordinal = findField(fields)
StructGetField(expr, fields(ordinal), ordinal)
case ArrayType(StructType(fields), containsNull) =>
val ordinal = findField(fields)
ArrayGetField(expr, fields(ordinal), ordinal, containsNull)
case otherType =>
throw new AnalysisException(s"GetField is not valid on fields of type $otherType")
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,12 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
override val extendedResolutionRules = EliminateSubQueries :: Nil
}

val checkAnalysis = new CheckAnalysis


def caseSensitiveAnalyze(plan: LogicalPlan) =
checkAnalysis(caseSensitiveAnalyzer(plan))
caseSensitiveAnalyzer.checkAnalysis(caseSensitiveAnalyzer(plan))

def caseInsensitiveAnalyze(plan: LogicalPlan) =
checkAnalysis(caseInsensitiveAnalyzer(plan))
caseInsensitiveAnalyzer.checkAnalysis(caseInsensitiveAnalyzer(plan))

val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)())
val testRelation2 = LocalRelation(
Expand All @@ -57,6 +55,21 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
AttributeReference("d", DecimalType.Unlimited)(),
AttributeReference("e", ShortType)())

val nestedRelation = LocalRelation(
AttributeReference("top", StructType(
StructField("duplicateField", StringType) ::
StructField("duplicateField", StringType) ::
StructField("differentCase", StringType) ::
StructField("differentcase", StringType) :: Nil
))())

val nestedRelation2 = LocalRelation(
AttributeReference("top", StructType(
StructField("aField", StringType) ::
StructField("bField", StringType) ::
StructField("cField", StringType) :: Nil
))())

before {
caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation)
caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation)
Expand Down Expand Up @@ -169,6 +182,24 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
"'b'" :: "group by" :: Nil
)

errorTest(
"ambiguous field",
nestedRelation.select($"top.duplicateField"),
"Ambiguous reference to fields" :: "duplicateField" :: Nil,
caseSensitive = false)

errorTest(
"ambiguous field due to case insensitivity",
nestedRelation.select($"top.differentCase"),
"Ambiguous reference to fields" :: "differentCase" :: "differentcase" :: Nil,
caseSensitive = false)

errorTest(
"missing field",
nestedRelation2.select($"top.c"),
"No such struct field" :: "aField" :: "bField" :: "cField" :: Nil,
caseSensitive = false)

case class UnresolvedTestPlan() extends LeafNode {
override lazy val resolved = false
override def output = Nil
Expand Down
14 changes: 5 additions & 9 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ class SQLContext(@transient val sparkContext: SparkContext)
ExtractPythonUdfs ::
sources.PreInsertCastAndRename ::
Nil

override val extendedCheckRules = Seq(
sources.PreWriteCheck(catalog)
)
}

@transient
Expand Down Expand Up @@ -1065,14 +1069,6 @@ class SQLContext(@transient val sparkContext: SparkContext)
Batch("Add exchange", Once, AddExchange(self)) :: Nil
}

@transient
protected[sql] lazy val checkAnalysis = new CheckAnalysis {
override val extendedCheckRules = Seq(
sources.PreWriteCheck(catalog)
)
}


protected[sql] def openSession(): SQLSession = {
detachSession()
val session = createSession()
Expand Down Expand Up @@ -1105,7 +1101,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
@DeveloperApi
protected[sql] class QueryExecution(val logical: LogicalPlan) {
def assertAnalyzed(): Unit = checkAnalysis(analyzed)
def assertAnalyzed(): Unit = analyzer.checkAnalysis(analyzed)

lazy val analyzed: LogicalPlan = analyzer(logical)
lazy val withCachedData: LogicalPlan = {
Expand Down
19 changes: 14 additions & 5 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1084,10 +1084,19 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
test("SPARK-6145: ORDER BY test for nested fields") {
jsonRDD(sparkContext.makeRDD(
"""{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil)).registerTempTable("nestedOrder")
// These should be successfully analyzed
sql("SELECT 1 FROM nestedOrder ORDER BY a.b").queryExecution.analyzed
sql("SELECT a.b FROM nestedOrder ORDER BY a.b").queryExecution.analyzed
sql("SELECT 1 FROM nestedOrder ORDER BY a.a.a").queryExecution.analyzed
sql("SELECT 1 FROM nestedOrder ORDER BY c[0].d").queryExecution.analyzed

checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY a.b"), Row(1))
checkAnswer(sql("SELECT a.b FROM nestedOrder ORDER BY a.b"), Row(1))
checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY a.a.a"), Row(1))
checkAnswer(sql("SELECT a.a.a FROM nestedOrder ORDER BY a.a.a"), Row(1))
checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY c[0].d"), Row(1))
checkAnswer(sql("SELECT c[0].d FROM nestedOrder ORDER BY c[0].d"), Row(1))
}

test("SPARK-6145: special cases") {
jsonRDD(sparkContext.makeRDD(
"""{"a": {"b": [1]}, "b": [{"a": 1}], "c0": {"a": 1}}""" :: Nil)).registerTempTable("t")
checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY c0.a"), Row(1))
checkAnswer(sql("SELECT b[0].a FROM t ORDER BY c0.a"), Row(1))
}
}
Loading