diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index bd400f86ea2c..00395fe3bd0d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -230,31 +230,46 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] { * Optimize IN predicates: * 1. Converts the predicate to false when the list is empty and * the value is not nullable. - * 2. Removes literal repetitions. - * 3. Replaces [[In (value, seq[Literal])]] with optimized version + * 2. Extract convertible part from list. + * 3. Removes literal repetitions. + * 4. Replaces [[In (value, seq[Literal])]] with optimized version * [[InSet (value, HashSet[Literal])]] which is much faster. */ object OptimizeIn extends Rule[LogicalPlan] { + def optimizeIn(expr: In, v: Expression, list: Seq[Expression]): Expression = { + val newList = ExpressionSet(list).toSeq + if (newList.length == 1 + // TODO: `EqualTo` for structural types are not working. Until SPARK-24443 is addressed, + // TODO: we exclude them in this rule. + && !v.isInstanceOf[CreateNamedStruct] + && !newList.head.isInstanceOf[CreateNamedStruct]) { + EqualTo(v, newList.head) + } else if (newList.length > SQLConf.get.optimizerInSetConversionThreshold) { + val hSet = newList.map(e => e.eval(EmptyRow)) + InSet(v, HashSet() ++ hSet) + } else if (newList.length < list.length) { + expr.copy(list = newList) + } else { // newList.length == list.length && newList.length > 1 + expr + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsDown { case In(v, list) if list.isEmpty => // When v is not nullable, the following expression will be optimized // to FalseLiteral which is tested in OptimizeInSuite.scala If(IsNotNull(v), FalseLiteral, Literal(null, BooleanType)) - case expr @ In(v, list) if expr.inSetConvertible => - val newList = ExpressionSet(list).toSeq - if (newList.length == 1 - // TODO: `EqualTo` for structural types are not working. Until SPARK-24443 is addressed, - // TODO: we exclude them in this rule. - && !v.isInstanceOf[CreateNamedStruct] - && !newList.head.isInstanceOf[CreateNamedStruct]) { - EqualTo(v, newList.head) - } else if (newList.length > SQLConf.get.optimizerInSetConversionThreshold) { - val hSet = newList.map(e => e.eval(EmptyRow)) - InSet(v, HashSet() ++ hSet) - } else if (newList.length < list.length) { - expr.copy(list = newList) - } else { // newList.length == list.length && newList.length > 1 + case expr @ In(v, list) => + // split list to 2 parts so that we can optimize convertible part + val (convertible, nonConvertible) = list.partition(_.isInstanceOf[Literal]) + if (convertible.nonEmpty && nonConvertible.isEmpty) { + optimizeIn(expr, v, list) + } else if (convertible.nonEmpty && nonConvertible.nonEmpty && + SQLConf.get.optimizerInExtractLiteralPart) { + val optimizedIn = optimizeIn(In(v, convertible), v, convertible) + Or(optimizedIn, In(v, nonConvertible)) + } else { expr } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 3149d14c1ddc..deb29d38a354 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -197,6 +197,14 @@ object SQLConf { .intConf .createWithDefault(100) + val OPTIMIZER_IN_EXTRACT_LITERAL_PART = + buildConf("spark.sql.optimizer.inExtractLiteralPart") + .internal() + .doc("When true, we will extract and optimize the literal part of in if not all are literal.") + .version("3.1.0") + .booleanConf + .createWithDefault(true) + val OPTIMIZER_INSET_CONVERSION_THRESHOLD = buildConf("spark.sql.optimizer.inSetConversionThreshold") .internal() @@ -2761,6 +2769,8 @@ class SQLConf extends Serializable with Logging { def optimizerMaxIterations: Int = getConf(OPTIMIZER_MAX_ITERATIONS) + def optimizerInExtractLiteralPart: Boolean = getConf(OPTIMIZER_IN_EXTRACT_LITERAL_PART) + def optimizerInSetConversionThreshold: Int = getConf(OPTIMIZER_INSET_CONVERSION_THRESHOLD) def optimizerInSetSwitchThreshold: Int = getConf(OPTIMIZER_INSET_SWITCH_THRESHOLD) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index a36083b84704..7fcc82cc9ab5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.internal.SQLConf.OPTIMIZER_IN_EXTRACT_LITERAL_PART import org.apache.spark.sql.internal.SQLConf.OPTIMIZER_INSET_CONVERSION_THRESHOLD import org.apache.spark.sql.types._ @@ -91,21 +92,6 @@ class OptimizeInSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("OptimizedIn test: In clause not optimized in case filter has attributes") { - val originalQuery = - testRelation - .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), UnresolvedAttribute("b")))) - .analyze - - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = - testRelation - .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), UnresolvedAttribute("b")))) - .analyze - - comparePlans(optimized, correctAnswer) - } - test("OptimizedIn test: NULL IN (expr1, ..., exprN) gets transformed to Filter(null)") { val originalQuery = testRelation @@ -238,4 +224,44 @@ class OptimizeInSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("SPARK-32196: Extract In convertible part if it is not convertible") { + Seq("true", "false").foreach { enable => + withSQLConf(OPTIMIZER_IN_EXTRACT_LITERAL_PART.key -> enable) { + val originalQuery1 = + testRelation + .where(In(UnresolvedAttribute("a"), Seq(Literal(1), UnresolvedAttribute("b")))) + .analyze + val optimized1 = Optimize.execute(originalQuery1) + + if (enable.toBoolean) { + val correctAnswer1 = + testRelation + .where( + Or(EqualTo(UnresolvedAttribute("a"), Literal(1)), + In(UnresolvedAttribute("a"), Seq(UnresolvedAttribute("b")))) + ) + .analyze + comparePlans(optimized1, correctAnswer1) + } else { + val correctAnswer1 = + testRelation + .where(In(UnresolvedAttribute("a"), Seq(Literal(1), UnresolvedAttribute("b")))) + .analyze + comparePlans(optimized1, correctAnswer1) + } + } + } + + val originalQuery2 = + testRelation + .where(In(UnresolvedAttribute("a"), Seq(UnresolvedAttribute("b")))) + .analyze + val optimized2 = Optimize.execute(originalQuery2) + val correctAnswer2 = + testRelation + .where(In(UnresolvedAttribute("a"), Seq(UnresolvedAttribute("b")))) + .analyze + comparePlans(optimized2, correctAnswer2) + } }