Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
update
  • Loading branch information
sigmod committed May 3, 2021
commit 6a35a882c39f4df302df0134ca6bf75491cf1cbf
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.trees.TreePattern.PLAN_EXPRESSION
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -863,7 +863,8 @@ object ColumnPruning extends Rule[LogicalPlan] {
*/
object CollapseProject extends Rule[LogicalPlan] with AliasHelper {

def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
_.containsPattern(PROJECT), ruleId) {
case p1 @ Project(_, p2: Project) =>
if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList)) {
p1
Expand Down Expand Up @@ -921,7 +922,8 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper {
* Combines adjacent [[RepartitionOperation]] operators
*/
object CollapseRepartition extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
_.containsPattern(REPARTITION_OPERATION)) {
// Case 1: When a Repartition has a child of Repartition or RepartitionByExpression,
// 1) When the top node does not enable the shuffle (i.e., coalesce API), but the child
// enables the shuffle. Returns the child node if the last numPartitions is bigger;
Expand Down Expand Up @@ -972,7 +974,8 @@ object OptimizeWindowFunctions extends Rule[LogicalPlan] {
* independent and are of the same window function type, collapse into the parent.
*/
object CollapseWindow extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
_.containsPattern(WINDOW), ruleId) {
case w1 @ Window(we1, ps1, os1, w2 @ Window(we2, ps2, os2, grandChild))
if ps1 == ps2 && os1 == os2 && w1.references.intersect(w2.windowOutputSet).isEmpty &&
we1.nonEmpty && we2.nonEmpty &&
Expand Down Expand Up @@ -1123,7 +1126,8 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan]
* Combines all adjacent [[Union]] operators into a single [[Union]].
*/
object CombineUnions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformDownWithPruning(
_.containsAnyPattern(UNION, DISTINCT_LIKE), ruleId) {
case u: Union => flattenUnion(u, false)
case Distinct(u: Union) => Distinct(flattenUnion(u, true))
// Only handle distinct-like 'Deduplicate', where the keys == output
Expand Down Expand Up @@ -1167,7 +1171,8 @@ object CombineUnions extends Rule[LogicalPlan] {
* one conjunctive predicate.
*/
object CombineFilters extends Rule[LogicalPlan] with PredicateHelper {
def apply(plan: LogicalPlan): LogicalPlan = plan transform applyLocally
def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
_.containsPattern(FILTER), ruleId)(applyLocally)

val applyLocally: PartialFunction[LogicalPlan, LogicalPlan] = {
// The query execution/optimization does not guarantee the expressions are evaluated in order.
Expand Down Expand Up @@ -1667,7 +1672,7 @@ object CheckCartesianProducts extends Rule[LogicalPlan] with PredicateHelper {
def apply(plan: LogicalPlan): LogicalPlan =
if (conf.crossJoinEnabled) {
plan
} else plan transform {
} else plan.transformWithPruning(_.containsAnyPattern(INNER_LIKE_JOIN, OUTER_JOIN)) {
case j @ Join(left, right, Inner | LeftOuter | RightOuter | FullOuter, _, _)
if isCartesianProduct(j) =>
throw new AnalysisException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan)
override def output: Seq[Attribute] = projectList.map(_.toAttribute)
override def maxRows: Option[Long] = child.maxRows

final override val nodePatterns: Seq[TreePattern] = Seq(PROJECT)

override lazy val resolved: Boolean = {
val hasSpecialExpressions = projectList.exists ( _.collect {
case agg: AggregateExpression => agg
Expand Down Expand Up @@ -274,6 +276,8 @@ case class Union(
}
}

final override val nodePatterns: Seq[TreePattern] = Seq(UNION)

/**
* Note the definition has assumption about how union is implemented physically.
*/
Expand Down Expand Up @@ -833,6 +837,8 @@ case class Window(

override def producedAttributes: AttributeSet = windowOutputSet

final override val nodePatterns: Seq[TreePattern] = Seq(WINDOW)

def windowOutputSet: AttributeSet = AttributeSet(windowExpressions.map(_.toAttribute))

override protected def withNewChildInternal(newChild: LogicalPlan): Window =
Expand Down Expand Up @@ -1179,6 +1185,7 @@ case class Sample(
case class Distinct(child: LogicalPlan) extends UnaryNode {
override def maxRows: Option[Long] = child.maxRows
override def output: Seq[Attribute] = child.output
final override val nodePatterns: Seq[TreePattern] = Seq(DISTINCT_LIKE)
override protected def withNewChildInternal(newChild: LogicalPlan): Distinct =
copy(child = newChild)
}
Expand All @@ -1191,6 +1198,7 @@ abstract class RepartitionOperation extends UnaryNode {
def numPartitions: Int
override final def maxRows: Option[Long] = child.maxRows
override def output: Seq[Attribute] = child.output
final override val nodePatterns: Seq[TreePattern] = Seq(REPARTITION_OPERATION)
def partitioning: Partitioning
}

Expand All @@ -1211,6 +1219,7 @@ case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan)
case _ => RoundRobinPartitioning(numPartitions)
}
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: unnecessary change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

override protected def withNewChildInternal(newChild: LogicalPlan): Repartition =
copy(child = newChild)
}
Expand Down Expand Up @@ -1290,6 +1299,7 @@ case class Deduplicate(
child: LogicalPlan) extends UnaryNode {
override def maxRows: Option[Long] = child.maxRows
override def output: Seq[Attribute] = child.output
final override val nodePatterns: Seq[TreePattern] = Seq(DISTINCT_LIKE)
override protected def withNewChildInternal(newChild: LogicalPlan): Deduplicate =
copy(child = newChild)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@ object RuleIdCollection {
"org.apache.spark.sql.catalyst.analysis.UpdateOuterReferences" ::
// Catalyst Optimizer rules
"org.apache.spark.sql.catalyst.optimizer.BooleanSimplification" ::
"org.apache.spark.sql.catalyst.optimizer.CollapseProject" ::
"org.apache.spark.sql.catalyst.optimizer.CollapseWindow" ::
"org.apache.spark.sql.catalyst.optimizer.CombineConcats" ::
"org.apache.spark.sql.catalyst.optimizer.CombineFilters" ::
"org.apache.spark.sql.catalyst.optimizer.CombineUnions" ::
"org.apache.spark.sql.catalyst.optimizer.ConstantFolding" ::
"org.apache.spark.sql.catalyst.optimizer.ConstantPropagation" ::
"org.apache.spark.sql.catalyst.optimizer.CostBasedJoinReorder" ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,15 @@ object TreePattern extends Enumeration {
val UPPER_OR_LOWER: Value = Value

// Logical plan patterns (alphabetically ordered)
val DISTINCT_LIKE: Value = Value
val FILTER: Value = Value
val INNER_LIKE_JOIN: Value = Value
val JOIN: Value = Value
val LEFT_SEMI_OR_ANTI_JOIN: Value = Value
val NATURAL_LIKE_JOIN: Value = Value
val PROJECT: Value = Value
val OUTER_JOIN: Value = Value
val REPARTITION_OPERATION: Value = Value
val WINDOW: Value = Value
val UNION: Value = Value
}