-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-20392][SQL][followup] should not add extra AnalysisBarrier #20094
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,7 +17,6 @@ | |
|
|
||
| package org.apache.spark.sql.catalyst.analysis | ||
|
|
||
| import scala.annotation.tailrec | ||
| import scala.collection.mutable.ArrayBuffer | ||
|
|
||
| import org.apache.spark.sql.AnalysisException | ||
|
|
@@ -665,14 +664,18 @@ class Analyzer( | |
| * Generate a new logical plan for the right child with different expression IDs | ||
| * for all conflicting attributes. | ||
| */ | ||
| private def dedupRight (left: LogicalPlan, originalRight: LogicalPlan): LogicalPlan = { | ||
| // Remove analysis barrier if any. | ||
| val right = EliminateBarriers(originalRight) | ||
| private def dedupRight (left: LogicalPlan, right: LogicalPlan): LogicalPlan = { | ||
| val conflictingAttributes = left.outputSet.intersect(right.outputSet) | ||
| logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} " + | ||
| s"between $left and $right") | ||
|
|
||
| right.collect { | ||
| // For `AnalysisBarrier`, recursively de-duplicate its child. | ||
| case oldVersion: AnalysisBarrier | ||
| if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => | ||
| val newVersion = dedupRight(left, oldVersion.child) | ||
| (oldVersion, AnalysisBarrier(newVersion)) | ||
|
|
||
| // Handle base relations that might appear more than once. | ||
| case oldVersion: MultiInstanceRelation | ||
| if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => | ||
|
|
@@ -710,10 +713,10 @@ class Analyzer( | |
| * that this rule cannot handle. When that is the case, there must be another rule | ||
| * that resolves these conflicts. Otherwise, the analysis will fail. | ||
| */ | ||
| originalRight | ||
| right | ||
| case Some((oldRelation, newRelation)) => | ||
| val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) | ||
| val newRight = right transformUp { | ||
| right transformUp { | ||
| case r if r == oldRelation => newRelation | ||
| } transformUp { | ||
| case other => other transformExpressions { | ||
|
|
@@ -723,7 +726,6 @@ class Analyzer( | |
| s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attributeRewrites)) | ||
| } | ||
| } | ||
| AnalysisBarrier(newRight) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -958,7 +960,8 @@ class Analyzer( | |
| protected[sql] def resolveExpression( | ||
| expr: Expression, | ||
| plan: LogicalPlan, | ||
| throws: Boolean = false) = { | ||
| throws: Boolean = false): Expression = { | ||
| if (expr.resolved) return expr | ||
| // Resolve expression in one round. | ||
| // If throws == false or the desired attribute doesn't exist | ||
| // (like try to resolve `a.b` but `a` doesn't exist), fail and return the origin one. | ||
|
|
@@ -1079,100 +1082,74 @@ class Analyzer( | |
| case sa @ Sort(_, _, AnalysisBarrier(child: Aggregate)) => sa | ||
| case sa @ Sort(_, _, child: Aggregate) => sa | ||
|
|
||
| case s @ Sort(order, _, originalChild) if !s.resolved && originalChild.resolved => | ||
| val child = EliminateBarriers(originalChild) | ||
| try { | ||
| val newOrder = order.map(resolveExpressionRecursively(_, child).asInstanceOf[SortOrder]) | ||
| val requiredAttrs = AttributeSet(newOrder).filter(_.resolved) | ||
| val missingAttrs = requiredAttrs -- child.outputSet | ||
| if (missingAttrs.nonEmpty) { | ||
| // Add missing attributes and then project them away after the sort. | ||
| Project(child.output, | ||
| Sort(newOrder, s.global, addMissingAttr(child, missingAttrs))) | ||
| } else if (newOrder != order) { | ||
| s.copy(order = newOrder) | ||
| } else { | ||
| s | ||
| } | ||
| } catch { | ||
| // Attempting to resolve it might fail. When this happens, return the original plan. | ||
| // Users will see an AnalysisException for resolution failure of missing attributes | ||
| // in Sort | ||
| case ae: AnalysisException => s | ||
| case s @ Sort(order, _, child) if !s.resolved && child.resolved => | ||
| val (newOrder, newChild) = resolveExprsAndAddMissingAttrs(order, child) | ||
| val ordering = newOrder.map(_.asInstanceOf[SortOrder]) | ||
| if (child.output == newChild.output) { | ||
| s.copy(order = ordering) | ||
| } else { | ||
| // Add missing attributes and then project them away. | ||
| val newSort = s.copy(order = ordering, child = newChild) | ||
| Project(child.output, newSort) | ||
| } | ||
|
|
||
| case f @ Filter(cond, originalChild) if !f.resolved && originalChild.resolved => | ||
| val child = EliminateBarriers(originalChild) | ||
| try { | ||
| val newCond = resolveExpressionRecursively(cond, child) | ||
| val requiredAttrs = newCond.references.filter(_.resolved) | ||
| val missingAttrs = requiredAttrs -- child.outputSet | ||
| if (missingAttrs.nonEmpty) { | ||
| // Add missing attributes and then project them away. | ||
| Project(child.output, | ||
| Filter(newCond, addMissingAttr(child, missingAttrs))) | ||
| } else if (newCond != cond) { | ||
| f.copy(condition = newCond) | ||
| } else { | ||
| f | ||
| } | ||
| } catch { | ||
| // Attempting to resolve it might fail. When this happens, return the original plan. | ||
| // Users will see an AnalysisException for resolution failure of missing attributes | ||
| case ae: AnalysisException => f | ||
| case f @ Filter(cond, child) if !f.resolved && child.resolved => | ||
| val (newCond, newChild) = resolveExprsAndAddMissingAttrs(Seq(cond), child) | ||
| if (child.output == newChild.output) { | ||
| f.copy(condition = newCond.head) | ||
| } else { | ||
| // Add missing attributes and then project them away. | ||
| val newFilter = Filter(newCond.head, newChild) | ||
| Project(child.output, newFilter) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Add the missing attributes into projectList of Project/Window or aggregateExpressions of | ||
| * Aggregate. | ||
| */ | ||
| private def addMissingAttr(plan: LogicalPlan, missingAttrs: AttributeSet): LogicalPlan = { | ||
| if (missingAttrs.isEmpty) { | ||
| return AnalysisBarrier(plan) | ||
| } | ||
| plan match { | ||
| case p: Project => | ||
| val missing = missingAttrs -- p.child.outputSet | ||
| Project(p.projectList ++ missingAttrs, addMissingAttr(p.child, missing)) | ||
| case a: Aggregate => | ||
| // all the missing attributes should be grouping expressions | ||
| // TODO: push down AggregateExpression | ||
| missingAttrs.foreach { attr => | ||
| if (!a.groupingExpressions.exists(_.semanticEquals(attr))) { | ||
| throw new AnalysisException(s"Can't add $attr to ${a.simpleString}") | ||
| } | ||
| } | ||
| val newAggregateExpressions = a.aggregateExpressions ++ missingAttrs | ||
| a.copy(aggregateExpressions = newAggregateExpressions) | ||
| case g: Generate => | ||
| // If join is false, we will convert it to true for getting from the child the missing | ||
| // attributes that its child might have or could have. | ||
| val missing = missingAttrs -- g.child.outputSet | ||
| g.copy(join = true, child = addMissingAttr(g.child, missing)) | ||
| case d: Distinct => | ||
| throw new AnalysisException(s"Can't add $missingAttrs to $d") | ||
| case u: UnaryNode => | ||
| u.withNewChildren(addMissingAttr(u.child, missingAttrs) :: Nil) | ||
| case other => | ||
| throw new AnalysisException(s"Can't add $missingAttrs to $other") | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Resolve the expression on a specified logical plan and it's child (recursively), until | ||
| * the expression is resolved or meet a non-unary node or Subquery. | ||
| */ | ||
| @tailrec | ||
| private def resolveExpressionRecursively(expr: Expression, plan: LogicalPlan): Expression = { | ||
| val resolved = resolveExpression(expr, plan) | ||
| if (resolved.resolved) { | ||
| resolved | ||
| private def resolveExprsAndAddMissingAttrs( | ||
|
||
| exprs: Seq[Expression], plan: LogicalPlan): (Seq[Expression], LogicalPlan) = { | ||
| if (exprs.forall(_.resolved)) { | ||
| // All given expressions are resolved, no need to continue anymore. | ||
| (exprs, plan) | ||
| } else { | ||
| plan match { | ||
| case u: UnaryNode if !u.isInstanceOf[SubqueryAlias] => | ||
| resolveExpressionRecursively(resolved, u.child) | ||
| case other => resolved | ||
| // For `AnalysisBarrier`, recursively resolve expressions and add missing attributes via | ||
| // its child. | ||
| case barrier: AnalysisBarrier => | ||
| val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(exprs, barrier.child) | ||
| (newExprs, AnalysisBarrier(newChild)) | ||
|
|
||
| case p: Project => | ||
| val maybeResolvedExprs = exprs.map(resolveExpression(_, p)) | ||
| val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, p.child) | ||
| val missingAttrs = AttributeSet(newExprs) -- AttributeSet(maybeResolvedExprs) | ||
| (newExprs, Project(p.projectList ++ missingAttrs, newChild)) | ||
|
|
||
| case a @ Aggregate(groupExprs, aggExprs, child) => | ||
| val maybeResolvedExprs = exprs.map(resolveExpression(_, a)) | ||
| val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, child) | ||
| val missingAttrs = AttributeSet(newExprs) -- AttributeSet(maybeResolvedExprs) | ||
| if (missingAttrs.forall(attr => groupExprs.exists(_.semanticEquals(attr)))) { | ||
| // All the missing attributes are grouping expressions, valid case. | ||
| (newExprs, a.copy(aggregateExpressions = aggExprs ++ missingAttrs, child = newChild)) | ||
| } else { | ||
| // Need to add non-grouping attributes, invalid case. | ||
| (exprs, a) | ||
| } | ||
|
|
||
| case g: Generate => | ||
| val maybeResolvedExprs = exprs.map(resolveExpression(_, g)) | ||
| val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, g.child) | ||
| (newExprs, g.copy(join = true, child = newChild)) | ||
|
|
||
| // For `Distinct` and `SubqueryAlias`, we can't recursively resolve and add attributes | ||
| // via its children. | ||
| case u: UnaryNode if !u.isInstanceOf[Distinct] && !u.isInstanceOf[SubqueryAlias] => | ||
| val maybeResolvedExprs = exprs.map(resolveExpression(_, u)) | ||
| val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, u.child) | ||
| (newExprs, u.withNewChildren(Seq(newChild))) | ||
|
|
||
| // For other operators, we can't recursively resolve and add attributes via its children. | ||
| case other => | ||
| (exprs.map(resolveExpression(_, other)), other) | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -1404,18 +1381,16 @@ class Analyzer( | |
| */ | ||
| object ResolveAggregateFunctions extends Rule[LogicalPlan] { | ||
| def apply(plan: LogicalPlan): LogicalPlan = plan.transformUp { | ||
| case filter @ Filter(havingCondition, AnalysisBarrier(aggregate: Aggregate)) => | ||
| apply(Filter(havingCondition, aggregate)).mapChildren(AnalysisBarrier) | ||
| case filter @ Filter(havingCondition, | ||
| aggregate @ Aggregate(grouping, originalAggExprs, child)) | ||
| if aggregate.resolved => | ||
| case Filter(cond, AnalysisBarrier(agg: Aggregate)) => | ||
| apply(Filter(cond, agg)).mapChildren(AnalysisBarrier) | ||
| case f @ Filter(cond, agg @ Aggregate(grouping, originalAggExprs, child)) if agg.resolved => | ||
|
||
|
|
||
| // Try resolving the condition of the filter as though it is in the aggregate clause | ||
| try { | ||
| val aggregatedCondition = | ||
| Aggregate( | ||
| grouping, | ||
| Alias(havingCondition, "havingCondition")() :: Nil, | ||
| Alias(cond, "havingCondition")() :: Nil, | ||
| child) | ||
| val resolvedOperator = execute(aggregatedCondition) | ||
| def resolvedAggregateFilter = | ||
|
|
@@ -1436,7 +1411,7 @@ class Analyzer( | |
| // Grouping functions are handled in the rule [[ResolveGroupingAnalytics]]. | ||
| case e: Expression if grouping.exists(_.semanticEquals(e)) && | ||
| !ResolveGroupingAnalytics.hasGroupingFunction(e) && | ||
| !aggregate.output.exists(_.semanticEquals(e)) => | ||
| !agg.output.exists(_.semanticEquals(e)) => | ||
| e match { | ||
| case ne: NamedExpression => | ||
| aggregateExpressions += ne | ||
|
|
@@ -1450,22 +1425,22 @@ class Analyzer( | |
|
|
||
| // Push the aggregate expressions into the aggregate (if any). | ||
| if (aggregateExpressions.nonEmpty) { | ||
| Project(aggregate.output, | ||
| Project(agg.output, | ||
| Filter(transformedAggregateFilter, | ||
| aggregate.copy(aggregateExpressions = originalAggExprs ++ aggregateExpressions))) | ||
| agg.copy(aggregateExpressions = originalAggExprs ++ aggregateExpressions))) | ||
| } else { | ||
| filter | ||
| f | ||
| } | ||
| } else { | ||
| filter | ||
| f | ||
| } | ||
| } catch { | ||
| // Attempting to resolve in the aggregate can result in ambiguity. When this happens, | ||
| // just return the original plan. | ||
| case ae: AnalysisException => filter | ||
| case ae: AnalysisException => f | ||
| } | ||
|
|
||
| case sort @ Sort(sortOrder, global, AnalysisBarrier(aggregate: Aggregate)) => | ||
| case Sort(sortOrder, global, AnalysisBarrier(aggregate: Aggregate)) => | ||
| apply(Sort(sortOrder, global, aggregate)).mapChildren(AnalysisBarrier) | ||
| case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved => | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If right plan is wrapped (e.g., we join two datasets) in an analysis barrier, the later
right.collectdoesn't work.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, I see, you have recursively
dedupRighton it.