Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ class Analyzer(
EliminateUnions),
Batch("Resolution", fixedPoint,
ResolveRelations ::
ResolveStar ::
ResolveReferences ::
ResolveGroupingAnalytics ::
ResolvePivot ::
Expand Down Expand Up @@ -374,91 +373,6 @@ class Analyzer(
}
}

/**
* Expand [[UnresolvedStar]] or [[ResolvedStar]] to the matching attributes in child's output.
*/
object ResolveStar extends Rule[LogicalPlan] {

def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case p: LogicalPlan if !p.childrenResolved => p
// If the projection list contains Stars, expand it.
case p: Project if containsStar(p.projectList) =>
p.copy(projectList = buildExpandedProjectList(p.projectList, p.child))
// If the aggregate function argument contains Stars, expand it.
case a: Aggregate if containsStar(a.aggregateExpressions) =>
if (conf.groupByOrdinal && a.groupingExpressions.exists(IntegerIndex.unapply(_).nonEmpty)) {
failAnalysis(
"Group by position: star is not allowed to use in the select list " +
"when using ordinals in group by")
} else {
a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a.child))
}
// If the script transformation input contains Stars, expand it.
case t: ScriptTransformation if containsStar(t.input) =>
t.copy(
input = t.input.flatMap {
case s: Star => s.expand(t.child, resolver)
case o => o :: Nil
}
)
case g: Generate if containsStar(g.generator.children) =>
failAnalysis("Invalid usage of '*' in explode/json_tuple/UDTF")
}

/**
* Build a project list for Project/Aggregate and expand the star if possible
*/
private def buildExpandedProjectList(
exprs: Seq[NamedExpression],
child: LogicalPlan): Seq[NamedExpression] = {
exprs.flatMap {
// Using Dataframe/Dataset API: testData2.groupBy($"a", $"b").agg($"*")
case s: Star => s.expand(child, resolver)
// Using SQL API without running ResolveAlias: SELECT * FROM testData2 group by a, b
case UnresolvedAlias(s: Star, _) => s.expand(child, resolver)
case o if containsStar(o :: Nil) => expandStarExpression(o, child) :: Nil
case o => o :: Nil
}.map(_.asInstanceOf[NamedExpression])
}

/**
* Returns true if `exprs` contains a [[Star]].
*/
def containsStar(exprs: Seq[Expression]): Boolean =
exprs.exists(_.collect { case _: Star => true }.nonEmpty)

/**
* Expands the matching attribute.*'s in `child`'s output.
*/
def expandStarExpression(expr: Expression, child: LogicalPlan): Expression = {
expr.transformUp {
case f1: UnresolvedFunction if containsStar(f1.children) =>
f1.copy(children = f1.children.flatMap {
case s: Star => s.expand(child, resolver)
case o => o :: Nil
})
case c: CreateStruct if containsStar(c.children) =>
c.copy(children = c.children.flatMap {
case s: Star => s.expand(child, resolver)
case o => o :: Nil
})
case c: CreateArray if containsStar(c.children) =>
c.copy(children = c.children.flatMap {
case s: Star => s.expand(child, resolver)
case o => o :: Nil
})
case p: Murmur3Hash if containsStar(p.children) =>
p.copy(children = p.children.flatMap {
case s: Star => s.expand(child, resolver)
case o => o :: Nil
})
// count(*) has been replaced by count(1)
case o if containsStar(o.children) =>
failAnalysis(s"Invalid usage of '*' in expression '${o.prettyName}'")
}
}
}

