Skip to content
Next Next commit
[SPARK-42551][SQL] Support more subexpression elimination cases
  • Loading branch information
wankunde committed May 9, 2023
commit ffe4ca8b41416d82677dae72e3e74faeb0f721f8
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,16 @@ import java.util.Objects

import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.EquivalentExpressions.supportedExpression
import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.Utils

/**
* This class is used to compute equality of (sub)expression trees. Expressions can be added
* to this class and they subsequently query for expression equality. Expression trees are
* considered equal if for the same input(s), the same result is produced.
*/
class EquivalentExpressions(
skipForShortcutEnable: Boolean = SQLConf.get.subexpressionEliminationSkipForShotcutExpr) {

class EquivalentExpressions {
// For each expression, the set of equivalent expressions.
private val equivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats]

Expand Down Expand Up @@ -91,92 +88,6 @@ class EquivalentExpressions(
}
}

/**
* Adds or removes only expressions which are common in each of given expressions, in a recursive
* way.
* For example, given two expressions `(a + (b + (c + 1)))` and `(d + (e + (c + 1)))`, the common
* expression `(c + 1)` will be added into `equivalenceMap`.
*
* Note that as we don't know in advance if any child node of an expression will be common across
* all given expressions, we compute local equivalence maps for all given expressions and filter
* only the common nodes.
* Those common nodes are then removed from the local map and added to the final map of
* expressions.
*/
private def updateCommonExprs(
exprs: Seq[Expression],
map: mutable.HashMap[ExpressionEquals, ExpressionStats],
useCount: Int): Unit = {
assert(exprs.length > 1)
var localEquivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats]
updateExprTree(exprs.head, localEquivalenceMap)

exprs.tail.foreach { expr =>
val otherLocalEquivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats]
updateExprTree(expr, otherLocalEquivalenceMap)
localEquivalenceMap = localEquivalenceMap.filter { case (key, _) =>
otherLocalEquivalenceMap.contains(key)
}
}

// Start with the highest expression, remove it from `localEquivalenceMap` and add it to `map`.
// The remaining highest expression in `localEquivalenceMap` is also common expression so loop
// until `localEquivalenceMap` is not empty.
var statsOption = Some(localEquivalenceMap).filter(_.nonEmpty).map(_.maxBy(_._1.height)._2)
while (statsOption.nonEmpty) {
val stats = statsOption.get
updateExprTree(stats.expr, localEquivalenceMap, -stats.useCount)
updateExprTree(stats.expr, map, useCount)

statsOption = Some(localEquivalenceMap).filter(_.nonEmpty).map(_.maxBy(_._1.height)._2)
}
}

private def skipForShortcut(expr: Expression): Expression = {
if (skipForShortcutEnable) {
// The subexpression may not need to eval even if it appears more than once.
// e.g., `if(or(a, and(b, b)))`, the expression `b` would be skipped if `a` is true.
expr match {
case and: And => and.left
case or: Or => or.left
case other => other
}
} else {
expr
}
}

// There are some special expressions that we should not recurse into all of its children.
// 1. CodegenFallback: it's children will not be used to generate code (call eval() instead)
// 2. ConditionalExpression: use its children that will always be evaluated.
private def childrenToRecurse(expr: Expression): Seq[Expression] = expr match {
case _: CodegenFallback => Nil
case c: ConditionalExpression => c.alwaysEvaluatedInputs.map(skipForShortcut)
case other => skipForShortcut(other).children
}

// For some special expressions we cannot just recurse into all of its children, but we can
// recursively add the common expressions shared between all of its children.
private def commonChildrenToRecurse(expr: Expression): Seq[Seq[Expression]] = expr match {
case _: CodegenFallback => Nil
case c: ConditionalExpression => c.branchGroups
case _ => Nil
}

private def supportedExpression(e: Expression) = {
!e.exists {
// `LambdaVariable` is usually used as a loop variable, which can't be evaluated ahead of the
// loop. So we can't evaluate sub-expressions containing `LambdaVariable` at the beginning.
case _: LambdaVariable => true

// `PlanExpression` wraps query plan. To compare query plans of `PlanExpression` on executor,
// can cause error like NPE.
case _: PlanExpression[_] => Utils.isInRunningSparkTask

case _ => false
}
}

/**
* Adds the expression to this data structure recursively. Stops if a matching expression
* is found. That is, if `expr` has already been added, its children are not added.
Expand All @@ -197,8 +108,7 @@ class EquivalentExpressions(

if (!skip && !updateExprInMap(expr, map, useCount)) {
val uc = useCount.signum
childrenToRecurse(expr).foreach(updateExprTree(_, map, uc))
commonChildrenToRecurse(expr).filter(_.nonEmpty).foreach(updateCommonExprs(_, map, uc))
expr.children.foreach(updateExprTree(_, map, uc))
}
}

Expand Down Expand Up @@ -240,6 +150,23 @@ class EquivalentExpressions(
}
}

object EquivalentExpressions {
def supportedExpression(e: Expression): Boolean = {
!e.exists {
// `LambdaVariable` is usually used as a loop variable and `NamedLambdaVariable` is used in
// higher-order functions, which can't be evaluated ahead of the execution.
case _: LambdaVariable => true
case _: NamedLambdaVariable => true

// `PlanExpression` wraps query plan. To compare query plans of `PlanExpression` on executor,
// can cause error like NPE.
case _: PlanExpression[_] => Utils.isInRunningSparkTask

case _ => false
}
}
}

/**
* Wrapper around an Expression that provides semantic equality.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,41 @@ abstract class Expression extends TreeNode[Expression] {
}.getOrElse {
val isNull = ctx.freshName("isNull")
val value = ctx.freshName("value")
val eval = doGenCode(ctx, ExprCode(
JavaCode.isNullVariable(isNull),
JavaCode.variable(value, dataType)))
val exprKey = ExpressionEquals(this)
val eval = if (EquivalentExpressions.supportedExpression(this)) {
ctx.commonExpressions.get(exprKey) match {
case Some((useCount, genFunc, Some(reuseExprCode))) =>
ctx.commonExpressions -= exprKey
if (useCount <= 1) {
ctx.commonExpressions -= exprKey
} else {
ctx.commonExpressions += exprKey ->
(useCount - 1, genFunc, Some(reuseExprCode))
}
reuseExprCode
case Some((useCount, genFunc, None)) =>
val eval = doGenCode(ctx, ExprCode(
JavaCode.isNullVariable(isNull),
JavaCode.variable(value, dataType)))
val reuseExprCode = genFunc(eval)
ctx.commonExpressions -= exprKey
if (useCount <= 1) {
ctx.commonExpressions -= exprKey
} else {
ctx.commonExpressions += exprKey ->
(useCount - 1, genFunc, Some(reuseExprCode))
}
reuseExprCode
case None =>
doGenCode(ctx, ExprCode(
JavaCode.isNullVariable(isNull),
JavaCode.variable(value, dataType)))
}
} else {
doGenCode(ctx, ExprCode(
JavaCode.isNullVariable(isNull),
JavaCode.variable(value, dataType)))
}
reduceCodeSize(ctx, eval)
if (eval.code.toString.nonEmpty) {
// Add `this` in the comment.
Expand Down
Loading