Skip to content
Prev Previous commit
Next Next commit
Remove whole-stage expression elimination
  • Loading branch information
wankunde committed May 9, 2023
commit f524ef51b5b118796225a120f81d7fe579f23f28
Original file line number Diff line number Diff line change
Expand Up @@ -1027,22 +1027,21 @@ class CodegenContext extends Logging {
splitExpressions(subexprFunctions.toSeq, "subexprFunc_split", Seq("InternalRow" -> INPUT_ROW))
}

/**
* Collect all commons expressions and return the initialization code block.
* @param expressions
* @return
*/
def subexpressionElimination(expressions: Seq[Expression]): Block = {
var initBlock: Block = EmptyBlock
if (SQLConf.get.subexpressionEliminationEnabled) {
val equivalence = new EquivalentExpressions
wholeStageSubexpressionElimination(expressions, equivalence)
expressions.map(equivalence.addExprTree(_))
equivalence.getAllExprStates(1).map(initBlock += initCommonExpression(_))
}
initBlock
}

def wholeStageSubexpressionElimination(
expressions: Seq[Expression],
equivalence: EquivalentExpressions): Unit = {
expressions.map(equivalence.addExprTree(_))
}

def initCommonExpression(stats: ExpressionStats): Block = {
if (stats.initialized.isEmpty) {
val expr = stats.expr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ import org.apache.spark.util.Utils
*/
trait CodegenSupport extends SparkPlan {

def reusableExpressions(): (Seq[Expression], Seq[Attribute]) = (Seq(), Seq())

var initBlock: Block = EmptyBlock
var commonExpressions = mutable.Map.empty[ExpressionEquals, ExpressionStats]

Expand Down Expand Up @@ -664,56 +662,6 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int)
def doCodeGen(): (CodegenContext, CodeAndComment) = {
val startTime = System.nanoTime()
val ctx = new CodegenContext

if (SQLConf.get.subexpressionEliminationEnabled) {
val stack = mutable.Stack[SparkPlan](child)
var attributeSeq = Seq[Attribute]()
val executeSeq =
new mutable.ArrayBuffer[(CodegenSupport, Seq[Expression], EquivalentExpressions)]()
var equivalence = new EquivalentExpressions
while (stack.nonEmpty) {
stack.pop() match {
case _: WholeStageCodegenExec =>
case _: InputRDDCodegen =>
case plan: CodegenSupport =>
// Because this plan may already be optimized before, so remove stale commonExpressions
plan.initBlock = EmptyBlock
plan.commonExpressions.clear()
val (newReusableExpressions, newAttributeSeq) = plan.reusableExpressions()
// If the input attributes changed, collect current common expressions and clear
// equivalentExpressions
if (attributeSeq.size != newAttributeSeq.size ||
attributeSeq.zip(newAttributeSeq).exists { case (left, right) => left != right }) {
equivalence = new EquivalentExpressions
}
if (newReusableExpressions.nonEmpty) {
val bondExpressions =
BindReferences.bindReferences(newReusableExpressions, newAttributeSeq)
executeSeq += ((plan, bondExpressions, equivalence))
ctx.wholeStageSubexpressionElimination(bondExpressions, equivalence)
}
attributeSeq = newAttributeSeq
stack.pushAll(plan.children)

case _ =>
}
}
executeSeq.reverse.foreach { case (plan, bondExpressions, equivalence) =>
val commonExprs =
equivalence.getAllExprStates(1)
.map(stat => ExpressionEquals(stat.expr) -> stat).toMap
bondExpressions.foreach {
_.foreach { expr =>
commonExprs.get(ExpressionEquals(expr)).map { stat =>
plan.initBlock += ctx.initCommonExpression(stat)
plan.commonExpressions += ExpressionEquals(expr) -> stat
}
}
}
}
// Do not support CSE in produce method.
ctx.commonExpressions.clear()
}
val code = child.asInstanceOf[CodegenSupport].produce(ctx, this)

// main next function.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,9 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)
references.filter(a => usedMoreThanOnce.contains(a.exprId))
}

override def reusableExpressions(): (Seq[Expression], Seq[Attribute]) =
(projectList, child.output)


override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
val exprs = bindReferences[Expression](projectList, child.output)
initBlock += ctx.subexpressionElimination(exprs: _*)
val resultVars = exprs.map(_.genCode(ctx))

// Evaluation of non-deterministic expressions can't be deferred.
Expand Down Expand Up @@ -169,6 +166,8 @@ trait GeneratePredicateHelper extends PredicateHelper {
// TODO: revisit this. We can consider reordering predicates as well.
val generatedIsNotNullChecks = new Array[Boolean](notNullPreds.length)
val extraIsNotNullAttrs = mutable.Set[Attribute]()
initBlock +=
ctx.subexpressionElimination(otherPreds.map(BindReferences.bindReference(_, inputAttrs)): _*)
val generated = otherPreds.map { c =>
val nullChecks = c.references.map { r =>
val idx = notNullPreds.indexWhere { n => n.asInstanceOf[IsNotNull].child.semanticEquals(r)}
Expand Down Expand Up @@ -237,9 +236,6 @@ case class FilterExec(condition: Expression, child: SparkPlan)
child.asInstanceOf[CodegenSupport].produce(ctx, this)
}

override def reusableExpressions(): (Seq[Expression], Seq[Attribute]) =
(otherPreds, child.output)

override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
val numOutput = metricTerm(ctx, "numOutputRows")

Expand Down