diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index c222571a3464..e7ca11963de8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -46,22 +46,18 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = { // Currently we only propagate constraints if the condition consists of equality // and ranges. For all other cases, we return an empty set of constraints - constraints.map { - case EqualTo(l, r) => - Set(IsNotNull(l), IsNotNull(r)) - case GreaterThan(l, r) => - Set(IsNotNull(l), IsNotNull(r)) - case GreaterThanOrEqual(l, r) => - Set(IsNotNull(l), IsNotNull(r)) - case LessThan(l, r) => - Set(IsNotNull(l), IsNotNull(r)) - case LessThanOrEqual(l, r) => - Set(IsNotNull(l), IsNotNull(r)) - case Not(EqualTo(l, r)) => - Set(IsNotNull(l), IsNotNull(r)) - case _ => - Set.empty[Expression] - }.foldLeft(Set.empty[Expression])(_ union _.toSet) + // Note: Almost all the subclasses of BinaryComparison (EqualTo, LessThan, LessThanOrEqual, + // GreaterThan and GreaterThanOrEqual) are NULL intolerant. The only exception is EqualNullSafe + var isNotNullConstraints = Set.empty[Expression] + constraints.collect { + case b @ BinaryComparison(l, r) if !b.isInstanceOf[EqualNullSafe] => + if (l.isInstanceOf[AttributeReference]) isNotNullConstraints += IsNotNull(l) + if (r.isInstanceOf[AttributeReference]) isNotNullConstraints += IsNotNull(r) + case Not(b @ BinaryComparison(l, r)) if !b.isInstanceOf[EqualNullSafe] => + if (l.isInstanceOf[AttributeReference]) isNotNullConstraints += IsNotNull(l) + if (r.isInstanceOf[AttributeReference]) isNotNullConstraints += IsNotNull(r) + } + isNotNullConstraints } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NullFilteringSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NullFilteringSuite.scala index 142e4ae6e439..fc058ccac4d4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NullFilteringSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NullFilteringSuite.scala @@ -40,6 +40,12 @@ class NullFilteringSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("filter: do not push Null-filtering of compound expressions") { + val originalQuery = testRelation.where('a + 'b === 1).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, originalQuery) + } + test("single inner join: filter out nulls on either side on equi-join keys") { val x = testRelation.subquery('x) val y = testRelation.subquery('y) @@ -83,6 +89,15 @@ class NullFilteringSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("single inner join: no null filters are generated for compound expression") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + val originalQuery = x.join(y, + condition = Some("x.a".attr * 2 === "y.a".attr - 4)).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, originalQuery) + } + test("single outer join: no null filters are generated") { val x = testRelation.subquery('x) val y = testRelation.subquery('y) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index a9375a740daa..1b43be5734e2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -217,4 +217,19 @@ class ConstraintPropagationSuite extends SparkFunSuite { IsNotNull(resolveColumn(tr, "a")), IsNotNull(resolveColumn(tr, "b"))))) } + + test("IsNotNull constraints of compound expressions in filters") { + val tr = LocalRelation('a.int, 'b.string, 'c.int) + verifyConstraints(tr + .where('a.attr + 'c.attr > 10).analyze.constraints, + ExpressionSet(Seq(resolveColumn(tr, "a") + resolveColumn(tr, "c") > 10))) + } + + test("IsNotNull constraints of BinaryComparison in Not in filters") { + val tr = LocalRelation('a.int, 'b.string, 'c.int) + verifyConstraints(tr + .where(!('a.attr < 10)).analyze.constraints, + ExpressionSet(Seq(IsNotNull(resolveColumn(tr, "a")), + Not(resolveColumn(tr, "a") < 10)))) + } }