From 48a52a25c8f43e7bbf2a15bf1892ff9f161fd83e Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 18 Dec 2018 23:21:52 -0800 Subject: [PATCH] [SPARK-26366][SQL][BACKPORT-2.3] ReplaceExceptWithFilter should consider NULL as False In `ReplaceExceptWithFilter` we do not consider properly the case in which the condition returns NULL. Indeed, in that case, since negating NULL still returns NULL, so it is not true the assumption that negating the condition returns all the rows which didn't satisfy it, rows returning NULL may not be returned. This happens when constraints inferred by `InferFiltersFromConstraints` are not enough, as it happens with `OR` conditions. The rule had also problems with non-deterministic conditions: in such a scenario, this rule would change the probability of the output. The PR fixes these problem by: - returning False for the condition when it is Null (in this way we do return all the rows which didn't satisfy it); - avoiding any transformation when the condition is non-deterministic. added UTs Closes #23315 from mgaido91/SPARK-26366. Authored-by: Marco Gaido Signed-off-by: gatorsmile --- .../optimizer/ReplaceExceptWithFilter.scala | 32 ++++++++------ .../optimizer/ReplaceOperatorSuite.scala | 42 ++++++++++++++----- .../org/apache/spark/sql/DatasetSuite.scala | 11 +++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 38 +++++++++++++++++ 4 files changed, 99 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala index 45edf266bbce..08cf16038a65 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala @@ -36,7 +36,8 @@ import org.apache.spark.sql.catalyst.rules.Rule * Note: * Before flipping the filter condition of the right node, we should: * 1. Combine all it's [[Filter]]. - * 2. Apply InferFiltersFromConstraints rule (to take into account of NULL values in the condition). + * 2. Update the attribute references to the left node; + * 3. Add a Coalesce(condition, False) (to take into account of NULL values in the condition). */ object ReplaceExceptWithFilter extends Rule[LogicalPlan] { @@ -47,23 +48,28 @@ object ReplaceExceptWithFilter extends Rule[LogicalPlan] { plan.transform { case e @ Except(left, right) if isEligible(left, right) => - val newCondition = transformCondition(left, skipProject(right)) - newCondition.map { c => - Distinct(Filter(Not(c), left)) - }.getOrElse { + val filterCondition = combineFilters(skipProject(right)).asInstanceOf[Filter].condition + if (filterCondition.deterministic) { + transformCondition(left, filterCondition).map { c => + Distinct(Filter(Not(c), left)) + }.getOrElse { + e + } + } else { e } } } - private def transformCondition(left: LogicalPlan, right: LogicalPlan): Option[Expression] = { - val filterCondition = - InferFiltersFromConstraints(combineFilters(right)).asInstanceOf[Filter].condition - - val attributeNameMap: Map[String, Attribute] = left.output.map(x => (x.name, x)).toMap - - if (filterCondition.references.forall(r => attributeNameMap.contains(r.name))) { - Some(filterCondition.transform { case a: AttributeReference => attributeNameMap(a.name) }) + private def transformCondition(plan: LogicalPlan, condition: Expression): Option[Expression] = { + val attributeNameMap: Map[String, Attribute] = plan.output.map(x => (x.name, x)).toMap + if (condition.references.forall(r => attributeNameMap.contains(r.name))) { + val rewrittenCondition = condition.transform { + case a: AttributeReference => attributeNameMap(a.name) + } + // We need to consider as False when the condition is NULL, otherwise we do not return those + // rows containing NULL which are instead filtered in the Except right plan + Some(Coalesce(Seq(rewrittenCondition, Literal.FalseLiteral))) } else { None } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala index 52dc2e9fb076..b9c4a72a7a94 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala @@ -20,11 +20,12 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, Not} +import org.apache.spark.sql.catalyst.expressions.{Alias, Coalesce, If, Literal, Not} import org.apache.spark.sql.catalyst.expressions.aggregate.First import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi, PlanTest} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types.BooleanType class ReplaceOperatorSuite extends PlanTest { @@ -65,8 +66,7 @@ class ReplaceOperatorSuite extends PlanTest { val correctAnswer = Aggregate(table1.output, table1.output, - Filter(Not((attributeA.isNotNull && attributeB.isNotNull) && - (attributeA >= 2 && attributeB < 1)), + Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1, Literal.FalseLiteral))), Filter(attributeB === 2, Filter(attributeA === 1, table1)))).analyze comparePlans(optimized, correctAnswer) @@ -84,8 +84,8 @@ class ReplaceOperatorSuite extends PlanTest { val correctAnswer = Aggregate(table1.output, table1.output, - Filter(Not((attributeA.isNotNull && attributeB.isNotNull) && - (attributeA >= 2 && attributeB < 1)), table1)).analyze + Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1, Literal.FalseLiteral))), + table1)).analyze comparePlans(optimized, correctAnswer) } @@ -104,8 +104,7 @@ class ReplaceOperatorSuite extends PlanTest { val correctAnswer = Aggregate(table1.output, table1.output, - Filter(Not((attributeA.isNotNull && attributeB.isNotNull) && - (attributeA >= 2 && attributeB < 1)), + Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1, Literal.FalseLiteral))), Project(Seq(attributeA, attributeB), table1))).analyze comparePlans(optimized, correctAnswer) @@ -125,8 +124,7 @@ class ReplaceOperatorSuite extends PlanTest { val correctAnswer = Aggregate(table1.output, table1.output, - Filter(Not((attributeA.isNotNull && attributeB.isNotNull) && - (attributeA >= 2 && attributeB < 1)), + Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1, Literal.FalseLiteral))), Filter(attributeB === 2, Filter(attributeA === 1, table1)))).analyze comparePlans(optimized, correctAnswer) @@ -146,8 +144,7 @@ class ReplaceOperatorSuite extends PlanTest { val correctAnswer = Aggregate(table1.output, table1.output, - Filter(Not((attributeA.isNotNull && attributeB.isNotNull) && - (attributeA === 1 && attributeB === 2)), + Filter(Not(Coalesce(Seq(attributeA === 1 && attributeB === 2, Literal.FalseLiteral))), Project(Seq(attributeA, attributeB), Filter(attributeB < 1, Filter(attributeA >= 2, table1))))).analyze @@ -229,4 +226,27 @@ class ReplaceOperatorSuite extends PlanTest { comparePlans(optimized, query) } + + test("SPARK-26366: ReplaceExceptWithFilter should handle properly NULL") { + val basePlan = LocalRelation(Seq('a.int, 'b.int)) + val otherPlan = basePlan.where('a.in(1, 2) || 'b.in()) + val except = Except(basePlan, otherPlan) + val result = OptimizeIn(Optimize.execute(except.analyze)) + val correctAnswer = Aggregate(basePlan.output, basePlan.output, + Filter(!Coalesce(Seq('a.in(1, 2) || 'b.in(), Literal.FalseLiteral)), + basePlan)).analyze + comparePlans(result, correctAnswer) + } + + test("SPARK-26366: ReplaceExceptWithFilter should not transform non-detrministic") { + val basePlan = LocalRelation(Seq('a.int, 'b.int)) + val otherPlan = basePlan.where('a > rand(1L)) + val except = Except(basePlan, otherPlan) + val result = Optimize.execute(except.analyze) + val condition = basePlan.output.zip(otherPlan.output).map { case (a1, a2) => + a1 <=> a2 }.reduce( _ && _) + val correctAnswer = Aggregate(basePlan.output, otherPlan.output, + Join(basePlan, otherPlan, LeftAnti, Option(condition))).analyze + comparePlans(result, correctAnswer) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 3b7bd84be047..522ed8d4a42d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1467,6 +1467,17 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkAnswer(df.groupBy(col("a")).agg(first(col("b"))), Seq(Row("0", BigDecimal.valueOf(0.1111)), Row("1", BigDecimal.valueOf(1.1111)))) } + + test("SPARK-26366: return nulls which are not filtered in except") { + val inputDF = sqlContext.createDataFrame( + sparkContext.parallelize(Seq(Row("0", "a"), Row("1", null))), + StructType(Seq( + StructField("a", StringType, nullable = true), + StructField("b", StringType, nullable = true)))) + + val exceptDF = inputDF.filter(col("a").isin("0") or col("b") > "c") + checkAnswer(inputDF.except(exceptDF), Seq(Row("1", null))) + } } case class TestDataUnion(x: Int, y: Int, z: Int) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 0af6d8715032..6848b66514ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2831,6 +2831,44 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(sql("select 26393499451 / (1e6 * 1000)"), Row(BigDecimal("26.3934994510000"))) } } + + test("SPARK-26366: verify ReplaceExceptWithFilter") { + Seq(true, false).foreach { enabled => + withSQLConf(SQLConf.REPLACE_EXCEPT_WITH_FILTER.key -> enabled.toString) { + val df = spark.createDataFrame( + sparkContext.parallelize(Seq(Row(0, 3, 5), + Row(0, 3, null), + Row(null, 3, 5), + Row(0, null, 5), + Row(0, null, null), + Row(null, null, 5), + Row(null, 3, null), + Row(null, null, null))), + StructType(Seq(StructField("c1", IntegerType), + StructField("c2", IntegerType), + StructField("c3", IntegerType)))) + val where = "c2 >= 3 OR c1 >= 0" + val whereNullSafe = + """ + |(c2 IS NOT NULL AND c2 >= 3) + |OR (c1 IS NOT NULL AND c1 >= 0) + """.stripMargin + + val df_a = df.filter(where) + val df_b = df.filter(whereNullSafe) + checkAnswer(df.except(df_a), df.except(df_b)) + + val whereWithIn = "c2 >= 3 OR c1 in (2)" + val whereWithInNullSafe = + """ + |(c2 IS NOT NULL AND c2 >= 3) + """.stripMargin + val dfIn_a = df.filter(whereWithIn) + val dfIn_b = df.filter(whereWithInNullSafe) + checkAnswer(df.except(dfIn_a), df.except(dfIn_b)) + } + } + } } case class Foo(bar: Option[String])