Skip to content

Commit 71b6eac

Browse files
committed
[SPARK-18609][SPARK-18841][SQL][BACKPORT-2.1] Fix redundant Alias removal in the optimizer
This is a backport of apache@73ee739 ## What changes were proposed in this pull request? The optimizer tries to remove redundant alias only projections from the query plan using the `RemoveAliasOnlyProject` rule. The current rule identifies removes such a project and rewrites the project's attributes in the **entire** tree. This causes problems when parts of the tree are duplicated (for instance a self join on a temporary view/CTE) and the duplicated part contains the alias only project, in this case the rewrite will break the tree. This PR fixes these problems by using a blacklist for attributes that are not to be moved, and by making sure that attribute remapping is only done for the parent tree, and not for unrelated parts of the query plan. The current tree transformation infrastructure works very well if the transformation at hand requires little or a global contextual information. In this case we need to know both the attributes that were not to be moved, and we also needed to know which child attributes were modified. This cannot be done easily using the current infrastructure, and solutions typically involves transversing the query plan multiple times (which is super slow). I have moved around some code in `TreeNode`, `QueryPlan` and `LogicalPlan`to make this much more straightforward; this basically allows you to manually traverse the tree. ## How was this patch tested? I have added unit tests to `RemoveRedundantAliasAndProjectSuite` and I have added integration tests to the `SQLQueryTestSuite.union` and `SQLQueryTestSuite.cte` test cases. Author: Herman van Hovell <[email protected]> Closes apache#16843 from hvanhovell/SPARK-18609-2.1.
1 parent 4d04029 commit 71b6eac

File tree

9 files changed

+302
-115
lines changed

9 files changed

