Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.sql.catalyst.analysis

import scala.annotation.tailrec
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.AnalysisException
Expand Down Expand Up @@ -665,14 +664,18 @@ class Analyzer(
* Generate a new logical plan for the right child with different expression IDs
* for all conflicting attributes.
*/
private def dedupRight (left: LogicalPlan, originalRight: LogicalPlan): LogicalPlan = {
// Remove analysis barrier if any.
val right = EliminateBarriers(originalRight)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If right plan is wrapped (e.g., we join two datasets) in an analysis barrier, the later right.collect doesn't work.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, I see, you have recursively dedupRight on it.

private def dedupRight (left: LogicalPlan, right: LogicalPlan): LogicalPlan = {
val conflictingAttributes = left.outputSet.intersect(right.outputSet)
logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} " +
s"between $left and $right")

right.collect {
// For `AnalysisBarrier`, recursively de-duplicate its child.
case oldVersion: AnalysisBarrier
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
val newVersion = dedupRight(left, oldVersion.child)
(oldVersion, AnalysisBarrier(newVersion))

// Handle base relations that might appear more than once.
case oldVersion: MultiInstanceRelation
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
Expand Down Expand Up @@ -710,10 +713,10 @@ class Analyzer(
* that this rule cannot handle. When that is the case, there must be another rule
* that resolves these conflicts. Otherwise, the analysis will fail.
*/
originalRight
right
case Some((oldRelation, newRelation)) =>
val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output))
val newRight = right transformUp {
right transformUp {
case r if r == oldRelation => newRelation
} transformUp {
case other => other transformExpressions {
Expand All @@ -723,7 +726,6 @@ class Analyzer(
s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attributeRewrites))
}
}
AnalysisBarrier(newRight)
}
}

