diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index 2bbde304c281..3ceab0493b6a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -85,10 +85,14 @@ class EquivalentExpressions { } /** - * Adds or removes only expressions which are common in each of given expressions, in a recursive - * way. - * For example, given two expressions `(a + (b + (c + 1)))` and `(d + (e + (c + 1)))`, the common + * Adds or removes expressions only when the expressions: + * 1. are common in exprs. + * For example, given two exprs `(a + (b + (c + 1)))` and `(d + (e + (c + 1)))`, the common * expression `(c + 1)` will be added into `equivalenceMap`. + * 2. are common between alwaysEvaluateExprs and exprs. + * For example, given alwaysEvaluateExprs `(a + (b + (c + 1)))` and two exprs `e + f` + * and `(d + (e + (c + 1)))`, the common expression `(c + 1)` + * will be added into `equivalenceMap`. * * Note that as we don't know in advance if any child node of an expression will be common across * all given expressions, we compute local equivalence maps for all given expressions and filter @@ -97,6 +101,7 @@ class EquivalentExpressions { * expressions. */ private def updateCommonExprs( + alwaysEvaluateExprs: Seq[Expression], exprs: Seq[Expression], map: mutable.HashMap[ExpressionEquals, ExpressionStats], useCount: Int): Unit = { @@ -112,6 +117,21 @@ class EquivalentExpressions { } } + if (alwaysEvaluateExprs.length > 0) { + val alwaysEvaluateEquivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats] + alwaysEvaluateExprs.foreach(updateExprTree(_, alwaysEvaluateEquivalenceMap)) + // check if part of `exprs` have common expressions with `alwaysEvaluateExprs`. + exprs.filter(e => !alwaysEvaluateExprs.exists(_.semanticEquals(e))).foreach { expr => + val otherLocalEquivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats] + updateExprTree(expr, otherLocalEquivalenceMap) + otherLocalEquivalenceMap.foreach { case (key, stats) => + if (alwaysEvaluateEquivalenceMap.contains(key) && !localEquivalenceMap.contains(key)) { + updateExprTree(stats.expr, localEquivalenceMap) + } + } + } + } + // Start with the highest expression, remove it from `localEquivalenceMap` and add it to `map`. // The remaining highest expression in `localEquivalenceMap` is also common expression so loop // until `localEquivalenceMap` is not empty. @@ -136,9 +156,10 @@ class EquivalentExpressions { // For some special expressions we cannot just recurse into all of its children, but we can // recursively add the common expressions shared between all of its children. - private def commonChildrenToRecurse(expr: Expression): Seq[Seq[Expression]] = expr match { + private def commonChildrenToRecurse(expr: Expression): Seq[(Seq[Expression], Seq[Expression])] = + expr match { case _: CodegenFallback => Nil - case c: ConditionalExpression => c.branchGroups + case c: ConditionalExpression => c.branchGroups.map((c.alwaysEvaluatedInputs, _)) case _ => Nil } @@ -168,7 +189,11 @@ class EquivalentExpressions { if (!skip && !updateExprInMap(expr, map, useCount)) { val uc = useCount.signum childrenToRecurse(expr).foreach(updateExprTree(_, map, uc)) - commonChildrenToRecurse(expr).filter(_.nonEmpty).foreach(updateCommonExprs(_, map, uc)) + commonChildrenToRecurse(expr) + .filter(_._2.nonEmpty) + .foreach { case (alwaysEvaluatedInputs, branchGroups) => + updateCommonExprs(alwaysEvaluatedInputs, branchGroups, map, uc) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index 2375d3ed35f2..13f97fba6320 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -165,12 +165,12 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel assert(equivalence1.getAllExprStates().filter(_.useCount == 1).exists(_.expr eq condition)) // Repeated `add` is only in one branch, so we don't count it. - val ifExpr2 = If(condition, Add(Literal(1), Literal(3)), Add(add, add)) + val ifExpr2 = If(Literal(true), Add(Literal(1), Literal(3)), Add(add, add)) val equivalence2 = new EquivalentExpressions equivalence2.addExprTree(ifExpr2) assert(equivalence2.getAllExprStates(1).isEmpty) - assert(equivalence2.getAllExprStates().count(_.useCount == 1) == 3) + assert(equivalence2.getAllExprStates().count(_.useCount == 1) == 1) val ifExpr3 = If(condition, ifExpr1, ifExpr1) val equivalence3 = new EquivalentExpressions @@ -461,6 +461,104 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel assert(e2.getCommonSubexpressions.size == 1) assert(e2.getCommonSubexpressions.head == add) } + + test("SubExpr elimination should work when `branchGroups` " + + "has overlap with `alwaysEvaluatedInputs`: If") { + val add = Add(Literal(1), Literal(2)) + + // `add` is in one branch of `If` and predicate, it should be a common expression. + val ifExpr1 = If(IsNull(add), Literal(null), KnownNotNull(add)) + val equivalence1 = new EquivalentExpressions + equivalence1.addExprTree(ifExpr1) + assert(equivalence1.getCommonSubexpressions.size == 1) + assert(equivalence1.getCommonSubexpressions.head == add) + assert(equivalence1.getAllExprStates().count(_.useCount == 2) == 1) + assert(equivalence1.getAllExprStates().filter(_.useCount == 2).head.expr eq add) + + // `add` is in both branches of `If` and predicate, it should be a common expression. + val ifExpr2 = If(IsNull(add), add, KnownNotNull(add)) + val equivalence2 = new EquivalentExpressions + equivalence2.addExprTree(ifExpr2) + assert(equivalence2.getCommonSubexpressions.size == 1) + assert(equivalence2.getCommonSubexpressions.head == add) + assert(equivalence2.getAllExprStates().count(_.useCount == 2) == 1) + assert(equivalence2.getAllExprStates().filter(_.useCount == 2).head.expr eq add) + } + + test("SubExpr elimination should work when `branchGroups` " + + "has overlap with `alwaysEvaluatedInputs`: CaseWhen") { + val add1 = Add(Literal(1), Literal(2)) + val add2 = Add(Literal(2), Literal(3)) + + // `add2` exists in the first and third condition, it should be a common expression + val conditions1 = (GreaterThan(add2, Literal(3)), add1) :: + (GreaterThan(add1, Literal(4)), add1) :: + (GreaterThan(add2, Literal(5)), add1) :: Nil + val caseWhenExpr1 = CaseWhen(conditions1, None) + val equivalence1 = new EquivalentExpressions + equivalence1.addExprTree(caseWhenExpr1) + assert(equivalence1.getCommonSubexpressions.size == 1) + assert(equivalence1.getCommonSubexpressions.head == add2) + assert(equivalence1.getAllExprStates().count(_.useCount == 2) == 1) + assert(equivalence1.getAllExprStates().filter(_.useCount == 2).head.expr eq add2) + + // `add2` exists only in the first condition, it should NOT be a common expression + val conditions2 = (GreaterThan(add2, Literal(3)), add1) :: + (GreaterThan(add1, Literal(4)), add1) :: + (GreaterThan(add1, Literal(5)), add1) :: Nil + val caseWhenExpr2 = CaseWhen(conditions2, add1) + val equivalence2 = new EquivalentExpressions + equivalence2.addExprTree(caseWhenExpr2) + assert(equivalence2.getCommonSubexpressions.size == 0) + + // `add2` exists only in the last two conditions, it should NOT be a common expression + val conditions3 = (GreaterThan(add2, Literal(3)), add1) :: + (GreaterThan(add1, Literal(4)), add1) :: + (GreaterThan(add1, Literal(5)), add1) :: Nil + val caseWhenExpr3 = CaseWhen(conditions3, add1) + val equivalence3 = new EquivalentExpressions + equivalence3.addExprTree(caseWhenExpr3) + assert(equivalence3.getCommonSubexpressions.size == 0) + + // `add2` exists in the first condition and else value, it should be a common expression + val conditions4 = (GreaterThan(add2, Literal(3)), add1) :: + (GreaterThan(add1, Literal(4)), add1) :: + (GreaterThan(add1, Literal(5)), add1) :: Nil + val caseWhenExpr4 = CaseWhen(conditions4, Some(add2)) + val equivalence4 = new EquivalentExpressions + equivalence4.addExprTree(caseWhenExpr4) + assert(equivalence4.getCommonSubexpressions.size == 1) + assert(equivalence4.getCommonSubexpressions.head == add2) + assert(equivalence4.getAllExprStates().count(_.useCount == 2) == 1) + assert(equivalence4.getAllExprStates().filter(_.useCount == 2).head.expr eq add2) + } + + test("SubExpr elimination should work when `branchGroups` " + + "has overlap with `alwaysEvaluatedInputs`: Coalesce") { + val add1 = Add(Literal(1), Literal(2)) + val add2 = Add(Literal(2), Literal(3)) + + // `add2` is in first and third conditions, it should be a common expression + val conditions1 = GreaterThan(add2, Literal(3)) :: + GreaterThan(add1, Literal(4)) :: + GreaterThan(add2, Literal(5)) :: Nil + val coalesceExpr1 = Coalesce(conditions1) + val equivalence1 = new EquivalentExpressions + equivalence1.addExprTree(coalesceExpr1) + assert(equivalence1.getCommonSubexpressions.size == 1) + assert(equivalence1.getCommonSubexpressions.head == add2) + assert(equivalence1.getAllExprStates().count(_.useCount == 2) == 1) + assert(equivalence1.getAllExprStates().filter(_.useCount == 2).head.expr eq add2) + + // `add2` is in the last conditions, it should be a common expression + val conditions2 = GreaterThan(add1, Literal(3)) :: + GreaterThan(add2, Literal(4)) :: + GreaterThan(add2, Literal(5)) :: Nil + val coalesceExpr2 = Coalesce(conditions2) + val equivalence2 = new EquivalentExpressions + equivalence2.addExprTree(coalesceExpr2) + assert(equivalence2.getCommonSubexpressions.size == 0) + } } case class CodegenFallbackExpression(child: Expression)