Skip to content
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.analysis.{DecimalPrecision, FunctionRegistry, TypeCheckResult}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.trees.TreePattern.{AVERAGE, TreePattern}
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -51,6 +52,8 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
// Return data type.
override def dataType: DataType = resultType

final override val nodePatterns: Seq[TreePattern] = Seq(AVERAGE)

private lazy val resultType = child.dataType match {
case DecimalType.Fixed(p, s) =>
DecimalType.bounded(p + 4, s + 4)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.trees.TreePattern.{SUM, TreePattern}
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -52,6 +53,8 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForAnsiIntervalOrNumericType(child.dataType, "sum")

final override val nodePatterns: Seq[TreePattern] = Seq(SUM)

private lazy val resultType = child.dataType match {
case DecimalType.Fixed(precision, scale) =>
DecimalType.bounded(precision + 10, scale)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ case class Alias(child: Expression, name: String)(
val nonInheritableMetadataKeys: Seq[String] = Seq.empty)
extends UnaryExpression with NamedExpression {

final override val nodePatterns: Seq[TreePattern] = Seq(ALIAS)

// Alias(Generator, xx) need to be transformed into Generate(generator, ...)
override lazy val resolved =
childrenResolved && checkInputDataTypes().isSuccess && !child.isInstanceOf[Generator]
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan)
override def output: Seq[Attribute] = projectList.map(_.toAttribute)
override def maxRows: Option[Long] = child.maxRows

final override val nodePatterns: Seq[TreePattern] = Seq(PROJECT)

override lazy val resolved: Boolean = {
val hasSpecialExpressions = projectList.exists ( _.collect {
case agg: AggregateExpression => agg
Expand Down Expand Up @@ -124,6 +126,8 @@ case class Generate(
child: LogicalPlan)
extends UnaryNode {

final override val nodePatterns: Seq[TreePattern] = Seq(GENERATE)

lazy val requiredChildOutput: Seq[Attribute] = {
val unrequiredSet = unrequiredChildIndex.toSet
child.output.zipWithIndex.filterNot(t => unrequiredSet.contains(t._2)).map(_._1)
Expand Down Expand Up @@ -207,6 +211,8 @@ case class Intersect(

override def nodeName: String = getClass.getSimpleName + ( if ( isAll ) "All" else "" )

final override val nodePatterns: Seq[TreePattern] = Seq(INTERSECT)

override def output: Seq[Attribute] =
left.output.zip(right.output).map { case (leftAttr, rightAttr) =>
leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable)
Expand Down Expand Up @@ -276,6 +282,8 @@ case class Union(
}
}

final override val nodePatterns: Seq[TreePattern] = Seq(UNION)

/**
* Note the definition has assumption about how union is implemented physically.
*/
Expand Down Expand Up @@ -628,6 +636,7 @@ case class Sort(
override def output: Seq[Attribute] = child.output
override def maxRows: Option[Long] = child.maxRows
override def outputOrdering: Seq[SortOrder] = order
final override val nodePatterns: Seq[TreePattern] = Seq(SORT)
override protected def withNewChildInternal(newChild: LogicalPlan): Sort = copy(child = newChild)
}

Expand Down Expand Up @@ -1189,6 +1198,7 @@ case class Sample(
case class Distinct(child: LogicalPlan) extends UnaryNode {
override def maxRows: Option[Long] = child.maxRows
override def output: Seq[Attribute] = child.output
final override val nodePatterns: Seq[TreePattern] = Seq(DISTINCT_LIKE)
override protected def withNewChildInternal(newChild: LogicalPlan): Distinct =
copy(child = newChild)
}
Expand All @@ -1201,6 +1211,7 @@ abstract class RepartitionOperation extends UnaryNode {
def numPartitions: Int
override final def maxRows: Option[Long] = child.maxRows
override def output: Seq[Attribute] = child.output
final override val nodePatterns: Seq[TreePattern] = Seq(REPARTITION_OPERATION)
def partitioning: Partitioning
}

Expand All @@ -1221,6 +1232,7 @@ case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan)
case _ => RoundRobinPartitioning(numPartitions)
}
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: unnecessary change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

override protected def withNewChildInternal(newChild: LogicalPlan): Repartition =
copy(child = newChild)
}
Expand Down Expand Up @@ -1300,6 +1312,7 @@ case class Deduplicate(
child: LogicalPlan) extends UnaryNode {
override def maxRows: Option[Long] = child.maxRows
override def output: Seq[Attribute] = child.output
final override val nodePatterns: Seq[TreePattern] = Seq(DISTINCT_LIKE)
override protected def withNewChildInternal(newChild: LogicalPlan): Deduplicate =
copy(child = newChild)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,33 +88,57 @@ object RuleIdCollection {
"org.apache.spark.sql.catalyst.analysis.UpdateOuterReferences" ::
// Catalyst Optimizer rules
"org.apache.spark.sql.catalyst.optimizer.BooleanSimplification" ::
"org.apache.spark.sql.catalyst.optimizer.CollapseProject" ::
"org.apache.spark.sql.catalyst.optimizer.CollapseRepartition" ::
"org.apache.spark.sql.catalyst.optimizer.CollapseWindow" ::
"org.apache.spark.sql.catalyst.optimizer.ColumnPruning" ::
"org.apache.spark.sql.catalyst.optimizer.CombineConcats" ::
"org.apache.spark.sql.catalyst.optimizer.CombineFilters" ::
"org.apache.spark.sql.catalyst.optimizer.CombineUnions" ::
"org.apache.spark.sql.catalyst.optimizer.ConstantFolding" ::
"org.apache.spark.sql.catalyst.optimizer.ConstantPropagation" ::
"org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation" ::
"org.apache.spark.sql.catalyst.optimizer.CostBasedJoinReorder" ::
"org.apache.spark.sql.catalyst.optimizer.DecimalAggregates" ::
"org.apache.spark.sql.catalyst.optimizer.EliminateAggregateFilter" ::
"org.apache.spark.sql.catalyst.optimizer.EliminateLimits" ::
"org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin" ::
"org.apache.spark.sql.catalyst.optimizer.LikeSimplification" ::
"org.apache.spark.sql.catalyst.optimizer.LimitPushDown" ::
"org.apache.spark.sql.catalyst.optimizer.LimitPushDownThroughWindow" ::
"org.apache.spark.sql.catalyst.optimizer.NullPropagation" ::
"org.apache.spark.sql.catalyst.optimizer.OptimizeCsvJsonExprs" ::
"org.apache.spark.sql.catalyst.optimizer.OptimizeIn" ::
"org.apache.spark.sql.catalyst.optimizer.Optimizer$OptimizeSubqueries" ::
"org.apache.spark.sql.catalyst.optimizer.OptimizeRepartition" ::
"org.apache.spark.sql.catalyst.optimizer.OptimizeWindowFunctions" ::
"org.apache.spark.sql.catalyst.optimizer.OptimizeUpdateFields"::
"org.apache.spark.sql.catalyst.optimizer.PropagateEmptyRelation" ::
"org.apache.spark.sql.catalyst.optimizer.PruneFilters" ::
"org.apache.spark.sql.catalyst.optimizer.PushDownLeftSemiAntiJoin" ::
"org.apache.spark.sql.catalyst.optimizer.PushExtraPredicateThroughJoin" ::
"org.apache.spark.sql.catalyst.optimizer.PushFoldableIntoBranches" ::
"org.apache.spark.sql.catalyst.optimizer.PushLeftSemiLeftAntiThroughJoin" ::
"org.apache.spark.sql.catalyst.optimizer.RemoveDispensableExpressions" ::
"org.apache.spark.sql.catalyst.optimizer.RemoveLiteralFromGroupExpressions" ::
"org.apache.spark.sql.catalyst.optimizer.RemoveNoopOperators" ::
"org.apache.spark.sql.catalyst.optimizer.RemoveRedundantAggregates" ::
"org.apache.spark.sql.catalyst.optimizer.RemoveRepetitionFromGroupExpressions" ::
"org.apache.spark.sql.catalyst.optimizer.ReorderAssociativeOperator" ::
"org.apache.spark.sql.catalyst.optimizer.ReorderJoin" ::
"org.apache.spark.sql.catalyst.optimizer.ReplaceExceptWithAntiJoin" ::
"org.apache.spark.sql.catalyst.optimizer.ReplaceExceptWithFilter" ::
"org.apache.spark.sql.catalyst.optimizer.ReplaceDistinctWithAggregate" ::
"org.apache.spark.sql.catalyst.optimizer.ReplaceNullWithFalseInPredicate" ::
"org.apache.spark.sql.catalyst.optimizer.ReplaceIntersectWithSemiJoin" ::
"org.apache.spark.sql.catalyst.optimizer.RewriteExceptAll" ::
"org.apache.spark.sql.catalyst.optimizer.RewriteIntersectAll" ::
"org.apache.spark.sql.catalyst.optimizer.SimplifyBinaryComparison" ::
"org.apache.spark.sql.catalyst.optimizer.SimplifyCaseConversionExpressions" ::
"org.apache.spark.sql.catalyst.optimizer.SimplifyCasts" ::
"org.apache.spark.sql.catalyst.optimizer.SimplifyConditionals" ::
"org.apache.spark.sql.catalyst.optimizer.SimplifyConditionalsInPredicate" ::
"org.apache.spark.sql.catalyst.optimizer.TransposeWindow" ::
"org.apache.spark.sql.catalyst.optimizer.UnwrapCastInBinaryComparison" :: Nil
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@ object TreePattern extends Enumeration {

// Enum Ids start from 0.
// Expression patterns (alphabetically ordered)
val AND_OR: Value = Value(0)
val ALIAS: Value = Value(0)
val AND_OR: Value = Value
val ATTRIBUTE_REFERENCE: Value = Value
val AVERAGE: Value = Value
val BINARY_ARITHMETIC: Value = Value
val BINARY_COMPARISON: Value = Value
val CASE_WHEN: Value = Value
Expand All @@ -36,10 +38,12 @@ object TreePattern extends Enumeration {
val EXISTS_SUBQUERY = Value
val EXPRESSION_WITH_RANDOM_SEED: Value = Value
val EXTRACT_VALUE: Value = Value
val GENERATE: Value = Value
val IF: Value = Value
val IN: Value = Value
val IN_SUBQUERY: Value = Value
val INSET: Value = Value
val INTERSECT: Value = Value
val JSON_TO_STRUCT: Value = Value
val LIKE_FAMLIY: Value = Value
val LIST_SUBQUERY: Value = Value
Expand All @@ -50,13 +54,16 @@ object TreePattern extends Enumeration {
val OUTER_REFERENCE: Value = Value
val PLAN_EXPRESSION: Value = Value
val SCALAR_SUBQUERY: Value = Value
val SORT: Value = Value
val SUM: Value = Value
val TRUE_OR_FALSE_LITERAL: Value = Value
val WINDOW_EXPRESSION: Value = Value
val UNARY_POSITIVE: Value = Value
val UPPER_OR_LOWER: Value = Value

// Logical plan patterns (alphabetically ordered)
val AGGREGATE: Value = Value
val DISTINCT_LIKE: Value = Value
val EXCEPT: Value = Value
val FILTER: Value = Value
val INNER_LIKE_JOIN: Value = Value
Expand All @@ -65,6 +72,9 @@ object TreePattern extends Enumeration {
val LIMIT: Value = Value
val LOCAL_RELATION: Value = Value
val NATURAL_LIKE_JOIN: Value = Value
val PROJECT: Value = Value
val OUTER_JOIN: Value = Value
val REPARTITION_OPERATION: Value = Value
val UNION: Value = Value
val WINDOW: Value = Value
}