+302
-115
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 84 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
109109
SimplifyCaseConversionExpressions,
110110
RewriteCorrelatedScalarSubquery,
111111
EliminateSerialization,
112-
RemoveAliasOnlyProject) ::
112+
RemoveRedundantAliases,
113+
RemoveRedundantProject) ::
113114
Batch("Check Cartesian Products", Once,
114115
CheckCartesianProducts(conf)) ::
115116
Batch("Decimal Optimizations", fixedPoint,
@@ -153,56 +154,98 @@ class SimpleTestOptimizer extends Optimizer(
153154
new SimpleCatalystConf(caseSensitiveAnalysis = true))
154155

155156
/**
156-
* Removes the Project only conducting Alias of its child node.
157-
* It is created mainly for removing extra Project added in EliminateSerialization rule,
158-
* but can also benefit other operators.
157+
* Remove redundant aliases from a query plan. A redundant alias is an alias that does not change
158+
* the name or metadata of a column, and does not deduplicate it.
159159
*/
160-
object RemoveAliasOnlyProject extends Rule[LogicalPlan] {
160+
object RemoveRedundantAliases extends Rule[LogicalPlan] {
161+
161162
/**
162-
* Returns true if the project list is semantically same as child output, after strip alias on
163-
* attribute.
163+
* Create an attribute mapping from the old to the new attributes. This function will only
164+
* return the attribute pairs that have changed.
164165
*/
165-
private def isAliasOnly(
166-
projectList: Seq[NamedExpression],
167-
childOutput: Seq[Attribute]): Boolean = {
168-
if (projectList.length != childOutput.length) {
169-
false
170-
} else {
171-
stripAliasOnAttribute(projectList).zip(childOutput).forall {
172-
case (a: Attribute, o) if a semanticEquals o => true
173-
case _ => false
174-
}
166+
private def createAttributeMapping(current: LogicalPlan, next: LogicalPlan)
167+
: Seq[(Attribute, Attribute)] = {
168+
current.output.zip(next.output).filterNot {
169+
case (a1, a2) => a1.semanticEquals(a2)
175170
}
176171
}
177172

178-
private def stripAliasOnAttribute(projectList: Seq[NamedExpression]) = {
179-
projectList.map {
180-
// Alias with metadata can not be stripped, or the metadata will be lost.
181-
// If the alias name is different from attribute name, we can't strip it either, or we may
182-
// accidentally change the output schema name of the root plan.
183-
case a @ Alias(attr: Attribute, name) if a.metadata == Metadata.empty && name == attr.name =>
184-
attr
185-
case other => other
186-
}
173+
/**
174+
* Remove the top-level alias from an expression when it is redundant.
175+
*/
176+
private def removeRedundantAlias(e: Expression, blacklist: AttributeSet): Expression = e match {
177+
// Alias with metadata can not be stripped, or the metadata will be lost.
178+
// If the alias name is different from attribute name, we can't strip it either, or we
179+
// may accidentally change the output schema name of the root plan.
180+
case a @ Alias(attr: Attribute, name)
181+
if a.metadata == Metadata.empty && name == attr.name && !blacklist.contains(attr) =>
182+
attr
183+
case a => a
187184
}
188185

189-
def apply(plan: LogicalPlan): LogicalPlan = {
190-
val aliasOnlyProject = plan.collectFirst {
191-
case p @ Project(pList, child) if isAliasOnly(pList, child.output) => p
192-
}
186+
/**
187+
* Remove redundant alias expression from a LogicalPlan and its subtree. A blacklist is used to
188+
* prevent the removal of seemingly redundant aliases used to deduplicate the input for a (self)
189+
* join.
190+
*/
191+
private def removeRedundantAliases(plan: LogicalPlan, blacklist: AttributeSet): LogicalPlan = {
192+
plan match {
193+
// A join has to be treated differently, because the left and the right side of the join are
194+
// not allowed to use the same attributes. We use a blacklist to prevent us from creating a
195+
// situation in which this happens; the rule will only remove an alias if its child
196+
// attribute is not on the black list.
197+
case Join(left, right, joinType, condition) =>
198+
val newLeft = removeRedundantAliases(left, blacklist ++ right.outputSet)
199+
val newRight = removeRedundantAliases(right, blacklist ++ newLeft.outputSet)
200+
val mapping = AttributeMap(
201+
createAttributeMapping(left, newLeft) ++
202+
createAttributeMapping(right, newRight))
203+
val newCondition = condition.map(_.transform {
204+
case a: Attribute => mapping.getOrElse(a, a)
205+
})
206+
Join(newLeft, newRight, joinType, newCondition)
207+
208+
case _ =>
209+
// Remove redundant aliases in the subtree(s).
210+
val currentNextAttrPairs = mutable.Buffer.empty[(Attribute, Attribute)]
211+
val newNode = plan.mapChildren { child =>
212+
val newChild = removeRedundantAliases(child, blacklist)
213+
currentNextAttrPairs ++= createAttributeMapping(child, newChild)
214+
newChild
215+
}
193216

194-
aliasOnlyProject.map { case proj =>
195-
val attributesToReplace = proj.output.zip(proj.child.output).filterNot {
196-
case (a1, a2) => a1 semanticEquals a2
197-
}
198-
val attrMap = AttributeMap(attributesToReplace)
199-
plan transform {
200-
case plan: Project if plan eq proj => plan.child
201-
case plan => plan transformExpressions {
202-
case a: Attribute if attrMap.contains(a) => attrMap(a)
217+
// Create the attribute mapping. Note that the currentNextAttrPairs can contain duplicate
218+
// keys in case of Union (this is caused by the PushProjectionThroughUnion rule); in this
219+
// case we use the the first mapping (which should be provided by the first child).
220+
val mapping = AttributeMap(currentNextAttrPairs)
221+
222+
// Create a an expression cleaning function for nodes that can actually produce redundant
223+
// aliases, use identity otherwise.
224+
val clean: Expression => Expression = plan match {
225+
case _: Project => removeRedundantAlias(_, blacklist)
226+
case _: Aggregate => removeRedundantAlias(_, blacklist)
227+
case _: Window => removeRedundantAlias(_, blacklist)
228+
case _ => identity[Expression]
203229
}
204-
}
205-
}.getOrElse(plan)
230+
231+
// Transform the expressions.
232+
newNode.mapExpressions { expr =>
233+
clean(expr.transform {
234+
case a: Attribute => mapping.getOrElse(a, a)
235+
})
236+
}
237+
}
238+
}
239+
240+
def apply(plan: LogicalPlan): LogicalPlan = removeRedundantAliases(plan, AttributeSet.empty)
241+
}
242+
243+
/**
244+
* Remove projections from the query plan that do not make any modifications.
245+
*/
246+
object RemoveRedundantProject extends Rule[LogicalPlan] {
247+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
248+
case p @ Project(_, child) if p.output == child.output => child
206249
}
207250
}
208251

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala

Lines changed: 13 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -242,31 +242,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
242242
* @param rule the rule to be applied to every expression in this operator.
243243
*/
244244
def transformExpressionsDown(rule: PartialFunction[Expression, Expression]): this.type = {
245-
var changed = false
246-
247-
@inline def transformExpressionDown(e: Expression): Expression = {
248-
val newE = e.transformDown(rule)
249-
if (newE.fastEquals(e)) {
250-
e
251-
} else {
252-
changed = true
253-
newE
254-
}
255-
}
256-
257-
def recursiveTransform(arg: Any): AnyRef = arg match {
258-
case e: Expression => transformExpressionDown(e)
259-
case Some(e: Expression) => Some(transformExpressionDown(e))
260-
case m: Map[_, _] => m
261-
case d: DataType => d // Avoid unpacking Structs
262-
case seq: Traversable[_] => seq.map(recursiveTransform)
263-
case other: AnyRef => other
264-
case null => null
265-
}
266-
267-
val newArgs = mapProductIterator(recursiveTransform)
268-
269-
if (changed) makeCopy(newArgs).asInstanceOf[this.type] else this
245+
mapExpressions(_.transformDown(rule))
270246
}
271247

272248
/**
@@ -276,10 +252,18 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
276252
* @return
277253
*/
278254
def transformExpressionsUp(rule: PartialFunction[Expression, Expression]): this.type = {
255+
mapExpressions(_.transformUp(rule))
256+
}
257+
258+
/**
259+
* Apply a map function to each expression present in this query operator, and return a new
260+
* query operator based on the mapped expressions.
261+
*/
262+
def mapExpressions(f: Expression => Expression): this.type = {
279263
var changed = false
280264

281-
@inline def transformExpressionUp(e: Expression): Expression = {
282-
val newE = e.transformUp(rule)
265+
@inline def transformExpression(e: Expression): Expression = {
266+
val newE = f(e)
283267
if (newE.fastEquals(e)) {
284268
e
285269
} else {
@@ -289,8 +273,8 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
289273
}
290274

291275
def recursiveTransform(arg: Any): AnyRef = arg match {
292-
case e: Expression => transformExpressionUp(e)
293-
case Some(e: Expression) => Some(transformExpressionUp(e))
276+
case e: Expression => transformExpression(e)
277+
case Some(e: Expression) => Some(transformExpression(e))
294278
case m: Map[_, _] => m
295279
case d: DataType => d // Avoid unpacking Structs
296280
case seq: Traversable[_] => seq.map(recursiveTransform)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
5555
*/
5656
def resolveOperators(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = {
5757
if (!analyzed) {
58-
val afterRuleOnChildren = transformChildren(rule, (t, r) => t.resolveOperators(r))
58+
val afterRuleOnChildren = mapChildren(_.resolveOperators(rule))
5959
if (this fastEquals afterRuleOnChildren) {
6060
CurrentOrigin.withOrigin(origin) {
6161
rule.applyOrElse(this, identity[LogicalPlan])

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala

Lines changed: 11 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -191,26 +191,6 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
191191
arr
192192
}
193193

194-
/**
195-
* Returns a copy of this node where `f` has been applied to all the nodes children.
196-
*/
197-
def mapChildren(f: BaseType => BaseType): BaseType = {
198-
var changed = false
199-
val newArgs = mapProductIterator {
200-
case arg: TreeNode[_] if containsChild(arg) =>
201-
val newChild = f(arg.asInstanceOf[BaseType])
202-
if (newChild fastEquals arg) {
203-
arg
204-
} else {
205-
changed = true
206-
newChild
207-
}
208-
case nonChild: AnyRef => nonChild
209-
case null => null
210-
}
211-
if (changed) makeCopy(newArgs) else this
212-
}
213-
214194
/**
215195
* Returns a copy of this node with the children replaced.
216196
* TODO: Validate somewhere (in debug mode?) that children are ordered correctly.
@@ -290,9 +270,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
290270

291271
// Check if unchanged and then possibly return old copy to avoid gc churn.
292272
if (this fastEquals afterRule) {
293-
transformChildren(rule, (t, r) => t.transformDown(r))
273+
mapChildren(_.transformDown(rule))
294274
} else {
295-
afterRule.transformChildren(rule, (t, r) => t.transformDown(r))
275+
afterRule.mapChildren(_.transformDown(rule))
296276
}
297277
}
298278

@@ -304,7 +284,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
304284
* @param rule the function use to transform this nodes children
305285
*/
306286
def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = {
307-
val afterRuleOnChildren = transformChildren(rule, (t, r) => t.transformUp(r))
287+
val afterRuleOnChildren = mapChildren(_.transformUp(rule))
308288
if (this fastEquals afterRuleOnChildren) {
309289
CurrentOrigin.withOrigin(origin) {
310290
rule.applyOrElse(this, identity[BaseType])
@@ -317,26 +297,22 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
317297
}
318298

319299
/**
320-
* Returns a copy of this node where `rule` has been recursively applied to all the children of
321-
* this node. When `rule` does not apply to a given node it is left unchanged.
322-
* @param rule the function used to transform this nodes children
300+
* Returns a copy of this node where `f` has been applied to all the nodes children.
323301
*/
324-
protected def transformChildren(
325-
rule: PartialFunction[BaseType, BaseType],
326-
nextOperation: (BaseType, PartialFunction[BaseType, BaseType]) => BaseType): BaseType = {
302+
def mapChildren(f: BaseType => BaseType): BaseType = {
327303
if (children.nonEmpty) {
328304
var changed = false
329305
val newArgs = mapProductIterator {
330306
case arg: TreeNode[_] if containsChild(arg) =>
331-
val newChild = nextOperation(arg.asInstanceOf[BaseType], rule)
307+
val newChild = f(arg.asInstanceOf[BaseType])
332308
if (!(newChild fastEquals arg)) {
333309
changed = true
334310
newChild
335311
} else {
336312
arg
337313
}
338314
case Some(arg: TreeNode[_]) if containsChild(arg) =>
339-
val newChild = nextOperation(arg.asInstanceOf[BaseType], rule)
315+
val newChild = f(arg.asInstanceOf[BaseType])
340316
if (!(newChild fastEquals arg)) {
341317
changed = true
342318
Some(newChild)
@@ -345,7 +321,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
345321
}
346322
case m: Map[_, _] => m.mapValues {
347323
case arg: TreeNode[_] if containsChild(arg) =>
348-
val newChild = nextOperation(arg.asInstanceOf[BaseType], rule)
324+
val newChild = f(arg.asInstanceOf[BaseType])
349325
if (!(newChild fastEquals arg)) {
350326
changed = true
351327
newChild
@@ -357,16 +333,16 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
357333
case d: DataType => d // Avoid unpacking Structs
358334
case args: Traversable[_] => args.map {
359335
case arg: TreeNode[_] if containsChild(arg) =>
360-
val newChild = nextOperation(arg.asInstanceOf[BaseType], rule)
336+
val newChild = f(arg.asInstanceOf[BaseType])
361337
if (!(newChild fastEquals arg)) {
362338
changed = true
363339
newChild
364340
} else {
365341
arg
366342
}
367343
case tuple@(arg1: TreeNode[_], arg2: TreeNode[_]) =>
368-
val newChild1 = nextOperation(arg1.asInstanceOf[BaseType], rule)
369-
val newChild2 = nextOperation(arg2.asInstanceOf[BaseType], rule)
344+
val newChild1 = f(arg1.asInstanceOf[BaseType])
345+
val newChild2 = f(arg2.asInstanceOf[BaseType])
370346
if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) {
371347
changed = true
372348
(newChild1, newChild2)

0 commit comments

Comments
 (0)