@@ -254,7 +254,7 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] {
254254 // not allowed to use the same attributes. We use a blacklist to prevent us from creating a
255255 // situation in which this happens; the rule will only remove an alias if its child
256256 // attribute is not on the black list.
257- case Join (left, right, joinType, condition) =>
257+ case Join (left, right, joinType, condition, notNullAttrs ) =>
258258 val newLeft = removeRedundantAliases(left, blacklist ++ right.outputSet)
259259 val newRight = removeRedundantAliases(right, blacklist ++ newLeft.outputSet)
260260 val mapping = AttributeMap (
@@ -263,7 +263,7 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] {
263263 val newCondition = condition.map(_.transform {
264264 case a : Attribute => mapping.getOrElse(a, a)
265265 })
266- Join (newLeft, newRight, joinType, newCondition)
266+ Join (newLeft, newRight, joinType, newCondition, notNullAttrs )
267267
268268 case _ =>
269269 // Remove redundant aliases in the subtree(s).
@@ -354,7 +354,7 @@ object LimitPushDown extends Rule[LogicalPlan] {
354354 // on both sides if it is applied multiple times. Therefore:
355355 // - If one side is already limited, stack another limit on top if the new limit is smaller.
356356 // The redundant limit will be collapsed by the CombineLimits rule.
357- case LocalLimit (exp, join @ Join (left, right, joinType, _)) =>
357+ case LocalLimit (exp, join @ Join (left, right, joinType, _, _ )) =>
358358 val newJoin = joinType match {
359359 case RightOuter => join.copy(right = maybePushLocalLimit(exp, right))
360360 case LeftOuter => join.copy(left = maybePushLocalLimit(exp, left))
@@ -468,7 +468,7 @@ object ColumnPruning extends Rule[LogicalPlan] {
468468 p.copy(child = g.copy(child = newChild, unrequiredChildIndex = unrequiredIndices))
469469
470470 // Eliminate unneeded attributes from right side of a Left Existence Join.
471- case j @ Join (_, right, LeftExistence (_), _) =>
471+ case j @ Join (_, right, LeftExistence (_), _, _ ) =>
472472 j.copy(right = prunedChild(right, j.references))
473473
474474 // all the columns will be used to compare, so we can't prune them
@@ -661,27 +661,38 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan]
661661 filter
662662 }
663663
664- case join @ Join (left, right, joinType, conditionOpt) =>
664+ case join @ Join (left, right, joinType, conditionOpt, notNullAttrs ) =>
665665 joinType match {
666666 // For inner join, we can infer additional filters for both sides. LeftSemi is kind of an
667667 // inner join, it just drops the right side in the final output.
668668 case _ : InnerLike | LeftSemi =>
669669 val allConstraints = getAllConstraints(left, right, conditionOpt)
670- val newLeft = inferNewFilter(left, allConstraints)
671- val newRight = inferNewFilter(right, allConstraints)
672- join.copy(left = newLeft, right = newRight)
670+ val newLeftPredicates = inferNewPredicate(left, allConstraints)
671+ val newRightPredicates = inferNewPredicate(right, allConstraints)
672+ val newNotNullAttrs = getNotNullAttributes(
673+ newLeftPredicates ++ newRightPredicates, notNullAttrs)
674+ join.copy(
675+ left = addFilterIfNeeded(left, newLeftPredicates),
676+ right = addFilterIfNeeded(right, newRightPredicates),
677+ notNullAttributes = newNotNullAttrs)
673678
674679 // For right outer join, we can only infer additional filters for left side.
675680 case RightOuter =>
676681 val allConstraints = getAllConstraints(left, right, conditionOpt)
677- val newLeft = inferNewFilter(left, allConstraints)
678- join.copy(left = newLeft)
682+ val newLeftPredicates = inferNewPredicate(left, allConstraints)
683+ val newNotNullAttrs = getNotNullAttributes(newLeftPredicates, notNullAttrs)
684+ join.copy(
685+ left = addFilterIfNeeded(left, newLeftPredicates),
686+ notNullAttributes = newNotNullAttrs)
679687
680688 // For left join, we can only infer additional filters for right side.
681689 case LeftOuter | LeftAnti =>
682690 val allConstraints = getAllConstraints(left, right, conditionOpt)
683- val newRight = inferNewFilter(right, allConstraints)
684- join.copy(right = newRight)
691+ val newRightPredicates = inferNewPredicate(right, allConstraints)
692+ val newNotNullAttrs = getNotNullAttributes(newRightPredicates, notNullAttrs)
693+ join.copy(
694+ right = addFilterIfNeeded(right, newRightPredicates),
695+ notNullAttributes = newNotNullAttrs)
685696
686697 case _ => join
687698 }
@@ -696,16 +707,32 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan]
696707 baseConstraints.union(inferAdditionalConstraints(baseConstraints))
697708 }
698709
699- private def inferNewFilter (plan : LogicalPlan , constraints : Set [Expression ]): LogicalPlan = {
700- val newPredicates = constraints
710+ private def inferNewPredicate (
711+ plan : LogicalPlan , constraints : Set [Expression ]): Set [Expression ] = {
712+ constraints
701713 .union(constructIsNotNullConstraints(constraints, plan.output))
702714 .filter { c =>
703715 c.references.nonEmpty && c.references.subsetOf(plan.outputSet) && c.deterministic
704716 } -- plan.constraints
705- if (newPredicates.isEmpty) {
706- plan
707- } else {
717+ }
718+
719+ private def getNotNullAttributes (
720+ constraints : Set [Expression ],
721+ curNotNullAttrs : Set [ExprId ]): Set [ExprId ] = {
722+
723+ // Split out all the IsNotNulls from the `constraints`
724+ val (notNullPreds, _) = constraints.partition {
725+ case IsNotNull (a) => isNullIntolerant(a)
726+ case _ => false
727+ }
728+ notNullPreds.flatMap(_.references.map(_.exprId)) ++ curNotNullAttrs
729+ }
730+
731+ private def addFilterIfNeeded (plan : LogicalPlan , newPredicates : Set [Expression ]): LogicalPlan = {
732+ if (newPredicates.nonEmpty) {
708733 Filter (newPredicates.reduce(And ), plan)
734+ } else {
735+ plan
709736 }
710737 }
711738}
@@ -1048,7 +1075,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
10481075
10491076 def apply (plan : LogicalPlan ): LogicalPlan = plan transform {
10501077 // push the where condition down into join filter
1051- case f @ Filter (filterCondition, Join (left, right, joinType, joinCondition)) =>
1078+ case f @ Filter (filterCondition, Join (left, right, joinType, joinCondition, notNullAttrs )) =>
10521079 val (leftFilterConditions, rightFilterConditions, commonFilterCondition) =
10531080 split(splitConjunctivePredicates(filterCondition), left, right)
10541081 joinType match {
@@ -1062,7 +1089,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
10621089 commonFilterCondition.partition(canEvaluateWithinJoin)
10631090 val newJoinCond = (newJoinConditions ++ joinCondition).reduceLeftOption(And )
10641091
1065- val join = Join (newLeft, newRight, joinType, newJoinCond)
1092+ val join = Join (newLeft, newRight, joinType, newJoinCond, notNullAttrs )
10661093 if (others.nonEmpty) {
10671094 Filter (others.reduceLeft(And ), join)
10681095 } else {
@@ -1074,7 +1101,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
10741101 val newRight = rightFilterConditions.
10751102 reduceLeftOption(And ).map(Filter (_, right)).getOrElse(right)
10761103 val newJoinCond = joinCondition
1077- val newJoin = Join (newLeft, newRight, RightOuter , newJoinCond)
1104+ val newJoin = Join (newLeft, newRight, RightOuter , newJoinCond, notNullAttrs )
10781105
10791106 (leftFilterConditions ++ commonFilterCondition).
10801107 reduceLeftOption(And ).map(Filter (_, newJoin)).getOrElse(newJoin)
@@ -1084,7 +1111,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
10841111 reduceLeftOption(And ).map(Filter (_, left)).getOrElse(left)
10851112 val newRight = right
10861113 val newJoinCond = joinCondition
1087- val newJoin = Join (newLeft, newRight, joinType, newJoinCond)
1114+ val newJoin = Join (newLeft, newRight, joinType, newJoinCond, notNullAttrs )
10881115
10891116 (rightFilterConditions ++ commonFilterCondition).
10901117 reduceLeftOption(And ).map(Filter (_, newJoin)).getOrElse(newJoin)
@@ -1094,7 +1121,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
10941121 }
10951122
10961123 // push down the join filter into sub query scanning if applicable
1097- case j @ Join (left, right, joinType, joinCondition) =>
1124+ case j @ Join (left, right, joinType, joinCondition, notNullAttrs ) =>
10981125 val (leftJoinConditions, rightJoinConditions, commonJoinCondition) =
10991126 split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil ), left, right)
11001127
@@ -1107,23 +1134,23 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
11071134 reduceLeftOption(And ).map(Filter (_, right)).getOrElse(right)
11081135 val newJoinCond = commonJoinCondition.reduceLeftOption(And )
11091136
1110- Join (newLeft, newRight, joinType, newJoinCond)
1137+ Join (newLeft, newRight, joinType, newJoinCond, notNullAttrs )
11111138 case RightOuter =>
11121139 // push down the left side only join filter for left side sub query
11131140 val newLeft = leftJoinConditions.
11141141 reduceLeftOption(And ).map(Filter (_, left)).getOrElse(left)
11151142 val newRight = right
11161143 val newJoinCond = (rightJoinConditions ++ commonJoinCondition).reduceLeftOption(And )
11171144
1118- Join (newLeft, newRight, RightOuter , newJoinCond)
1145+ Join (newLeft, newRight, RightOuter , newJoinCond, notNullAttrs )
11191146 case LeftOuter | LeftAnti | ExistenceJoin (_) =>
11201147 // push down the right side only join filter for right sub query
11211148 val newLeft = left
11221149 val newRight = rightJoinConditions.
11231150 reduceLeftOption(And ).map(Filter (_, right)).getOrElse(right)
11241151 val newJoinCond = (leftJoinConditions ++ commonJoinCondition).reduceLeftOption(And )
11251152
1126- Join (newLeft, newRight, joinType, newJoinCond)
1153+ Join (newLeft, newRight, joinType, newJoinCond, notNullAttrs )
11271154 case FullOuter => j
11281155 case NaturalJoin (_) => sys.error(" Untransformed NaturalJoin node" )
11291156 case UsingJoin (_, _) => sys.error(" Untransformed Using join node" )
@@ -1179,7 +1206,7 @@ object CheckCartesianProducts extends Rule[LogicalPlan] with PredicateHelper {
11791206 if (SQLConf .get.crossJoinEnabled) {
11801207 plan
11811208 } else plan transform {
1182- case j @ Join (left, right, Inner | LeftOuter | RightOuter | FullOuter , _)
1209+ case j @ Join (left, right, Inner | LeftOuter | RightOuter | FullOuter , _, _ )
11831210 if isCartesianProduct(j) =>
11841211 throw new AnalysisException (
11851212 s """ Detected cartesian product for ${j.joinType.sql} join between logical plans
0 commit comments