/**
* Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from
* a logical plan node's children.
Expand Down Expand Up @@ -525,6 +439,29 @@ class Analyzer(
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case p: LogicalPlan if !p.childrenResolved => p

// If the projection list contains Stars, expand it.
case p: Project if containsStar(p.projectList) =>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is moved from ResolveStar without any changes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see, this can speed up resolution for nested plans, thanks for fixing it!

p.copy(projectList = buildExpandedProjectList(p.projectList, p.child))
// If the aggregate function argument contains Stars, expand it.
case a: Aggregate if containsStar(a.aggregateExpressions) =>
if (conf.groupByOrdinal && a.groupingExpressions.exists(IntegerIndex.unapply(_).nonEmpty)) {
failAnalysis(
"Group by position: star is not allowed to use in the select list " +
"when using ordinals in group by")
} else {
a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a.child))
}
// If the script transformation input contains Stars, expand it.
case t: ScriptTransformation if containsStar(t.input) =>
t.copy(
input = t.input.flatMap {
case s: Star => s.expand(t.child, resolver)
case o => o :: Nil
}
)
case g: Generate if containsStar(g.generator.children) =>
failAnalysis("Invalid usage of '*' in explode/json_tuple/UDTF")

// To resolve duplicate expression IDs for Join and Intersect
case j @ Join(left, right, _, _) if !j.duplicateResolved =>
j.copy(right = dedupRight(left, right))
Expand Down Expand Up @@ -619,6 +556,59 @@ class Analyzer(
def findAliases(projectList: Seq[NamedExpression]): AttributeSet = {
AttributeSet(projectList.collect { case a: Alias => a.toAttribute })
}

/**
* Build a project list for Project/Aggregate and expand the star if possible
*/
private def buildExpandedProjectList(
exprs: Seq[NamedExpression],
child: LogicalPlan): Seq[NamedExpression] = {
exprs.flatMap {
// Using Dataframe/Dataset API: testData2.groupBy($"a", $"b").agg($"*")
case s: Star => s.expand(child, resolver)
// Using SQL API without running ResolveAlias: SELECT * FROM testData2 group by a, b
case UnresolvedAlias(s: Star, _) => s.expand(child, resolver)
case o if containsStar(o :: Nil) => expandStarExpression(o, child) :: Nil
case o => o :: Nil
}.map(_.asInstanceOf[NamedExpression])
}

/**
* Returns true if `exprs` contains a [[Star]].
*/
def containsStar(exprs: Seq[Expression]): Boolean =
exprs.exists(_.collect { case _: Star => true }.nonEmpty)

/**
* Expands the matching attribute.*'s in `child`'s output.
*/
def expandStarExpression(expr: Expression, child: LogicalPlan): Expression = {
expr.transformUp {
case f1: UnresolvedFunction if containsStar(f1.children) =>
f1.copy(children = f1.children.flatMap {
case s: Star => s.expand(child, resolver)
case o => o :: Nil
})
case c: CreateStruct if containsStar(c.children) =>
c.copy(children = c.children.flatMap {
case s: Star => s.expand(child, resolver)
case o => o :: Nil
})
case c: CreateArray if containsStar(c.children) =>
c.copy(children = c.children.flatMap {
case s: Star => s.expand(child, resolver)
case o => o :: Nil
})
case p: Murmur3Hash if containsStar(p.children) =>
p.copy(children = p.children.flatMap {
case s: Star => s.expand(child, resolver)
case o => o :: Nil
})
// count(*) has been replaced by count(1)
case o if containsStar(o.children) =>
failAnalysis(s"Invalid usage of '*' in expression '${o.prettyName}'")
}
}
}

protected[sql] def resolveExpression(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,21 +306,21 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
}

