diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 8ae24e51351db..82ad2df11a3ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.analysis.{DecimalPrecision, FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.trees.TreePattern.{AVERAGE, TreePattern} import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -51,6 +52,8 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit // Return data type. override def dataType: DataType = resultType + final override val nodePatterns: Seq[TreePattern] = Seq(AVERAGE) + private lazy val resultType = child.dataType match { case DecimalType.Fixed(p, s) => DecimalType.bounded(p + 4, s + 4) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index 31150fc31ba1c..16cd9d76f7b14 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.trees.TreePattern.{SUM, TreePattern} import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.internal.SQLConf @@ -52,6 +53,8 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast override def checkInputDataTypes(): TypeCheckResult = TypeUtils.checkForAnsiIntervalOrNumericType(child.dataType, "sum") + final override val nodePatterns: Seq[TreePattern] = Seq(SUM) + private lazy val resultType = child.dataType match { case DecimalType.Fixed(precision, scale) => DecimalType.bounded(precision + 10, scale) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 69f7d24d04be2..52487d4decb68 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -158,6 +158,8 @@ case class Alias(child: Expression, name: String)( val nonInheritableMetadataKeys: Seq[String] = Seq.empty) extends UnaryExpression with NamedExpression { + final override val nodePatterns: Seq[TreePattern] = Seq(ALIAS) + // Alias(Generator, xx) need to be transformed into Generate(generator, ...) override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess && !child.isInstanceOf[Generator] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 07c86a7e09bb8..19e9312715db0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -27,7 +27,8 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.catalyst.trees.TreePattern.PLAN_EXPRESSION +import org.apache.spark.sql.catalyst.trees.AlwaysProcess +import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -377,7 +378,8 @@ object EliminateDistinct extends Rule[LogicalPlan] { * This rule should be applied before RewriteDistinctAggregates. */ object EliminateAggregateFilter extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan transformExpressions { + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformExpressionsWithPruning( + _.containsAllPatterns(TRUE_OR_FALSE_LITERAL), ruleId) { case ae @ AggregateExpression(_, _, _, Some(Literal.TrueLiteral), _) => ae.copy(filter = None) case AggregateExpression(af: DeclarativeAggregate, _, _, Some(Literal.FalseLiteral), _) => @@ -445,6 +447,9 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] { * (self) join or to prevent the removal of top-level subquery attributes. */ private def removeRedundantAliases(plan: LogicalPlan, excluded: AttributeSet): LogicalPlan = { + if (!plan.containsPattern(ALIAS)) { + return plan + } plan match { // We want to keep the same output attributes for subqueries. This means we cannot remove // the aliases that produce these attributes @@ -506,7 +511,8 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] { * only goal is to keep distinct values, while its parent aggregate would ignore duplicate values. */ object RemoveRedundantAggregates extends Rule[LogicalPlan] with AliasHelper { - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning( + _.containsPattern(AGGREGATE), ruleId) { case upper @ Aggregate(_, _, lower: Aggregate) if lowerIsRedundant(upper, lower) => val aliasMap = getAliasMap(lower) @@ -545,7 +551,8 @@ object RemoveRedundantAggregates extends Rule[LogicalPlan] with AliasHelper { * Remove no-op operators from the query plan that do not make any modifications. */ object RemoveNoopOperators extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning( + _.containsAnyPattern(PROJECT, WINDOW), ruleId) { // Eliminate no-op Projects case p @ Project(_, child) if child.sameOutput(p) => child @@ -597,7 +604,8 @@ object RemoveNoopUnion extends Rule[LogicalPlan] { } } - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning( + _.containsAllPatterns(DISTINCT_LIKE, UNION)) { case d @ Distinct(u: Union) => d.withNewChildren(Seq(simplifyUnion(u))) case d @ Deduplicate(_, u: Union) => @@ -648,7 +656,8 @@ object LimitPushDown extends Rule[LogicalPlan] { } } - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( + _.containsPattern(LIMIT), ruleId) { // Adding extra Limits below UNION ALL for children which are not Limit or do not have Limit // descendants whose maxRow is larger. This heuristic is valid assuming there does not exist any // Limit push-down rule that is unable to infer the value of maxRows. @@ -745,7 +754,8 @@ object PushProjectionThroughUnion extends Rule[LogicalPlan] with PredicateHelper */ object ColumnPruning extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = removeProjectBeforeFilter(plan transform { + def apply(plan: LogicalPlan): LogicalPlan = removeProjectBeforeFilter( + plan.transformWithPruning(AlwaysProcess.fn, ruleId) { // Prunes the unused columns from project list of Project/Aggregate/Expand case p @ Project(_, p2: Project) if !p2.outputSet.subsetOf(p.references) => p.copy(child = p2.copy(projectList = p2.projectList.filter(p.references.contains))) @@ -863,7 +873,8 @@ object ColumnPruning extends Rule[LogicalPlan] { */ object CollapseProject extends Rule[LogicalPlan] with AliasHelper { - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning( + _.containsPattern(PROJECT), ruleId) { case p1 @ Project(_, p2: Project) => if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList)) { p1 @@ -921,7 +932,8 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper { * Combines adjacent [[RepartitionOperation]] operators */ object CollapseRepartition extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning( + _.containsPattern(REPARTITION_OPERATION), ruleId) { // Case 1: When a Repartition has a child of Repartition or RepartitionByExpression, // 1) When the top node does not enable the shuffle (i.e., coalesce API), but the child // enables the shuffle. Returns the child node if the last numPartitions is bigger; @@ -943,7 +955,8 @@ object CollapseRepartition extends Rule[LogicalPlan] { * and user not specify. */ object OptimizeRepartition extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.transform { + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( + _.containsPattern(REPARTITION_OPERATION), ruleId) { case r @ RepartitionByExpression(partitionExpressions, _, numPartitions) if partitionExpressions.nonEmpty && partitionExpressions.forall(_.foldable) && numPartitions.isEmpty => @@ -955,7 +968,8 @@ object OptimizeRepartition extends Rule[LogicalPlan] { * Replaces first(col) to nth_value(col, 1) for better performance. */ object OptimizeWindowFunctions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressionsWithPruning( + _.containsPattern(WINDOW_EXPRESSION), ruleId) { case we @ WindowExpression(AggregateExpression(first: First, _, _, _, _), WindowSpecDefinition(_, orderSpec, frameSpecification: SpecifiedWindowFrame)) if orderSpec.nonEmpty && frameSpecification.frameType == RowFrame && @@ -972,7 +986,8 @@ object OptimizeWindowFunctions extends Rule[LogicalPlan] { * independent and are of the same window function type, collapse into the parent. */ object CollapseWindow extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning( + _.containsPattern(WINDOW), ruleId) { case w1 @ Window(we1, ps1, os1, w2 @ Window(we2, ps2, os2, grandChild)) if ps1 == ps2 && os1 == os2 && w1.references.intersect(w2.windowOutputSet).isEmpty && we1.nonEmpty && we2.nonEmpty && @@ -995,7 +1010,8 @@ object TransposeWindow extends Rule[LogicalPlan] { }) } - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning( + _.containsPattern(WINDOW), ruleId) { case w1 @ Window(we1, ps1, os1, w2 @ Window(we2, ps2, os2, grandChild)) if w1.references.intersect(w2.windowOutputSet).isEmpty && w1.expressions.forall(_.deterministic) && @@ -1010,7 +1026,8 @@ object TransposeWindow extends Rule[LogicalPlan] { * by this [[Generate]] can be removed earlier - before joins and in data sources. */ object InferFiltersFromGenerate extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning( + _.containsPattern(GENERATE)) { // This rule does not infer filters from foldable expressions to avoid constant filters // like 'size([1, 2, 3]) > 0'. These do not show up in child's constraints and // then the idempotence will break. @@ -1060,7 +1077,8 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] } } - private def inferFilters(plan: LogicalPlan): LogicalPlan = plan transform { + private def inferFilters(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( + _.containsAnyPattern(FILTER, JOIN)) { case filter @ Filter(condition, child) => val newFilters = filter.constraints -- (child.constraints ++ splitConjunctivePredicates(condition)) @@ -1123,7 +1141,8 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] * Combines all adjacent [[Union]] operators into a single [[Union]]. */ object CombineUnions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformDownWithPruning( + _.containsAnyPattern(UNION, DISTINCT_LIKE), ruleId) { case u: Union => flattenUnion(u, false) case Distinct(u: Union) => Distinct(flattenUnion(u, true)) // Only handle distinct-like 'Deduplicate', where the keys == output @@ -1167,7 +1186,8 @@ object CombineUnions extends Rule[LogicalPlan] { * one conjunctive predicate. */ object CombineFilters extends Rule[LogicalPlan] with PredicateHelper { - def apply(plan: LogicalPlan): LogicalPlan = plan transform applyLocally + def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( + _.containsPattern(FILTER), ruleId)(applyLocally) val applyLocally: PartialFunction[LogicalPlan, LogicalPlan] = { // The query execution/optimization does not guarantee the expressions are evaluated in order. @@ -1202,7 +1222,8 @@ object CombineFilters extends Rule[LogicalPlan] with PredicateHelper { * function is order irrelevant */ object EliminateSorts extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform applyLocally + def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( + _.containsPattern(SORT))(applyLocally) private val applyLocally: PartialFunction[LogicalPlan, LogicalPlan] = { case s @ Sort(orders, _, child) if orders.isEmpty || orders.exists(_.child.foldable) => @@ -1221,11 +1242,16 @@ object EliminateSorts extends Rule[LogicalPlan] { g.copy(child = recursiveRemoveSort(originChild)) } - private def recursiveRemoveSort(plan: LogicalPlan): LogicalPlan = plan match { - case Sort(_, _, child) => recursiveRemoveSort(child) - case other if canEliminateSort(other) => - other.withNewChildren(other.children.map(recursiveRemoveSort)) - case _ => plan + private def recursiveRemoveSort(plan: LogicalPlan): LogicalPlan = { + if (!plan.containsPattern(SORT)) { + return plan + } + plan match { + case Sort(_, _, child) => recursiveRemoveSort(child) + case other if canEliminateSort(other) => + other.withNewChildren(other.children.map(recursiveRemoveSort)) + case _ => plan + } } private def canEliminateSort(plan: LogicalPlan): Boolean = plan match { @@ -1264,7 +1290,8 @@ object EliminateSorts extends Rule[LogicalPlan] { * 3) by eliminating the always-true conditions given the constraints on the child's output. */ object PruneFilters extends Rule[LogicalPlan] with PredicateHelper { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( + _.containsPattern(FILTER), ruleId) { // If the filter condition always evaluate to true, remove the filter. case Filter(Literal(true, BooleanType), child) => child // If the filter condition always evaluate to null or false, @@ -1620,7 +1647,8 @@ object EliminateLimits extends Rule[LogicalPlan] { limitExpr.foldable && child.maxRows.exists { _ <= limitExpr.eval().asInstanceOf[Int] } } - def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformDownWithPruning( + _.containsPattern(LIMIT), ruleId) { case Limit(l, child) if canEliminate(l, child) => child case GlobalLimit(l, child) if canEliminate(l, child) => @@ -1667,7 +1695,7 @@ object CheckCartesianProducts extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = if (conf.crossJoinEnabled) { plan - } else plan transform { + } else plan.transformWithPruning(_.containsAnyPattern(INNER_LIKE_JOIN, OUTER_JOIN)) { case j @ Join(left, right, Inner | LeftOuter | RightOuter | FullOuter, _, _) if isCartesianProduct(j) => throw new AnalysisException( @@ -1695,8 +1723,10 @@ object DecimalAggregates extends Rule[LogicalPlan] { /** Maximum number of decimal digits representable precisely in a Double */ private val MAX_DOUBLE_DIGITS = 15 - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case q: LogicalPlan => q transformExpressionsDown { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( + _.containsAnyPattern(SUM, AVERAGE), ruleId) { + case q: LogicalPlan => q.transformExpressionsDownWithPruning( + _.containsAnyPattern(SUM, AVERAGE), ruleId) { case we @ WindowExpression(ae @ AggregateExpression(af, _, _, _, _), _) => af match { case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS => MakeDecimal(we.copy(windowFunction = ae.copy(aggregateFunction = Sum(UnscaledValue(e)))), @@ -1732,7 +1762,8 @@ object DecimalAggregates extends Rule[LogicalPlan] { * another `LocalRelation`. */ object ConvertToLocalRelation extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( + _.containsPattern(LOCAL_RELATION), ruleId) { case Project(projectList, LocalRelation(output, data, isStreaming)) if !projectList.exists(hasUnevaluableExpr) => val projection = new InterpretedMutableProjection(projectList, output) @@ -1761,7 +1792,8 @@ object ConvertToLocalRelation extends Rule[LogicalPlan] { * }}} */ object ReplaceDistinctWithAggregate extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( + _.containsPattern(DISTINCT_LIKE), ruleId) { case Distinct(child) => Aggregate(child.output, child.output, child) } } @@ -1805,7 +1837,8 @@ object ReplaceDeduplicateWithAggregate extends Rule[LogicalPlan] { * join conditions will be incorrect. */ object ReplaceIntersectWithSemiJoin extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( + _.containsPattern(INTERSECT), ruleId) { case Intersect(left, right, false) => assert(left.output.size == right.output.size) val joinCond = left.output.zip(right.output).map { case (l, r) => EqualNullSafe(l, r) } @@ -1826,7 +1859,8 @@ object ReplaceIntersectWithSemiJoin extends Rule[LogicalPlan] { * join conditions will be incorrect. */ object ReplaceExceptWithAntiJoin extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( + _.containsPattern(EXCEPT), ruleId) { case Except(left, right, false) => assert(left.output.size == right.output.size) val joinCond = left.output.zip(right.output).map { case (l, r) => EqualNullSafe(l, r) } @@ -1866,7 +1900,8 @@ object ReplaceExceptWithAntiJoin extends Rule[LogicalPlan] { */ object RewriteExceptAll extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( + _.containsPattern(EXCEPT), ruleId) { case Except(left, right, true) => assert(left.output.size == right.output.size) @@ -1923,7 +1958,8 @@ object RewriteExceptAll extends Rule[LogicalPlan] { * }}} */ object RewriteIntersectAll extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( + _.containsPattern(INTERSECT), ruleId) { case Intersect(left, right, true) => assert(left.output.size == right.output.size) @@ -1975,7 +2011,8 @@ object RewriteIntersectAll extends Rule[LogicalPlan] { * but only makes the grouping key bigger. */ object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( + _.containsPattern(AGGREGATE), ruleId) { case a @ Aggregate(grouping, _, _) if grouping.nonEmpty => val newGrouping = grouping.filter(!_.foldable) if (newGrouping.nonEmpty) { @@ -1994,7 +2031,8 @@ object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] { * but only makes the grouping key bigger. */ object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( + _.containsPattern(AGGREGATE), ruleId) { case a @ Aggregate(grouping, _, _) if grouping.size > 1 => val newGrouping = ExpressionSet(grouping).toSeq if (newGrouping.size == grouping.size) { @@ -2014,7 +2052,8 @@ object OptimizeLimitZero extends Rule[LogicalPlan] { private def empty(plan: LogicalPlan) = LocalRelation(plan.output, data = Seq.empty, isStreaming = plan.isStreaming) - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning( + _.containsAllPatterns(LIMIT, LITERAL)) { // Nodes below GlobalLimit or LocalLimit can be pruned if the limit value is zero (0). // Any subtree in the logical plan that has GlobalLimit 0 or LocalLimit 0 as its root is // semantically equivalent to an empty relation. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index d3c5b5125b445..88a58fda1fa6a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -127,6 +127,8 @@ case class Generate( child: LogicalPlan) extends UnaryNode { + final override val nodePatterns: Seq[TreePattern] = Seq(GENERATE) + lazy val requiredChildOutput: Seq[Attribute] = { val unrequiredSet = unrequiredChildIndex.toSet child.output.zipWithIndex.filterNot(t => unrequiredSet.contains(t._2)).map(_._1) @@ -211,6 +213,8 @@ case class Intersect( override def nodeName: String = getClass.getSimpleName + ( if ( isAll ) "All" else "" ) + final override val nodePatterns: Seq[TreePattern] = Seq(INTERSECT) + override def output: Seq[Attribute] = left.output.zip(right.output).map { case (leftAttr, rightAttr) => leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable) @@ -280,6 +284,8 @@ case class Union( } } + final override val nodePatterns: Seq[TreePattern] = Seq(UNION) + /** * Note the definition has assumption about how union is implemented physically. */ @@ -632,6 +638,7 @@ case class Sort( override def output: Seq[Attribute] = child.output override def maxRows: Option[Long] = child.maxRows override def outputOrdering: Seq[SortOrder] = order + final override val nodePatterns: Seq[TreePattern] = Seq(SORT) override protected def withNewChildInternal(newChild: LogicalPlan): Sort = copy(child = newChild) } @@ -1203,6 +1210,7 @@ case class Sample( case class Distinct(child: LogicalPlan) extends UnaryNode { override def maxRows: Option[Long] = child.maxRows override def output: Seq[Attribute] = child.output + final override val nodePatterns: Seq[TreePattern] = Seq(DISTINCT_LIKE) override protected def withNewChildInternal(newChild: LogicalPlan): Distinct = copy(child = newChild) } @@ -1215,6 +1223,7 @@ abstract class RepartitionOperation extends UnaryNode { def numPartitions: Int override final def maxRows: Option[Long] = child.maxRows override def output: Seq[Attribute] = child.output + final override val nodePatterns: Seq[TreePattern] = Seq(REPARTITION_OPERATION) def partitioning: Partitioning } @@ -1314,6 +1323,7 @@ case class Deduplicate( child: LogicalPlan) extends UnaryNode { override def maxRows: Option[Long] = child.maxRows override def output: Seq[Attribute] = child.output + final override val nodePatterns: Seq[TreePattern] = Seq(DISTINCT_LIKE) override protected def withNewChildInternal(newChild: LogicalPlan): Deduplicate = copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala index 62f09d02ea146..605b57e46fc10 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala @@ -88,38 +88,62 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.analysis.UpdateOuterReferences" :: // Catalyst Optimizer rules "org.apache.spark.sql.catalyst.optimizer.BooleanSimplification" :: + "org.apache.spark.sql.catalyst.optimizer.CollapseProject" :: + "org.apache.spark.sql.catalyst.optimizer.CollapseRepartition" :: + "org.apache.spark.sql.catalyst.optimizer.CollapseWindow" :: + "org.apache.spark.sql.catalyst.optimizer.ColumnPruning" :: "org.apache.spark.sql.catalyst.optimizer.CombineConcats" :: + "org.apache.spark.sql.catalyst.optimizer.CombineFilters" :: "org.apache.spark.sql.catalyst.optimizer.CombineTypedFilters" :: + "org.apache.spark.sql.catalyst.optimizer.CombineUnions" :: "org.apache.spark.sql.catalyst.optimizer.ConstantFolding" :: "org.apache.spark.sql.catalyst.optimizer.ConstantPropagation" :: + "org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation" :: "org.apache.spark.sql.catalyst.optimizer.CostBasedJoinReorder" :: + "org.apache.spark.sql.catalyst.optimizer.DecimalAggregates" :: + "org.apache.spark.sql.catalyst.optimizer.EliminateAggregateFilter" :: + "org.apache.spark.sql.catalyst.optimizer.EliminateLimits" :: "org.apache.spark.sql.catalyst.optimizer.EliminateMapObjects" :: "org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin" :: "org.apache.spark.sql.catalyst.optimizer.EliminateSerialization" :: "org.apache.spark.sql.catalyst.optimizer.LikeSimplification" :: + "org.apache.spark.sql.catalyst.optimizer.LimitPushDown" :: "org.apache.spark.sql.catalyst.optimizer.LimitPushDownThroughWindow" :: "org.apache.spark.sql.catalyst.optimizer.NullPropagation" :: "org.apache.spark.sql.catalyst.optimizer.ObjectSerializerPruning" :: "org.apache.spark.sql.catalyst.optimizer.OptimizeCsvJsonExprs" :: "org.apache.spark.sql.catalyst.optimizer.OptimizeIn" :: "org.apache.spark.sql.catalyst.optimizer.Optimizer$OptimizeSubqueries" :: + "org.apache.spark.sql.catalyst.optimizer.OptimizeRepartition" :: + "org.apache.spark.sql.catalyst.optimizer.OptimizeWindowFunctions" :: "org.apache.spark.sql.catalyst.optimizer.OptimizeUpdateFields":: "org.apache.spark.sql.catalyst.optimizer.PropagateEmptyRelation" :: + "org.apache.spark.sql.catalyst.optimizer.PruneFilters" :: "org.apache.spark.sql.catalyst.optimizer.PushDownLeftSemiAntiJoin" :: "org.apache.spark.sql.catalyst.optimizer.PushExtraPredicateThroughJoin" :: "org.apache.spark.sql.catalyst.optimizer.PushFoldableIntoBranches" :: "org.apache.spark.sql.catalyst.optimizer.PushLeftSemiLeftAntiThroughJoin" :: "org.apache.spark.sql.catalyst.optimizer.ReassignLambdaVariableID" :: "org.apache.spark.sql.catalyst.optimizer.RemoveDispensableExpressions" :: + "org.apache.spark.sql.catalyst.optimizer.RemoveLiteralFromGroupExpressions" :: + "org.apache.spark.sql.catalyst.optimizer.RemoveNoopOperators" :: + "org.apache.spark.sql.catalyst.optimizer.RemoveRedundantAggregates" :: + "org.apache.spark.sql.catalyst.optimizer.RemoveRepetitionFromGroupExpressions" :: "org.apache.spark.sql.catalyst.optimizer.ReorderAssociativeOperator" :: "org.apache.spark.sql.catalyst.optimizer.ReorderJoin" :: + "org.apache.spark.sql.catalyst.optimizer.ReplaceExceptWithAntiJoin" :: "org.apache.spark.sql.catalyst.optimizer.ReplaceExceptWithFilter" :: + "org.apache.spark.sql.catalyst.optimizer.ReplaceDistinctWithAggregate" :: "org.apache.spark.sql.catalyst.optimizer.ReplaceNullWithFalseInPredicate" :: + "org.apache.spark.sql.catalyst.optimizer.ReplaceIntersectWithSemiJoin" :: + "org.apache.spark.sql.catalyst.optimizer.RewriteExceptAll" :: + "org.apache.spark.sql.catalyst.optimizer.RewriteIntersectAll" :: "org.apache.spark.sql.catalyst.optimizer.SimplifyBinaryComparison" :: "org.apache.spark.sql.catalyst.optimizer.SimplifyCaseConversionExpressions" :: "org.apache.spark.sql.catalyst.optimizer.SimplifyCasts" :: "org.apache.spark.sql.catalyst.optimizer.SimplifyConditionals" :: "org.apache.spark.sql.catalyst.optimizer.SimplifyConditionalsInPredicate" :: + "org.apache.spark.sql.catalyst.optimizer.TransposeWindow" :: "org.apache.spark.sql.catalyst.optimizer.UnwrapCastInBinaryComparison" :: Nil } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index d1ba832114f18..40ef7cb592daa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -23,9 +23,11 @@ object TreePattern extends Enumeration { // Enum Ids start from 0. // Expression patterns (alphabetically ordered) - val AND_OR: Value = Value(0) + val ALIAS: Value = Value(0) + val AND_OR: Value = Value val ATTRIBUTE_REFERENCE: Value = Value val APPEND_COLUMNS: Value = Value + val AVERAGE: Value = Value val BINARY_ARITHMETIC: Value = Value val BINARY_COMPARISON: Value = Value val BOOL_AGG: Value = Value @@ -41,10 +43,12 @@ object TreePattern extends Enumeration { val EXISTS_SUBQUERY = Value val EXPRESSION_WITH_RANDOM_SEED: Value = Value val EXTRACT_VALUE: Value = Value + val GENERATE: Value = Value val IF: Value = Value val IN: Value = Value val IN_SUBQUERY: Value = Value val INSET: Value = Value + val INTERSECT: Value = Value val JSON_TO_STRUCT: Value = Value val LAMBDA_VARIABLE: Value = Value val LIKE_FAMLIY: Value = Value @@ -59,6 +63,8 @@ object TreePattern extends Enumeration { val PLAN_EXPRESSION: Value = Value val RUNTIME_REPLACEABLE: Value = Value val SCALAR_SUBQUERY: Value = Value + val SORT: Value = Value + val SUM: Value = Value val TRUE_OR_FALSE_LITERAL: Value = Value val WINDOW_EXPRESSION: Value = Value val UNARY_POSITIVE: Value = Value @@ -66,6 +72,7 @@ object TreePattern extends Enumeration { // Logical plan patterns (alphabetically ordered) val AGGREGATE: Value = Value + val DISTINCT_LIKE: Value = Value val EXCEPT: Value = Value val FILTER: Value = Value val INNER_LIKE_JOIN: Value = Value @@ -76,6 +83,8 @@ object TreePattern extends Enumeration { val NATURAL_LIKE_JOIN: Value = Value val OUTER_JOIN: Value = Value val PROJECT: Value = Value + val REPARTITION_OPERATION: Value = Value + val UNION: Value = Value val TYPED_FILTER: Value = Value val WINDOW: Value = Value }