diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala index 04c15e3f2bd15..e0c4e21503e91 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala @@ -187,7 +187,7 @@ class ProtoToParsedPlanTestSuite object Helper extends RuleExecutor[LogicalPlan] { val batches = Batch("Finish Analysis", Once, ReplaceExpressions) :: - Batch("Rewrite With expression", FixedPoint(10), RewriteWithExpression) :: Nil + Batch("Rewrite With expression", Once, RewriteWithExpression) :: Nil } Helper.execute(catalystPlan) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 8d46b46eefbe5..decef766ae97d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -149,7 +149,7 @@ abstract class Optimizer(catalogManager: CatalogManager) Batch("Finish Analysis", Once, FinishAnalysis) :: // We must run this batch after `ReplaceExpressions`, as `RuntimeReplaceable` expression // may produce `With` expressions that need to be rewritten. - Batch("Rewrite With expression", fixedPoint, RewriteWithExpression) :: + Batch("Rewrite With expression", Once, RewriteWithExpression) :: ////////////////////////////////////////////////////////////////////////////////////////// // Optimizer rules start here ////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala index 073f60bca47f7..c5bd71b4a7d1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, CommonExpressionRef, Expression, With} +import org.apache.spark.sql.catalyst.expressions.{Alias, CommonExpressionDef, CommonExpressionRef, Expression, With} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMON_EXPR_REF, WITH_EXPRESSION} @@ -35,57 +35,48 @@ object RewriteWithExpression extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = { plan.transformWithPruning(_.containsPattern(WITH_EXPRESSION)) { case p if p.expressions.exists(_.containsPattern(WITH_EXPRESSION)) => - val commonExprs = mutable.ArrayBuffer.empty[Alias] - // `With` can be nested, we should only rewrite the leaf `With` expression, as the outer - // `With` needs to add its own Project, in the next iteration when it becomes leaf. - // This is done via "transform down" and check if the common expression definitions does not - // contain nested `With`. - var newPlan: LogicalPlan = p.transformExpressionsDown { - case With(child, defs) if defs.forall(!_.containsPattern(WITH_EXPRESSION)) => - val idToCheapExpr = mutable.HashMap.empty[Long, Expression] - val idToNonCheapExpr = mutable.HashMap.empty[Long, Alias] - defs.zipWithIndex.foreach { case (commonExprDef, index) => - if (CollapseProject.isCheap(commonExprDef.child)) { - idToCheapExpr(commonExprDef.id) = commonExprDef.child + var newChildren = p.children + var newPlan: LogicalPlan = p.transformExpressionsUp { + case With(child, defs) => + val refToExpr = mutable.HashMap.empty[Long, Expression] + val childProjections = Array.fill(newChildren.size)(mutable.ArrayBuffer.empty[Alias]) + + defs.zipWithIndex.foreach { case (CommonExpressionDef(child, id), index) => + if (CollapseProject.isCheap(child)) { + refToExpr(id) = child } else { - // TODO: we should calculate the ref count and also inline the common expression - // if it's ref count is 1. - val alias = Alias(commonExprDef.child, s"_common_expr_$index")() - commonExprs += alias - idToNonCheapExpr(commonExprDef.id) = alias + val childProjectionIndex = newChildren.indexWhere( + c => child.references.subsetOf(c.outputSet) + ) + if (childProjectionIndex == -1) { + // When we cannot rewrite the common expressions, force to inline them so that the + // query can still run. This can happen if the join condition contains `With` and + // the common expression references columns from both join sides. + // TODO: things can go wrong if the common expression is nondeterministic. We + // don't fix it for now to match the old buggy behavior when certain + // `RuntimeReplaceable` did not use the `With` expression. + // TODO: we should calculate the ref count and also inline the common expression + // if it's ref count is 1. + refToExpr(id) = child + } else { + val alias = Alias(child, s"_common_expr_$index")() + childProjections(childProjectionIndex) += alias + refToExpr(id) = alias.toAttribute + } } } - child.transformWithPruning(_.containsPattern(COMMON_EXPR_REF)) { - case ref: CommonExpressionRef => - idToCheapExpr.getOrElse(ref.id, idToNonCheapExpr(ref.id).toAttribute) + newChildren = newChildren.zip(childProjections).map { case (child, projections) => + if (projections.nonEmpty) { + Project(child.output ++ projections, child) + } else { + child + } } - } - var exprsToAdd = commonExprs.toSeq - val newChildren = newPlan.children.map { child => - val (newExprs, others) = exprsToAdd.partition(_.references.subsetOf(child.outputSet)) - exprsToAdd = others - if (newExprs.nonEmpty) { - Project(child.output ++ newExprs, child) - } else { - child - } - } - - if (exprsToAdd.nonEmpty) { - // When we cannot rewrite the common expressions, force to inline them so that the query - // can still run. This can happen if the join condition contains `With` and the common - // expression references columns from both join sides. - // TODO: things can go wrong if the common expression is nondeterministic. We don't fix - // it for now to match the old buggy behavior when certain `RuntimeReplaceable` - // did not use the `With` expression. - val attrToExpr = AttributeMap(exprsToAdd.map { alias => - alias.toAttribute -> alias.child - }) - newPlan = newPlan.transformExpressionsUp { - case a: Attribute => attrToExpr.getOrElse(a, a) - } + child.transformWithPruning(_.containsPattern(COMMON_EXPR_REF)) { + case ref: CommonExpressionRef => refToExpr(ref.id) + } } newPlan = newPlan.withNewChildren(newChildren) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala index c4b08e6e5de85..c625379eb5ffd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.types.IntegerType class RewriteWithExpressionSuite extends PlanTest { object Optimizer extends RuleExecutor[LogicalPlan] { - val batches = Batch("Rewrite With expression", FixedPoint(10), RewriteWithExpression) :: Nil + val batches = Batch("Rewrite With expression", Once, RewriteWithExpression) :: Nil } private val testRelation = LocalRelation($"a".int, $"b".int)