/**
* Attempts to eliminate the reading of unneeded columns from the query plan using the following
* transformations:
* Attempts to eliminate the reading of unneeded columns from the query plan.
*
* - Inserting Projections beneath the following operators:
* - Aggregate
* - Generate
* - Project <- Join
* - LeftSemiJoin
* Since adding Project before Filter conflicts with PushPredicatesThroughProject, this rule will
* remove the Project p2 in the following pattern:
*
* p1 @ Project(_, Filter(_, p2 @ Project(_, child))) if p2.outputSet.subsetOf(p2.inputSet)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is our target, why not add a new case before the last case, and handle case p1 @ Project(_, f @ Filter(_, child) specially? i.e. don't insert a Project if unnecessary.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to insert the p2 to pruning the columns further, for example Project(Filter(Join)), we need p2 to prune the columns from Join.

*
* p2 is usually inserted by this rule and useless, p1 could prune the columns anyway.
*/
object ColumnPruning extends Rule[LogicalPlan] {
private def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean =
output1.size == output2.size &&
output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2))

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
def apply(plan: LogicalPlan): LogicalPlan = removeProjectBeforeFilter(plan transform {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, we are using transform, which is actually transformDown. In this rule ColumnPruning, we could add many Project into the child. This could easily cause stack overflow. That is why my PR #11745 is changing it to transformUp. Do you think this change makes sense?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Column pruning have to be from top to bottom, or you will need multiple run of this rule. The added Projection is exactly the same whenever you go from top or bottom. If going from bottom, it will not work sometimes (because the added Project will be moved by other rules, for sample filter push down).

Have you actually see the stack overflow on this rule? I donot think so.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are using transformUp, the removeProjectBeforeFilter 's assumption is not right. The following line does not cover all the cases:

case p1 @ Project(_, f @ Filter(_, p2 @ Project(_, child)))
   if p2.outputSet.subsetOf(child.outputSet) =>

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw the stack overflow in my local environment.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think my PR: #11745 covers all the cases even if we change it from transform to transformUp

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not change transform to transformUp, it will be great if you can post a test case that cause StackOverflow, thanks!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do it tonight. I did not have it now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unable to reproduce the stack overflow now, if we keep the following lines in ColumnPruning:

    // Eliminate no-op Projects
    case p @ Project(projectList, child) if sameOutput(child.output, p.output) => child

If we remove the above line, we will get the stack overflow easily because we can generate duplicate Project. Anyway, I am fine if you want to use transformDown.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no reason we should remove this line.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If transformDown is required here, could you change transform to transformDown? Got it from the comment in the function transform
https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala#L242-L243

// Prunes the unused columns from project list of Project/Aggregate/Expand
case p @ Project(_, p2: Project) if (p2.outputSet -- p.references).nonEmpty =>
p.copy(child = p2.copy(projectList = p2.projectList.filter(p.references.contains)))
Expand Down Expand Up @@ -399,7 +399,7 @@ object ColumnPruning extends Rule[LogicalPlan] {
} else {
p
}
}
})

/** Applies a projection only when the child is producing unnecessary attributes */
private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) =
Expand All @@ -408,6 +408,16 @@ object ColumnPruning extends Rule[LogicalPlan] {
} else {
c
}

