diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index cf17f5959996..4696699337c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -390,6 +390,8 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { case If(TrueLiteral, trueValue, _) => trueValue case If(FalseLiteral, _, falseValue) => falseValue case If(Literal(null, _), _, falseValue) => falseValue + case If(cond, trueValue, falseValue) + if cond.deterministic && trueValue.semanticEquals(falseValue) => trueValue case e @ CaseWhen(branches, elseValue) if branches.exists(x => falseOrNullLiteral(x._1)) => // If there are branches that are always false, remove them. @@ -403,14 +405,14 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { e.copy(branches = newBranches) } - case e @ CaseWhen(branches, _) if branches.headOption.map(_._1) == Some(TrueLiteral) => + case CaseWhen(branches, _) if branches.headOption.map(_._1).contains(TrueLiteral) => // If the first branch is a true literal, remove the entire CaseWhen and use the value // from that. Note that CaseWhen.branches should never be empty, and as a result the // headOption (rather than head) added above is just an extra (and unnecessary) safeguard. branches.head._2 case CaseWhen(branches, _) if branches.exists(_._1 == TrueLiteral) => - // a branc with a TRue condition eliminates all following branches, + // a branch with a true condition eliminates all following branches, // these branches can be pruned away val (h, t) = branches.span(_._1 != TrueLiteral) CaseWhen( h :+ t.head, None) @@ -651,6 +653,7 @@ object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] { } } + /** * Combine nested [[Concat]] expressions. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala index b597c8e162c8..e210874a55d8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} @@ -29,7 +31,8 @@ import org.apache.spark.sql.types.{IntegerType, NullType} class SimplifyConditionalSuite extends PlanTest with PredicateHelper { object Optimize extends RuleExecutor[LogicalPlan] { - val batches = Batch("SimplifyConditionals", FixedPoint(50), SimplifyConditionals) :: Nil + val batches = Batch("SimplifyConditionals", FixedPoint(50), + BooleanSimplification, ConstantFolding, SimplifyConditionals) :: Nil } protected def assertEquivalent(e1: Expression, e2: Expression): Unit = { @@ -43,6 +46,8 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper { private val unreachableBranch = (FalseLiteral, Literal(20)) private val nullBranch = (Literal.create(null, NullType), Literal(30)) + private val testRelation = LocalRelation('a.int) + test("simplify if") { assertEquivalent( If(TrueLiteral, Literal(10), Literal(20)), @@ -57,6 +62,23 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper { Literal(20)) } + test("remove unnecessary if when the outputs are semantic equivalence") { + assertEquivalent( + If(IsNotNull(UnresolvedAttribute("a")), + Subtract(Literal(10), Literal(1)), + Add(Literal(6), Literal(3))), + Literal(9)) + + // For non-deterministic condition, we don't remove the `If` statement. + assertEquivalent( + If(GreaterThan(Rand(0), Literal(0.5)), + Subtract(Literal(10), Literal(1)), + Add(Literal(6), Literal(3))), + If(GreaterThan(Rand(0), Literal(0.5)), + Literal(9), + Literal(9))) + } + test("remove unreachable branches") { // i.e. removing branches whose conditions are always false assertEquivalent( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index e562be83822e..ac70488febc5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -393,7 +393,7 @@ private[sql] trait SQLTestUtilsBase } /** - * Returns full path to the given file in the resouce folder + * Returns full path to the given file in the resource folder */ protected def testFile(fileName: String): String = { Thread.currentThread().getContextClassLoader.getResource(fileName).toString