Skip to content
Prev Previous commit
Next Next commit
Bug fix for whole stage subexpression elimination
  • Loading branch information
wankunde committed May 9, 2023
commit e5ee05ef3457e4fb22c3e01c42b167436b070623
Original file line number Diff line number Diff line change
Expand Up @@ -1159,7 +1159,7 @@ class CodegenContext extends Logging {
}
}

var commonExpressions = Map[ExpressionEquals, ExpressionStats]()
var commonExpressions = mutable.Map[ExpressionEquals, ExpressionStats]()
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ import org.apache.spark.util.Utils
*/
trait CodegenSupport extends SparkPlan {

def reusableExpressions(): (Seq[Expression], AttributeSet) = (Seq(), AttributeSet.empty)
def reusableExpressions(): (Seq[Expression], Seq[Attribute]) = (Seq(), Seq())

var initBlock: Block = EmptyBlock
var commonExpressions = Map[ExpressionEquals, ExpressionStats]()
var commonExpressions = mutable.Map.empty[ExpressionEquals, ExpressionStats]

/** Prefix used in the current operator's variable names. */
private def variablePrefix: String = this match {
Expand Down Expand Up @@ -667,7 +667,7 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int)

if (SQLConf.get.subexpressionEliminationEnabled) {
val stack = mutable.Stack[SparkPlan](child)
var attributeSet = AttributeSet.empty
var attributeSeq = Seq[Attribute]()
val executeSeq =
new mutable.ArrayBuffer[(CodegenSupport, Seq[Expression], EquivalentExpressions)]()
var equivalence = new EquivalentExpressions
Expand All @@ -676,19 +676,20 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int)
case _: WholeStageCodegenExec =>
case _: InputRDDCodegen =>
case c: CodegenSupport =>
val (newReusableExpressions, newAttributeSet) = c.reusableExpressions()
val (newReusableExpressions, newAttributeSeq) = c.reusableExpressions()
// If the input attributes changed, collect current common expressions and clear
// equivalentExpressions
if (!attributeSet.subsetOf(newAttributeSet)) {
if (attributeSeq.size != newAttributeSeq.size ||
attributeSeq.zip(newAttributeSeq).exists(tup => !tup._1.equals(tup._2))) {
equivalence = new EquivalentExpressions
}
if (newReusableExpressions.nonEmpty) {
val bondExpressions =
BindReferences.bindReferences(newReusableExpressions, newAttributeSet.toSeq)
BindReferences.bindReferences(newReusableExpressions, newAttributeSeq.toSeq)
executeSeq += ((c, bondExpressions, equivalence))
ctx.wholeStageSubexpressionElimination(bondExpressions, equivalence)
}
attributeSet = newAttributeSet
attributeSeq = newAttributeSeq
stack.pushAll(c.children)

case _ =>
Expand All @@ -698,11 +699,11 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int)
val commonExprs =
equivalence.getAllExprStates(1)
.map(stat => ExpressionEquals(stat.expr) -> stat).toMap
plan.commonExpressions = commonExprs
bondExpressions.foreach {
_.foreach { expr =>
commonExprs.get(ExpressionEquals(expr)).map { stat =>
plan.initBlock += ctx.initCommonExpression(stat)
plan.commonExpressions += ExpressionEquals(expr) -> stat
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)
references.filter(a => usedMoreThanOnce.contains(a.exprId))
}

override def reusableExpressions(): (Seq[Expression], AttributeSet) =
(projectList, AttributeSet(child.output))
override def reusableExpressions(): (Seq[Expression], Seq[Attribute]) =
(projectList, child.output)


override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
Expand Down Expand Up @@ -237,8 +237,8 @@ case class FilterExec(condition: Expression, child: SparkPlan)
child.asInstanceOf[CodegenSupport].produce(ctx, this)
}

override def reusableExpressions(): (Seq[Expression], AttributeSet) =
(otherPreds, AttributeSet(child.output))
override def reusableExpressions(): (Seq[Expression], Seq[Attribute]) =
(otherPreds, child.output)

override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
val numOutput = metricTerm(ctx, "numOutputRows")
Expand Down