Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,11 @@ trait ConditionalExpression extends Expression {
*/
def alwaysEvaluatedInputs: Seq[Expression]

/**
* Return a copy of itself with a new `alwaysEvaluatedInputs`.
*/
def withNewAlwaysEvaluatedInputs(alwaysEvaluatedInputs: Seq[Expression]): ConditionalExpression

/**
* Return groups of branches. For each group, at least one branch will be hit at runtime,
* so that we can eagerly evaluate the common expressions of a group.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
*/
override def alwaysEvaluatedInputs: Seq[Expression] = predicate :: Nil

override def withNewAlwaysEvaluatedInputs(alwaysEvaluatedInputs: Seq[Expression]): If = {
copy(predicate = alwaysEvaluatedInputs.head)
}

override def branchGroups: Seq[Seq[Expression]] = Seq(Seq(trueValue, falseValue))

final override val nodePatterns : Seq[TreePattern] = Seq(IF)
Expand Down Expand Up @@ -165,8 +169,15 @@ case class CaseWhen(

final override val nodePatterns : Seq[TreePattern] = Seq(CASE_WHEN)

override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
super.legacyWithNewChildren(newChildren)
override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): CaseWhen = {
if (newChildren.length % 2 == 0) {
copy(branches = newChildren.grouped(2).map { case Seq(a, b) => (a, b) }.toSeq)
} else {
copy(
branches = newChildren.dropRight(1).grouped(2).map { case Seq(a, b) => (a, b) }.toSeq,
elseValue = newChildren.lastOption)
}
}

// both then and else expressions should be considered.
@transient
Expand Down Expand Up @@ -213,6 +224,10 @@ case class CaseWhen(
*/
override def alwaysEvaluatedInputs: Seq[Expression] = children.head :: Nil

override def withNewAlwaysEvaluatedInputs(alwaysEvaluatedInputs: Seq[Expression]): CaseWhen = {
withNewChildrenInternal(alwaysEvaluatedInputs.toIndexedSeq ++ children.drop(1))
}

override def branchGroups: Seq[Seq[Expression]] = {
// We look at subexpressions in conditions and values of `CaseWhen` separately. It is
// because a subexpression in conditions will be run no matter which condition is matched
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ case class Coalesce(children: Seq[Expression])
*/
override def alwaysEvaluatedInputs: Seq[Expression] = children.head :: Nil

override def withNewAlwaysEvaluatedInputs(alwaysEvaluatedInputs: Seq[Expression]): Coalesce = {
withNewChildrenInternal(alwaysEvaluatedInputs.toIndexedSeq ++ children.drop(1))
}

override def branchGroups: Seq[Seq[Expression]] = if (children.length > 1) {
// If there is only one child, the first child is already covered by
// `alwaysEvaluatedInputs` and we should exclude it here.
Expand Down Expand Up @@ -290,6 +294,10 @@ case class NaNvl(left: Expression, right: Expression)
*/
override def alwaysEvaluatedInputs: Seq[Expression] = left :: Nil

override def withNewAlwaysEvaluatedInputs(alwaysEvaluatedInputs: Seq[Expression]): NaNvl = {
copy(left = alwaysEvaluatedInputs.head)
}

override def branchGroups: Seq[Seq[Expression]] = Seq(children)

override def eval(input: InternalRow): Any = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ package org.apache.spark.sql.catalyst.optimizer

import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.{Alias, CommonExpressionDef, CommonExpressionRef, Expression, With}
import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMON_EXPR_REF, WITH_EXPRESSION}
Expand All @@ -35,56 +36,92 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = {
plan.transformWithPruning(_.containsPattern(WITH_EXPRESSION)) {
case p if p.expressions.exists(_.containsPattern(WITH_EXPRESSION)) =>
var newChildren = p.children
var newPlan: LogicalPlan = p.transformExpressionsUp {
case With(child, defs) =>
val refToExpr = mutable.HashMap.empty[Long, Expression]
val childProjections = Array.fill(newChildren.size)(mutable.ArrayBuffer.empty[Alias])
val inputPlans = p.children.toArray
var newPlan: LogicalPlan = p.mapExpressions { expr =>
rewriteWithExprAndInputPlans(expr, inputPlans)
}
newPlan = newPlan.withNewChildren(inputPlans.toIndexedSeq)
if (p.output == newPlan.output) {
newPlan
} else {
Project(p.output, newPlan)
}
}
}

private def rewriteWithExprAndInputPlans(
e: Expression,
inputPlans: Array[LogicalPlan]): Expression = {
if (!e.containsPattern(WITH_EXPRESSION)) return e
e match {
case w: With =>
// Rewrite nested With expressions first
val child = rewriteWithExprAndInputPlans(w.child, inputPlans)
val defs = w.defs.map(rewriteWithExprAndInputPlans(_, inputPlans))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we have "manual" recursion (instead of transformExpressionsUp()), shall we deal with nested Withs in w.child too?

Copy link
Contributor

@peter-toth peter-toth Nov 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, the current logic seems to behave correctly if there is an inner With in an outer With's child and the inner has a definition with a reference to an outer definition . (The previous transformExpressionsUp() had issues in that case.) But the rule is not idempotent now, so maybe we should recurse into w.child after replacing CommonExpressionRefs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good catch! It seems doesn't matter when to recurse into w.child, either before replacing CommonExpressionRef or after is fine?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe before is better, as the expression tree may be much larger after replacing CommonExpressionRef

Copy link
Contributor

@peter-toth peter-toth Nov 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure. E.g. if we have With(With(x + x, Seq(x = y + y)), Seq(y = a + 1)) where x and y are references and a is an attribute and we would recurse into With(x + x, Seq(x = y + y)) before replacing the y references to actual attributes, that aliases a + 1, then the childProjectionIndex calculation for y + y won't find the right child, will it? But an UT covering this case would be good. :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh correlated nested With! I'm not sure if we want to support it or not... But at least we should fail if we don't want to support it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then we may need a test for that (either supported or failed if not).

val refToExpr = mutable.HashMap.empty[Long, Expression]
val childProjections = Array.fill(inputPlans.length)(mutable.ArrayBuffer.empty[Alias])

defs.zipWithIndex.foreach { case (CommonExpressionDef(child, id), index) =>
if (child.containsPattern(COMMON_EXPR_REF)) {
throw SparkException.internalError(
"Common expression definition cannot reference other Common expression definitions")
}

defs.zipWithIndex.foreach { case (CommonExpressionDef(child, id), index) =>
if (CollapseProject.isCheap(child)) {
refToExpr(id) = child
} else {
val childProjectionIndex = newChildren.indexWhere(
c => child.references.subsetOf(c.outputSet)
)
if (childProjectionIndex == -1) {
// When we cannot rewrite the common expressions, force to inline them so that the
// query can still run. This can happen if the join condition contains `With` and
// the common expression references columns from both join sides.
// TODO: things can go wrong if the common expression is nondeterministic. We
// don't fix it for now to match the old buggy behavior when certain
// `RuntimeReplaceable` did not use the `With` expression.
// TODO: we should calculate the ref count and also inline the common expression
// if it's ref count is 1.
refToExpr(id) = child
} else {
val alias = Alias(child, s"_common_expr_$index")()
childProjections(childProjectionIndex) += alias
refToExpr(id) = alias.toAttribute
}
}
if (CollapseProject.isCheap(child)) {
refToExpr(id) = child
} else {
val childProjectionIndex = inputPlans.indexWhere(
c => child.references.subsetOf(c.outputSet)
)
if (childProjectionIndex == -1) {
// When we cannot rewrite the common expressions, force to inline them so that the
// query can still run. This can happen if the join condition contains `With` and
// the common expression references columns from both join sides.
// TODO: things can go wrong if the common expression is nondeterministic. We
// don't fix it for now to match the old buggy behavior when certain
// `RuntimeReplaceable` did not use the `With` expression.
// TODO: we should calculate the ref count and also inline the common expression
// if it's ref count is 1.
refToExpr(id) = child
} else {
val alias = Alias(child, s"_common_expr_$index")()
childProjections(childProjectionIndex) += alias
refToExpr(id) = alias.toAttribute
}
}
}

for (i <- inputPlans.indices) {
val projectList = childProjections(i)
if (projectList.nonEmpty) {
inputPlans(i) = Project(inputPlans(i).output ++ projectList, inputPlans(i))
}
}

newChildren = newChildren.zip(childProjections).map { case (child, projections) =>
if (projections.nonEmpty) {
Project(child.output ++ projections, child)
} else {
child
}
child.transformWithPruning(_.containsPattern(COMMON_EXPR_REF)) {
case ref: CommonExpressionRef =>
if (!refToExpr.contains(ref.id)) {
throw SparkException.internalError("Undefined common expression id " + ref.id)
}
refToExpr(ref.id)
}

case c: ConditionalExpression =>
val newAlwaysEvaluatedInputs = c.alwaysEvaluatedInputs.map(
rewriteWithExprAndInputPlans(_, inputPlans))
Comment on lines +110 to +111
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is dealing with common expressions only in always evaluated input e.g., predicate of If.

How about common expressions shared between predicate and branches?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about it before. The problem is that it's hard to update the original ConditionalExpression with the new shared common expressions. alwaysEvaluatedInputs is static so that I can let every ConditionalExpression to implement a method to update it.

val newExpr = c.withNewAlwaysEvaluatedInputs(newAlwaysEvaluatedInputs)
// Use transformUp to handle nested With.
newExpr.transformUpWithPruning(_.containsPattern(WITH_EXPRESSION)) {
case With(child, defs) =>
// For With in the conditional branches, they may not be evaluated at all and we can't
// pull the common expressions into a project which will always be evaluated. Inline it.
Comment on lines +115 to +117
Copy link
Member

@viirya viirya Nov 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, for specific conditional expression, e.g., If, we can still extract common expression shared on both branches which will be evaluated for sure?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as https://github.com/apache/spark/pull/43978/files#r1403392772 .

It's easy to find these common expressions shared on both branches, but it's hard to put them back to If. I think it's better to do it when we make it into a general rule that find shared common expressions and create With to deduplicate.

val refToExpr = defs.map(d => d.id -> d.child).toMap
child.transformWithPruning(_.containsPattern(COMMON_EXPR_REF)) {
case ref: CommonExpressionRef => refToExpr(ref.id)
}
}

newPlan = newPlan.withNewChildren(newChildren)
if (p.output == newPlan.output) {
newPlan
} else {
Project(p.output, newPlan)
}
case other => other.mapChildren(rewriteWithExprAndInputPlans(_, inputPlans))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, CommonExpressionDef, CommonExpressionRef, With}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Coalesce, CommonExpressionDef, CommonExpressionRef, With}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
Expand Down Expand Up @@ -57,7 +58,7 @@ class RewriteWithExpressionSuite extends PlanTest {
)
}

test("nested WITH expression") {
test("nested WITH expression in the definition expression") {
val a = testRelation.output.head
val commonExprDef = CommonExpressionDef(a + a)
val ref = new CommonExpressionRef(commonExprDef)
Expand Down Expand Up @@ -85,6 +86,57 @@ class RewriteWithExpressionSuite extends PlanTest {
)
}

test("nested WITH expression in the main expression") {
val a = testRelation.output.head
val commonExprDef = CommonExpressionDef(a + a)
val ref = new CommonExpressionRef(commonExprDef)
val innerExpr = With(ref + ref, Seq(commonExprDef))
val innerCommonExprName = "_common_expr_0"

val b = testRelation.output.last
val outerCommonExprDef = CommonExpressionDef(b + b)
val outerRef = new CommonExpressionRef(outerCommonExprDef)
val outerExpr = With(outerRef * outerRef + innerExpr, Seq(outerCommonExprDef))
val outerCommonExprName = "_common_expr_0"

val plan = testRelation.select(outerExpr.as("col"))
val rewrittenInnerExpr = (a + a).as(innerCommonExprName)
val rewrittenOuterExpr = (b + b).as(outerCommonExprName)
val finalExpr = rewrittenOuterExpr.toAttribute * rewrittenOuterExpr.toAttribute +
(rewrittenInnerExpr.toAttribute + rewrittenInnerExpr.toAttribute)
comparePlans(
Optimizer.execute(plan),
testRelation
.select((testRelation.output :+ rewrittenInnerExpr): _*)
.select((testRelation.output :+ rewrittenInnerExpr.toAttribute :+ rewrittenOuterExpr): _*)
.select(finalExpr.as("col"))
.analyze
)
}

test("correlated nested WITH expression is not supported") {
val b = testRelation.output.last
val outerCommonExprDef = CommonExpressionDef(b + b)
val outerRef = new CommonExpressionRef(outerCommonExprDef)

val a = testRelation.output.head
// The inner expression definition references the outer expression
val commonExprDef1 = CommonExpressionDef(a + a + outerRef)
val ref1 = new CommonExpressionRef(commonExprDef1)
val innerExpr1 = With(ref1 + ref1, Seq(commonExprDef1))

val outerExpr1 = With(outerRef + innerExpr1, Seq(outerCommonExprDef))
intercept[SparkException](Optimizer.execute(testRelation.select(outerExpr1.as("col"))))

val commonExprDef2 = CommonExpressionDef(a + a)
val ref2 = new CommonExpressionRef(commonExprDef2)
// The inner main expression references the outer expression
val innerExpr2 = With(ref2 + outerRef, Seq(commonExprDef1))

val outerExpr2 = With(outerRef + innerExpr2, Seq(outerCommonExprDef))
intercept[SparkException](Optimizer.execute(testRelation.select(outerExpr2.as("col"))))
}

test("WITH expression in filter") {
val a = testRelation.output.head
val commonExprDef = CommonExpressionDef(a + a)
Expand Down Expand Up @@ -154,4 +206,27 @@ class RewriteWithExpressionSuite extends PlanTest {
)
)
}

test("WITH expression inside conditional expression") {
val a = testRelation.output.head
val commonExprDef = CommonExpressionDef(a + a)
val ref = new CommonExpressionRef(commonExprDef)
val expr = Coalesce(Seq(a, With(ref * ref, Seq(commonExprDef))))
val inlinedExpr = Coalesce(Seq(a, (a + a) * (a + a)))
val plan = testRelation.select(expr.as("col"))
// With in the conditional branches is always inlined.
comparePlans(Optimizer.execute(plan), testRelation.select(inlinedExpr.as("col")))

val expr2 = Coalesce(Seq(With(ref * ref, Seq(commonExprDef)), a))
val plan2 = testRelation.select(expr2.as("col"))
val commonExprName = "_common_expr_0"
// With in the always-evaluated branches can still be optimized.
comparePlans(
Optimizer.execute(plan2),
testRelation
.select((testRelation.output :+ (a + a).as(commonExprName)): _*)
.select(Coalesce(Seq(($"$commonExprName" * $"$commonExprName"), a)).as("col"))
.analyze
)
}
}