Expand Down Expand Up @@ -958,7 +960,8 @@ class Analyzer(
protected[sql] def resolveExpression(
expr: Expression,
plan: LogicalPlan,
throws: Boolean = false) = {
throws: Boolean = false): Expression = {
if (expr.resolved) return expr
// Resolve expression in one round.
// If throws == false or the desired attribute doesn't exist
// (like try to resolve `a.b` but `a` doesn't exist), fail and return the origin one.
Expand Down Expand Up @@ -1079,100 +1082,74 @@ class Analyzer(
case sa @ Sort(_, _, AnalysisBarrier(child: Aggregate)) => sa
case sa @ Sort(_, _, child: Aggregate) => sa

case s @ Sort(order, _, originalChild) if !s.resolved && originalChild.resolved =>
val child = EliminateBarriers(originalChild)
try {
val newOrder = order.map(resolveExpressionRecursively(_, child).asInstanceOf[SortOrder])
val requiredAttrs = AttributeSet(newOrder).filter(_.resolved)
val missingAttrs = requiredAttrs -- child.outputSet
if (missingAttrs.nonEmpty) {
// Add missing attributes and then project them away after the sort.
Project(child.output,
Sort(newOrder, s.global, addMissingAttr(child, missingAttrs)))
} else if (newOrder != order) {
s.copy(order = newOrder)
} else {
s
}
} catch {
// Attempting to resolve it might fail. When this happens, return the original plan.
// Users will see an AnalysisException for resolution failure of missing attributes
// in Sort
case ae: AnalysisException => s
case s @ Sort(order, _, child) if !s.resolved && child.resolved =>
val (newOrder, newChild) = resolveExprsAndAddMissingAttrs(order, child)
val ordering = newOrder.map(_.asInstanceOf[SortOrder])
if (child.output == newChild.output) {
s.copy(order = ordering)
} else {
// Add missing attributes and then project them away.
val newSort = s.copy(order = ordering, child = newChild)
Project(child.output, newSort)
}

case f @ Filter(cond, originalChild) if !f.resolved && originalChild.resolved =>
val child = EliminateBarriers(originalChild)
try {
val newCond = resolveExpressionRecursively(cond, child)
val requiredAttrs = newCond.references.filter(_.resolved)
val missingAttrs = requiredAttrs -- child.outputSet
if (missingAttrs.nonEmpty) {
// Add missing attributes and then project them away.
Project(child.output,
Filter(newCond, addMissingAttr(child, missingAttrs)))
} else if (newCond != cond) {
f.copy(condition = newCond)
} else {
f
}
} catch {
// Attempting to resolve it might fail. When this happens, return the original plan.
// Users will see an AnalysisException for resolution failure of missing attributes
case ae: AnalysisException => f
case f @ Filter(cond, child) if !f.resolved && child.resolved =>
val (newCond, newChild) = resolveExprsAndAddMissingAttrs(Seq(cond), child)
if (child.output == newChild.output) {
f.copy(condition = newCond.head)
} else {
// Add missing attributes and then project them away.
val newFilter = Filter(newCond.head, newChild)
Project(child.output, newFilter)
}
}

/**
* Add the missing attributes into projectList of Project/Window or aggregateExpressions of
* Aggregate.
*/
private def addMissingAttr(plan: LogicalPlan, missingAttrs: AttributeSet): LogicalPlan = {
if (missingAttrs.isEmpty) {
return AnalysisBarrier(plan)
}
plan match {
case p: Project =>
val missing = missingAttrs -- p.child.outputSet
Project(p.projectList ++ missingAttrs, addMissingAttr(p.child, missing))
case a: Aggregate =>
// all the missing attributes should be grouping expressions
// TODO: push down AggregateExpression
missingAttrs.foreach { attr =>
if (!a.groupingExpressions.exists(_.semanticEquals(attr))) {
throw new AnalysisException(s"Can't add $attr to ${a.simpleString}")
}
}
val newAggregateExpressions = a.aggregateExpressions ++ missingAttrs
a.copy(aggregateExpressions = newAggregateExpressions)
case g: Generate =>
// If join is false, we will convert it to true for getting from the child the missing
// attributes that its child might have or could have.
val missing = missingAttrs -- g.child.outputSet
g.copy(join = true, child = addMissingAttr(g.child, missing))
case d: Distinct =>
throw new AnalysisException(s"Can't add $missingAttrs to $d")
case u: UnaryNode =>
u.withNewChildren(addMissingAttr(u.child, missingAttrs) :: Nil)
case other =>
throw new AnalysisException(s"Can't add $missingAttrs to $other")
}
}

/**
* Resolve the expression on a specified logical plan and it's child (recursively), until
* the expression is resolved or meet a non-unary node or Subquery.
*/
@tailrec
private def resolveExpressionRecursively(expr: Expression, plan: LogicalPlan): Expression = {
val resolved = resolveExpression(expr, plan)
if (resolved.resolved) {
resolved
private def resolveExprsAndAddMissingAttrs(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I refactored the code to resolve expressions and add missing attributes in one shot, so that we have a central place to deal with analysis barrier and to decide which operator is supported and which is not.

exprs: Seq[Expression], plan: LogicalPlan): (Seq[Expression], LogicalPlan) = {
if (exprs.forall(_.resolved)) {
// All given expressions are resolved, no need to continue anymore.
(exprs, plan)
} else {
plan match {
case u: UnaryNode if !u.isInstanceOf[SubqueryAlias] =>
resolveExpressionRecursively(resolved, u.child)
case other => resolved
// For `AnalysisBarrier`, recursively resolve expressions and add missing attributes via
// its child.
case barrier: AnalysisBarrier =>
val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(exprs, barrier.child)
(newExprs, AnalysisBarrier(newChild))

case p: Project =>
val maybeResolvedExprs = exprs.map(resolveExpression(_, p))
val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, p.child)
val missingAttrs = AttributeSet(newExprs) -- AttributeSet(maybeResolvedExprs)
(newExprs, Project(p.projectList ++ missingAttrs, newChild))

case a @ Aggregate(groupExprs, aggExprs, child) =>
val maybeResolvedExprs = exprs.map(resolveExpression(_, a))
val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, child)
val missingAttrs = AttributeSet(newExprs) -- AttributeSet(maybeResolvedExprs)
if (missingAttrs.forall(attr => groupExprs.exists(_.semanticEquals(attr)))) {
// All the missing attributes are grouping expressions, valid case.
(newExprs, a.copy(aggregateExpressions = aggExprs ++ missingAttrs, child = newChild))
} else {
// Need to add non-grouping attributes, invalid case.
(exprs, a)
}

case g: Generate =>
val maybeResolvedExprs = exprs.map(resolveExpression(_, g))
val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, g.child)
(newExprs, g.copy(join = true, child = newChild))

// For `Distinct` and `SubqueryAlias`, we can't recursively resolve and add attributes
// via its children.
case u: UnaryNode if !u.isInstanceOf[Distinct] && !u.isInstanceOf[SubqueryAlias] =>
val maybeResolvedExprs = exprs.map(resolveExpression(_, u))
val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, u.child)
(newExprs, u.withNewChildren(Seq(newChild)))

// For other operators, we can't recursively resolve and add attributes via its children.
case other =>
(exprs.map(resolveExpression(_, other)), other)
}
}
}
Expand Down Expand Up @@ -1404,18 +1381,16 @@ class Analyzer(
*/
object ResolveAggregateFunctions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp {
case filter @ Filter(havingCondition, AnalysisBarrier(aggregate: Aggregate)) =>
apply(Filter(havingCondition, aggregate)).mapChildren(AnalysisBarrier)
case filter @ Filter(havingCondition,
aggregate @ Aggregate(grouping, originalAggExprs, child))
if aggregate.resolved =>
case Filter(cond, AnalysisBarrier(agg: Aggregate)) =>
apply(Filter(cond, agg)).mapChildren(AnalysisBarrier)
case f @ Filter(cond, agg @ Aggregate(grouping, originalAggExprs, child)) if agg.resolved =>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just make the names shorter


// Try resolving the condition of the filter as though it is in the aggregate clause
try {
val aggregatedCondition =
Aggregate(
grouping,
Alias(havingCondition, "havingCondition")() :: Nil,
Alias(cond, "havingCondition")() :: Nil,
child)
val resolvedOperator = execute(aggregatedCondition)
def resolvedAggregateFilter =
Expand All @@ -1436,7 +1411,7 @@ class Analyzer(
// Grouping functions are handled in the rule [[ResolveGroupingAnalytics]].
case e: Expression if grouping.exists(_.semanticEquals(e)) &&
!ResolveGroupingAnalytics.hasGroupingFunction(e) &&
!aggregate.output.exists(_.semanticEquals(e)) =>
!agg.output.exists(_.semanticEquals(e)) =>
e match {
case ne: NamedExpression =>
aggregateExpressions += ne
Expand All @@ -1450,22 +1425,22 @@ class Analyzer(

// Push the aggregate expressions into the aggregate (if any).
if (aggregateExpressions.nonEmpty) {
Project(aggregate.output,
Project(agg.output,
Filter(transformedAggregateFilter,
aggregate.copy(aggregateExpressions = originalAggExprs ++ aggregateExpressions)))
agg.copy(aggregateExpressions = originalAggExprs ++ aggregateExpressions)))
} else {
filter
f
}
} else {
filter
f
}
} catch {
// Attempting to resolve in the aggregate can result in ambiguity. When this happens,
// just return the original plan.
case ae: AnalysisException => filter
case ae: AnalysisException => f
}

case sort @ Sort(sortOrder, global, AnalysisBarrier(aggregate: Aggregate)) =>
case Sort(sortOrder, global, AnalysisBarrier(aggregate: Aggregate)) =>
apply(Sort(sortOrder, global, aggregate)).mapChildren(AnalysisBarrier)
case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved =>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1562,7 +1562,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
}

test("multi-insert with lateral view") {
withTempView("t1") {
withTempView("source") {
spark.range(10)
.select(array($"id", $"id" + 1).as("arr"), $"id")
.createOrReplaceTempView("source")
Expand Down