/**
* The Project before Filter is not necessary but conflict with PushPredicatesThroughProject,
* so remove it.
*/
private def removeProjectBeforeFilter(plan: LogicalPlan): LogicalPlan = plan transform {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here. We still need to explicitly use transformDown.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We usually use transform in everywhere, even we know that tranformDown is better, for example, all those rules that push down a predicate.

I think it's fine, or we should update all these places.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's still correct if someone change transform to transformUp suddenly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see.

Is that possible there are two continuous Project following the Filter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

two continuous Project will be combined together by other rules.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CollapseProject is called after this rule. Anyway, we can leave it here if no test case failed due to it.

case p1 @ Project(_, f @ Filter(_, p2 @ Project(_, child)))
if p2.outputSet.subsetOf(child.outputSet) =>
p1.copy(child = f.copy(child = child))
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ import scala.collection.JavaConverters._
import com.google.common.util.concurrent.AtomicLongMap

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.util.sideBySide
import org.apache.spark.util.Utils

object RuleExecutor {
protected val timeMap = AtomicLongMap.create[String]()
Expand Down Expand Up @@ -98,7 +100,12 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging {
if (iteration > batch.strategy.maxIterations) {
// Only log if this is a rule that is supposed to run more than once.
if (iteration != 2) {
logInfo(s"Max iterations (${iteration - 1}) reached for batch ${batch.name}")
val message = s"Max iterations (${iteration - 1}) reached for batch ${batch.name}"
if (Utils.isTesting) {
throw new TreeNodeException(curPlan, message, null)
} else {
logWarning(message)
}
}
continue = false
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class AnalysisSuite extends AnalysisTest {
import org.apache.spark.sql.catalyst.analysis.TestRelations._

test("union project *") {
val plan = (1 to 100)
val plan = (1 to 120)
.map(_ => testRelation)
.fold[LogicalPlan](testRelation) { (a, b) =>
a.select(UnresolvedStar(None)).select('a).union(b.select(UnresolvedStar(None)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class ColumnPruningSuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("Column pruning", FixedPoint(100),
PushPredicateThroughProject,
ColumnPruning,
CollapseProject) :: Nil
}
Expand Down Expand Up @@ -133,12 +134,16 @@ class ColumnPruningSuite extends PlanTest {

test("Column pruning on Filter") {
val input = LocalRelation('a.int, 'b.string, 'c.double)
val plan1 = Filter('a > 1, input).analyze
comparePlans(Optimize.execute(plan1), plan1)
val query = Project('a :: Nil, Filter('c > Literal(0.0), input)).analyze
val expected =
Project('a :: Nil,
Filter('c > Literal(0.0),
Project(Seq('a, 'c), input))).analyze
comparePlans(Optimize.execute(query), expected)
comparePlans(Optimize.execute(query), query)
val plan2 = Filter('b > 1, Project(Seq('a, 'b), input)).analyze
val expected2 = Project(Seq('a, 'b), Filter('b > 1, input)).analyze
comparePlans(Optimize.execute(plan2), expected2)
val plan3 = Project(Seq('a), Filter('b > 1, Project(Seq('a, 'b), input))).analyze
val expected3 = Project(Seq('a), Filter('b > 1, input)).analyze
comparePlans(Optimize.execute(plan3), expected3)
}

test("Column pruning on except/intersect/distinct") {
Expand Down Expand Up @@ -297,7 +302,7 @@ class ColumnPruningSuite extends PlanTest {
SortOrder('b, Ascending) :: Nil,
UnspecifiedFrame)).as('window) :: Nil,
'a :: Nil, 'b.asc :: Nil)
.select('a, 'c, 'window).where('window > 1).select('a, 'c).analyze
.where('window > 1).select('a, 'c).analyze
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason why removing .select('a, 'c, 'window)? It seems like the previous one is a better plan, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The select before where help nothing, could be worse (without whole stage codegen), is it really a better plan?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If so, it becomes harder for Optimizer to judge which plan is better. Based on my understanding, the general principle of ColumnPruning is doing the best to add extra Project to prune unnecessary columns or pushing Project down as deep as possible. In this case, .select('a, 'c, 'window) prunes the useless column b.

Could you explain the current strategy for this rule? We might need to add more test cases to check if it does the desired work.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After more thinking, can we modify the existing operator Filter by adding the functionality of Project into Filter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added comment for that.

I don't think that's necessary or good idea to add the functionality of Project into Filter.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. It is easier to understand it now. : )


val optimized = Optimize.execute(originalQuery.analyze)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
package org.apache.spark.sql.catalyst.trees

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions.{Expression, IntegerLiteral, Literal}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}

class RuleExecutorSuite extends SparkFunSuite {
Expand Down Expand Up @@ -49,6 +51,9 @@ class RuleExecutorSuite extends SparkFunSuite {
val batches = Batch("fixedPoint", FixedPoint(10), DecrementLiterals) :: Nil
}

assert(ToFixedPoint.execute(Literal(100)) === Literal(90))
val message = intercept[TreeNodeException[LogicalPlan]] {
ToFixedPoint.execute(Literal(100))
}.getMessage
assert(message.contains("Max iterations (10) reached for batch fixedPoint"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,9 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"udf_round_3",
"view_cast",

// enable this after fixing SPARK-14137
"union20",

// These tests check the VIEW table definition, but Spark handles CREATE VIEW itself and
// generates different View Expanded Text.
"alter_view_as_select",
Expand Down Expand Up @@ -1040,7 +1043,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"union18",
"union19",
"union2",
"union20",
"union22",
"union23",
"union24",
Expand Down