Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
support pruning in more rules
  • Loading branch information
sigmod committed May 5, 2021
commit 98830ebaa82f976b5410bd4678084b316edf180e
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.analysis.{DecimalPrecision, FunctionRegistry, TypeCheckResult}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.trees.TreePattern.{AVERAGE, TreePattern}
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -51,6 +52,8 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
// Return data type.
override def dataType: DataType = resultType

final override val nodePatterns: Seq[TreePattern] = Seq(AVERAGE)

private lazy val resultType = child.dataType match {
case DecimalType.Fixed(p, s) =>
DecimalType.bounded(p + 4, s + 4)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.trees.TreePattern.{SUM, TreePattern}
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -52,6 +53,8 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForAnsiIntervalOrNumericType(child.dataType, "sum")

final override val nodePatterns: Seq[TreePattern] = Seq(SUM)

private lazy val resultType = child.dataType match {
case DecimalType.Fixed(precision, scale) =>
DecimalType.bounded(precision + 10, scale)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ case class Alias(child: Expression, name: String)(
val nonInheritableMetadataKeys: Seq[String] = Seq.empty)
extends UnaryExpression with NamedExpression {

final override val nodePatterns: Seq[TreePattern] = Seq(ALIAS)

// Alias(Generator, xx) need to be transformed into Generate(generator, ...)
override lazy val resolved =
childrenResolved && checkInputDataTypes().isSuccess && !child.isInstanceOf[Generator]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ object EliminateDistinct extends Rule[LogicalPlan] {
*/
object EliminateAggregateFilter extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.transformExpressionsWithPruning(
_.containsAllPatterns(AGGREGATE, TRUE_OR_FALSE_LITERAL)) {
_.containsAllPatterns(TRUE_OR_FALSE_LITERAL)) {
case ae @ AggregateExpression(_, _, _, Some(Literal.TrueLiteral), _) =>
ae.copy(filter = None)
case AggregateExpression(af: DeclarativeAggregate, _, _, Some(Literal.FalseLiteral), _) =>
Expand Down Expand Up @@ -447,6 +447,9 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] {
* (self) join or to prevent the removal of top-level subquery attributes.
*/
private def removeRedundantAliases(plan: LogicalPlan, excluded: AttributeSet): LogicalPlan = {
if (!plan.containsPattern(ALIAS)) {
return plan
}
plan match {
// We want to keep the same output attributes for subqueries. This means we cannot remove
// the aliases that produce these attributes
Expand Down Expand Up @@ -508,7 +511,8 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] {
* only goal is to keep distinct values, while its parent aggregate would ignore duplicate values.
*/
object RemoveRedundantAggregates extends Rule[LogicalPlan] with AliasHelper {
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
_.containsPattern(AGGREGATE), ruleId) {
case upper @ Aggregate(_, _, lower: Aggregate) if lowerIsRedundant(upper, lower) =>
val aliasMap = getAliasMap(lower)

Expand Down Expand Up @@ -547,7 +551,8 @@ object RemoveRedundantAggregates extends Rule[LogicalPlan] with AliasHelper {
* Remove no-op operators from the query plan that do not make any modifications.
*/
object RemoveNoopOperators extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
_.containsAnyPattern(PROJECT, WINDOW), ruleId) {
// Eliminate no-op Projects
case p @ Project(_, child) if child.sameOutput(p) => child

Expand Down Expand Up @@ -599,7 +604,8 @@ object RemoveNoopUnion extends Rule[LogicalPlan] {
}
}

def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
_.containsAllPatterns(DISTINCT_LIKE, UNION), ruleId) {
case d @ Distinct(u: Union) =>
d.withNewChildren(Seq(simplifyUnion(u)))
case d @ Deduplicate(_, u: Union) =>
Expand Down Expand Up @@ -650,7 +656,8 @@ object LimitPushDown extends Rule[LogicalPlan] {
}
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
_.containsPattern(LIMIT), ruleId) {
// Adding extra Limits below UNION ALL for children which are not Limit or do not have Limit
// descendants whose maxRow is larger. This heuristic is valid assuming there does not exist any
// Limit push-down rule that is unable to infer the value of maxRows.
Expand Down Expand Up @@ -948,7 +955,8 @@ object CollapseRepartition extends Rule[LogicalPlan] {
* and user not specify.
*/
object OptimizeRepartition extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.transform {
override def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
_.containsPattern(REPARTITION_OPERATION), ruleId) {
case r @ RepartitionByExpression(partitionExpressions, _, numPartitions)
if partitionExpressions.nonEmpty && partitionExpressions.forall(_.foldable) &&
numPartitions.isEmpty =>
Expand All @@ -960,7 +968,8 @@ object OptimizeRepartition extends Rule[LogicalPlan] {
* Replaces first(col) to nth_value(col, 1) for better performance.
*/
object OptimizeWindowFunctions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressionsWithPruning(
_.containsPattern(WINDOW_EXPRESSION), ruleId) {
case we @ WindowExpression(AggregateExpression(first: First, _, _, _, _),
WindowSpecDefinition(_, orderSpec, frameSpecification: SpecifiedWindowFrame))
if orderSpec.nonEmpty && frameSpecification.frameType == RowFrame &&
Expand Down Expand Up @@ -1001,7 +1010,8 @@ object TransposeWindow extends Rule[LogicalPlan] {
})
}

def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
_.containsPattern(WINDOW), ruleId) {
case w1 @ Window(we1, ps1, os1, w2 @ Window(we2, ps2, os2, grandChild))
if w1.references.intersect(w2.windowOutputSet).isEmpty &&
w1.expressions.forall(_.deterministic) &&
Expand All @@ -1016,7 +1026,8 @@ object TransposeWindow extends Rule[LogicalPlan] {
* by this [[Generate]] can be removed earlier - before joins and in data sources.
*/
object InferFiltersFromGenerate extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
_.containsPattern(GENERATE)) {
// This rule does not infer filters from foldable expressions to avoid constant filters
// like 'size([1, 2, 3]) > 0'. These do not show up in child's constraints and
// then the idempotence will break.
Expand Down Expand Up @@ -1066,7 +1077,8 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan]
}
}

private def inferFilters(plan: LogicalPlan): LogicalPlan = plan transform {
private def inferFilters(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
_.containsAnyPattern(FILTER, JOIN)) {
case filter @ Filter(condition, child) =>
val newFilters = filter.constraints --
(child.constraints ++ splitConjunctivePredicates(condition))
Expand Down Expand Up @@ -1210,7 +1222,8 @@ object CombineFilters extends Rule[LogicalPlan] with PredicateHelper {
* function is order irrelevant
*/
object EliminateSorts extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform applyLocally
def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
_.containsPattern(SORT), ruleId)(applyLocally)

private val applyLocally: PartialFunction[LogicalPlan, LogicalPlan] = {
case s @ Sort(orders, _, child) if orders.isEmpty || orders.exists(_.child.foldable) =>
Expand All @@ -1229,11 +1242,16 @@ object EliminateSorts extends Rule[LogicalPlan] {
g.copy(child = recursiveRemoveSort(originChild))
}

private def recursiveRemoveSort(plan: LogicalPlan): LogicalPlan = plan match {
case Sort(_, _, child) => recursiveRemoveSort(child)
case other if canEliminateSort(other) =>
other.withNewChildren(other.children.map(recursiveRemoveSort))
case _ => plan
private def recursiveRemoveSort(plan: LogicalPlan): LogicalPlan = {
if (!plan.containsPattern(SORT)) {
return plan
}
plan match {
case Sort(_, _, child) => recursiveRemoveSort(child)
case other if canEliminateSort(other) =>
other.withNewChildren(other.children.map(recursiveRemoveSort))
case _ => plan
}
}

private def canEliminateSort(plan: LogicalPlan): Boolean = plan match {
Expand Down Expand Up @@ -1272,7 +1290,8 @@ object EliminateSorts extends Rule[LogicalPlan] {
* 3) by eliminating the always-true conditions given the constraints on the child's output.
*/
object PruneFilters extends Rule[LogicalPlan] with PredicateHelper {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
_.containsPattern(FILTER)) {
// If the filter condition always evaluate to true, remove the filter.
case Filter(Literal(true, BooleanType), child) => child
// If the filter condition always evaluate to null or false,
Expand Down Expand Up @@ -1628,7 +1647,8 @@ object EliminateLimits extends Rule[LogicalPlan] {
limitExpr.foldable && child.maxRows.exists { _ <= limitExpr.eval().asInstanceOf[Int] }
}

def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformDownWithPruning(
_.containsPattern(LIMIT), ruleId) {
case Limit(l, child) if canEliminate(l, child) =>
child
case GlobalLimit(l, child) if canEliminate(l, child) =>
Expand Down Expand Up @@ -1703,8 +1723,10 @@ object DecimalAggregates extends Rule[LogicalPlan] {
/** Maximum number of decimal digits representable precisely in a Double */
private val MAX_DOUBLE_DIGITS = 15

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsDown {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
_.containsAnyPattern(SUM, AVERAGE), ruleId) {
case q: LogicalPlan => q.transformExpressionsDownWithPruning(
_.containsAnyPattern(SUM, AVERAGE), ruleId) {
case we @ WindowExpression(ae @ AggregateExpression(af, _, _, _, _), _) => af match {
case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS =>
MakeDecimal(we.copy(windowFunction = ae.copy(aggregateFunction = Sum(UnscaledValue(e)))),
Expand Down Expand Up @@ -1740,7 +1762,8 @@ object DecimalAggregates extends Rule[LogicalPlan] {
* another `LocalRelation`.
*/
object ConvertToLocalRelation extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
_.containsPattern(LOCAL_RELATION), ruleId) {
case Project(projectList, LocalRelation(output, data, isStreaming))
if !projectList.exists(hasUnevaluableExpr) =>
val projection = new InterpretedMutableProjection(projectList, output)
Expand Down Expand Up @@ -1769,7 +1792,8 @@ object ConvertToLocalRelation extends Rule[LogicalPlan] {
* }}}
*/
object ReplaceDistinctWithAggregate extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
_.containsPattern(DISTINCT_LIKE), ruleId) {
case Distinct(child) => Aggregate(child.output, child.output, child)
}
}
Expand Down Expand Up @@ -1813,7 +1837,8 @@ object ReplaceDeduplicateWithAggregate extends Rule[LogicalPlan] {
* join conditions will be incorrect.
*/
object ReplaceIntersectWithSemiJoin extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
_.containsPattern(INTERSECT), ruleId) {
case Intersect(left, right, false) =>
assert(left.output.size == right.output.size)
val joinCond = left.output.zip(right.output).map { case (l, r) => EqualNullSafe(l, r) }
Expand All @@ -1834,7 +1859,8 @@ object ReplaceIntersectWithSemiJoin extends Rule[LogicalPlan] {
* join conditions will be incorrect.
*/
object ReplaceExceptWithAntiJoin extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
_.containsPattern(EXCEPT), ruleId) {
case Except(left, right, false) =>
assert(left.output.size == right.output.size)
val joinCond = left.output.zip(right.output).map { case (l, r) => EqualNullSafe(l, r) }
Expand Down Expand Up @@ -1874,7 +1900,8 @@ object ReplaceExceptWithAntiJoin extends Rule[LogicalPlan] {
*/

object RewriteExceptAll extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
_.containsPattern(EXCEPT), ruleId) {
case Except(left, right, true) =>
assert(left.output.size == right.output.size)

Expand Down Expand Up @@ -1931,7 +1958,8 @@ object RewriteExceptAll extends Rule[LogicalPlan] {
* }}}
*/
object RewriteIntersectAll extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
_.containsPattern(INTERSECT), ruleId) {
case Intersect(left, right, true) =>
assert(left.output.size == right.output.size)

Expand Down Expand Up @@ -1983,7 +2011,8 @@ object RewriteIntersectAll extends Rule[LogicalPlan] {
* but only makes the grouping key bigger.
*/
object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
_.containsPattern(AGGREGATE), ruleId) {
case a @ Aggregate(grouping, _, _) if grouping.nonEmpty =>
val newGrouping = grouping.filter(!_.foldable)
if (newGrouping.nonEmpty) {
Expand All @@ -2002,7 +2031,8 @@ object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] {
* but only makes the grouping key bigger.
*/
object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
_.containsPattern(AGGREGATE), ruleId) {
case a @ Aggregate(grouping, _, _) if grouping.size > 1 =>
val newGrouping = ExpressionSet(grouping).toSeq
if (newGrouping.size == grouping.size) {
Expand All @@ -2022,7 +2052,8 @@ object OptimizeLimitZero extends Rule[LogicalPlan] {
private def empty(plan: LogicalPlan) =
LocalRelation(plan.output, data = Seq.empty, isStreaming = plan.isStreaming)

def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
_.containsPattern(LIMIT)) {
// Nodes below GlobalLimit or LocalLimit can be pruned if the limit value is zero (0).
// Any subtree in the logical plan that has GlobalLimit 0 or LocalLimit 0 as its root is
// semantically equivalent to an empty relation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ case class Generate(
child: LogicalPlan)
extends UnaryNode {

final override val nodePatterns: Seq[TreePattern] = Seq(GENERATE)

lazy val requiredChildOutput: Seq[Attribute] = {
val unrequiredSet = unrequiredChildIndex.toSet
child.output.zipWithIndex.filterNot(t => unrequiredSet.contains(t._2)).map(_._1)
Expand Down Expand Up @@ -209,6 +211,8 @@ case class Intersect(

override def nodeName: String = getClass.getSimpleName + ( if ( isAll ) "All" else "" )

final override val nodePatterns: Seq[TreePattern] = Seq(INTERSECT)

override def output: Seq[Attribute] =
left.output.zip(right.output).map { case (leftAttr, rightAttr) =>
leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable)
Expand Down Expand Up @@ -632,6 +636,7 @@ case class Sort(
override def output: Seq[Attribute] = child.output
override def maxRows: Option[Long] = child.maxRows
override def outputOrdering: Seq[SortOrder] = order
final override val nodePatterns: Seq[TreePattern] = Seq(SORT)
override protected def withNewChildInternal(newChild: LogicalPlan): Sort = copy(child = newChild)
}

Expand Down
Loading